don't view origin buffer when sharding (#5122)

* make buffer view optional with a flag

* do not view when sharding to save memory
This commit is contained in:
David Hou 2024-06-25 20:19:09 -07:00 committed by GitHub
parent 89e106686a
commit 666a9c1448
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 7 additions and 1 deletions

View File

@ -542,6 +542,12 @@ class TestMultiTensor(unittest.TestCase):
assert ast.src[0].src[0].op is BufferOps.LOAD
assert ast.src[0].src[1].op is BufferOps.CONST and ast.src[0].src[1].arg.val == 3
def test_shard_memory(self):
devices = (d0, d1, d2, d3)
t = Tensor.zeros(16, 16).contiguous()
t.shard_(devices, axis=0)
assert all([lb is lb.base and lb.buffer.base.size == 4 * 16 for lb in t.lazydata.lbs])
@unittest.skipIf(CI and Device.DEFAULT in ("GPU", "CUDA", "METAL"), "no GPU CI")
class TestHandleData(unittest.TestCase):
def test_copied_to_device(self):

View File

@ -73,7 +73,7 @@ class MultiLazyBuffer:
def from_sharded(lb:LazyBuffer, devices:Tuple[str, ...], axis:Optional[int]=None):
lbs = [lb.contiguous() if lb.base != lb and not lb.is_unrealized_unmasked_const() else lb] * len(devices)
sharded_lbs = [lb.copy_to_device(d) for lb,d in zip(to_sharded(lbs, axis) if axis is not None else lbs, devices)]
return MultiLazyBuffer([lb if lb.is_unrealized_unmasked_const() else lb.contiguous() for lb in sharded_lbs], axis)
return MultiLazyBuffer([lb if lb.is_unrealized_unmasked_const() else lb.contiguous(allow_buffer_view=False) for lb in sharded_lbs], axis)
def copy_to_device(self, device:str) -> LazyBuffer:
if self.axis is None: