From 6e3eb50fd172d7599e9a0aaea88b772460bd3a15 Mon Sep 17 00:00:00 2001 From: Tobias Fischer Date: Mon, 12 Aug 2024 21:00:48 -0400 Subject: [PATCH] added fix and reg tests (#6060) --- test/test_multitensor.py | 20 ++++++++++++++++++++ tinygrad/tensor.py | 2 +- 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/test/test_multitensor.py b/test/test_multitensor.py index 6ec3038a..76b55136 100644 --- a/test/test_multitensor.py +++ b/test/test_multitensor.py @@ -883,5 +883,25 @@ class TestBatchNorm(unittest.TestCase): assert synced_si assert unsynced_si +def helper_test_shard_op(shps, fxn, atol=1e-6, rtol=1e-3): + for shp in shps: + single_in = Tensor.randn(shp) + multi_in = single_in.shard(devices_2, axis=0) + + single_out = fxn(single_in).numpy() + multi_out = fxn(multi_in).numpy() + + try: + assert single_out.shape == multi_out.shape, f"shape mismatch: single={single_out.shape} | multi={multi_out.shape}" + assert single_out.dtype == multi_out.dtype, f"dtype mismatch: single={single_out.dtype} | multi={multi_out.dtype}" + np.testing.assert_allclose(single_out, multi_out, atol=atol, rtol=rtol) + except Exception as e: + raise Exception(f"Failed shape {single_out.shape}: {e}") + +@unittest.skipIf(CI and Device.DEFAULT in ("GPU", "CUDA", "METAL"), "no GPU CI") +class TestTensorOps(unittest.TestCase): + def test_interpolate(self): + helper_test_shard_op([(4,16,16),(4,24,24)], lambda x: Tensor.interpolate(x, (19,19))) + if __name__ == '__main__': unittest.main() diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 77dae28a..cffff734 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1965,7 +1965,7 @@ class Tensor: x, expand = self, list(self.shape) for i in range(-len(size), 0): scale = (self.shape[i] - int(align_corners)) / (size[i] - int(align_corners)) - arr, reshape = Tensor.arange(size[i], dtype=dtypes.float32), [1] * self.ndim + arr, reshape = Tensor.arange(size[i], dtype=dtypes.float32, device=self.device), [1] * self.ndim index = (scale*arr if align_corners else (scale*(arr+0.5))-0.5).clip(0, self.shape[i]-1) reshape[i] = expand[i] = size[i] low, high, perc = [y.reshape(reshape).expand(expand) for y in (index.floor(), index.ceil(), index - index.floor())]