mirror of https://github.com/commaai/tinygrad.git
Fix broken test (#4104)
This commit is contained in:
parent
9a95d87366
commit
2b81d9b334
|
@ -40,18 +40,21 @@ class TestMultiTensor(unittest.TestCase):
|
|||
assert lb.shape == (128,)
|
||||
(X + X).realize()
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "METAL", "metal multi-device is fake")
|
||||
def test_sharded_memory(self):
|
||||
# Buffer may be stuck in track_cross_buffer
|
||||
for x in (d_zero, d0, d1, d2, d3): Device[x].synchronize()
|
||||
mem_base = GlobalCounters.mem_used
|
||||
|
||||
X = Tensor.ones(256).contiguous().realize()
|
||||
assert GlobalCounters.mem_used-mem_base== X.dtype.itemsize * 256, GlobalCounters.mem_used-mem_base
|
||||
X.shard_((d0, d1, d2, d3)).realize()
|
||||
for x in (d_zero, d0, d1, d2, d3): Device[x].synchronize()
|
||||
assert GlobalCounters.mem_used-mem_base == X.dtype.itemsize * 256 * 4, GlobalCounters.mem_used-mem_base
|
||||
|
||||
X = Tensor.ones(256).contiguous().realize()
|
||||
assert GlobalCounters.mem_used-mem_base == X.dtype.itemsize * 256, GlobalCounters.mem_used-mem_base
|
||||
X.shard_((d0, d1, d2, d3), axis=0).realize()
|
||||
for x in (d_zero, d0, d1, d2, d3): Device[x].synchronize()
|
||||
assert GlobalCounters.mem_used-mem_base == X.dtype.itemsize * 256, GlobalCounters.mem_used-mem_base
|
||||
|
||||
X = Tensor.ones(256).realize()
|
||||
|
|
Loading…
Reference in New Issue