Fix broken test (#4104)

This commit is contained in:
uuuvn 2024-04-07 19:02:12 +03:00 committed by GitHub
parent 9a95d87366
commit 2b81d9b334
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 4 additions and 1 deletions

View File

@ -40,18 +40,21 @@ class TestMultiTensor(unittest.TestCase):
assert lb.shape == (128,)
(X + X).realize()
@unittest.skipIf(Device.DEFAULT == "METAL", "metal multi-device is fake")
def test_sharded_memory(self):
# Buffer may be stuck in track_cross_buffer
for x in (d_zero, d0, d1, d2, d3): Device[x].synchronize()
mem_base = GlobalCounters.mem_used
X = Tensor.ones(256).contiguous().realize()
assert GlobalCounters.mem_used-mem_base== X.dtype.itemsize * 256, GlobalCounters.mem_used-mem_base
X.shard_((d0, d1, d2, d3)).realize()
for x in (d_zero, d0, d1, d2, d3): Device[x].synchronize()
assert GlobalCounters.mem_used-mem_base == X.dtype.itemsize * 256 * 4, GlobalCounters.mem_used-mem_base
X = Tensor.ones(256).contiguous().realize()
assert GlobalCounters.mem_used-mem_base == X.dtype.itemsize * 256, GlobalCounters.mem_used-mem_base
X.shard_((d0, d1, d2, d3), axis=0).realize()
for x in (d_zero, d0, d1, d2, d3): Device[x].synchronize()
assert GlobalCounters.mem_used-mem_base == X.dtype.itemsize * 256, GlobalCounters.mem_used-mem_base
X = Tensor.ones(256).realize()