added fix and reg tests (#6060)

This commit is contained in:
Tobias Fischer 2024-08-12 21:00:48 -04:00 committed by GitHub
parent 45bd667a78
commit 6e3eb50fd1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 21 additions and 1 deletions

View File

@ -883,5 +883,25 @@ class TestBatchNorm(unittest.TestCase):
assert synced_si assert synced_si
assert unsynced_si assert unsynced_si
def helper_test_shard_op(shps, fxn, atol=1e-6, rtol=1e-3):
for shp in shps:
single_in = Tensor.randn(shp)
multi_in = single_in.shard(devices_2, axis=0)
single_out = fxn(single_in).numpy()
multi_out = fxn(multi_in).numpy()
try:
assert single_out.shape == multi_out.shape, f"shape mismatch: single={single_out.shape} | multi={multi_out.shape}"
assert single_out.dtype == multi_out.dtype, f"dtype mismatch: single={single_out.dtype} | multi={multi_out.dtype}"
np.testing.assert_allclose(single_out, multi_out, atol=atol, rtol=rtol)
except Exception as e:
raise Exception(f"Failed shape {single_out.shape}: {e}")
@unittest.skipIf(CI and Device.DEFAULT in ("GPU", "CUDA", "METAL"), "no GPU CI")
class TestTensorOps(unittest.TestCase):
def test_interpolate(self):
helper_test_shard_op([(4,16,16),(4,24,24)], lambda x: Tensor.interpolate(x, (19,19)))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -1965,7 +1965,7 @@ class Tensor:
x, expand = self, list(self.shape) x, expand = self, list(self.shape)
for i in range(-len(size), 0): for i in range(-len(size), 0):
scale = (self.shape[i] - int(align_corners)) / (size[i] - int(align_corners)) scale = (self.shape[i] - int(align_corners)) / (size[i] - int(align_corners))
arr, reshape = Tensor.arange(size[i], dtype=dtypes.float32), [1] * self.ndim arr, reshape = Tensor.arange(size[i], dtype=dtypes.float32, device=self.device), [1] * self.ndim
index = (scale*arr if align_corners else (scale*(arr+0.5))-0.5).clip(0, self.shape[i]-1) index = (scale*arr if align_corners else (scale*(arr+0.5))-0.5).clip(0, self.shape[i]-1)
reshape[i] = expand[i] = size[i] reshape[i] = expand[i] = size[i]
low, high, perc = [y.reshape(reshape).expand(expand) for y in (index.floor(), index.ceil(), index - index.floor())] low, high, perc = [y.reshape(reshape).expand(expand) for y in (index.floor(), index.ceil(), index - index.floor())]