From 2b81d9b3348cdd04acfe70d7fc1291af24311215 Mon Sep 17 00:00:00 2001 From: uuuvn <83587632+uuuvn@users.noreply.github.com> Date: Sun, 7 Apr 2024 19:02:12 +0300 Subject: [PATCH] Fix broken test (#4104) --- test/test_multitensor.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/test/test_multitensor.py b/test/test_multitensor.py index fda003e0..974ec92b 100644 --- a/test/test_multitensor.py +++ b/test/test_multitensor.py @@ -40,18 +40,21 @@ class TestMultiTensor(unittest.TestCase): assert lb.shape == (128,) (X + X).realize() - @unittest.skipIf(Device.DEFAULT == "METAL", "metal multi-device is fake") def test_sharded_memory(self): + # Buffer may be stuck in track_cross_buffer + for x in (d_zero, d0, d1, d2, d3): Device[x].synchronize() mem_base = GlobalCounters.mem_used X = Tensor.ones(256).contiguous().realize() assert GlobalCounters.mem_used-mem_base== X.dtype.itemsize * 256, GlobalCounters.mem_used-mem_base X.shard_((d0, d1, d2, d3)).realize() + for x in (d_zero, d0, d1, d2, d3): Device[x].synchronize() assert GlobalCounters.mem_used-mem_base == X.dtype.itemsize * 256 * 4, GlobalCounters.mem_used-mem_base X = Tensor.ones(256).contiguous().realize() assert GlobalCounters.mem_used-mem_base == X.dtype.itemsize * 256, GlobalCounters.mem_used-mem_base X.shard_((d0, d1, d2, d3), axis=0).realize() + for x in (d_zero, d0, d1, d2, d3): Device[x].synchronize() assert GlobalCounters.mem_used-mem_base == X.dtype.itemsize * 256, GlobalCounters.mem_used-mem_base X = Tensor.ones(256).realize()