mirror of https://github.com/commaai/tinygrad.git
added fix and reg tests (#6060)
This commit is contained in:
parent
45bd667a78
commit
6e3eb50fd1
|
@ -883,5 +883,25 @@ class TestBatchNorm(unittest.TestCase):
|
|||
assert synced_si
|
||||
assert unsynced_si
|
||||
|
||||
def helper_test_shard_op(shps, fxn, atol=1e-6, rtol=1e-3):
|
||||
for shp in shps:
|
||||
single_in = Tensor.randn(shp)
|
||||
multi_in = single_in.shard(devices_2, axis=0)
|
||||
|
||||
single_out = fxn(single_in).numpy()
|
||||
multi_out = fxn(multi_in).numpy()
|
||||
|
||||
try:
|
||||
assert single_out.shape == multi_out.shape, f"shape mismatch: single={single_out.shape} | multi={multi_out.shape}"
|
||||
assert single_out.dtype == multi_out.dtype, f"dtype mismatch: single={single_out.dtype} | multi={multi_out.dtype}"
|
||||
np.testing.assert_allclose(single_out, multi_out, atol=atol, rtol=rtol)
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed shape {single_out.shape}: {e}")
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in ("GPU", "CUDA", "METAL"), "no GPU CI")
|
||||
class TestTensorOps(unittest.TestCase):
|
||||
def test_interpolate(self):
|
||||
helper_test_shard_op([(4,16,16),(4,24,24)], lambda x: Tensor.interpolate(x, (19,19)))
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
@ -1965,7 +1965,7 @@ class Tensor:
|
|||
x, expand = self, list(self.shape)
|
||||
for i in range(-len(size), 0):
|
||||
scale = (self.shape[i] - int(align_corners)) / (size[i] - int(align_corners))
|
||||
arr, reshape = Tensor.arange(size[i], dtype=dtypes.float32), [1] * self.ndim
|
||||
arr, reshape = Tensor.arange(size[i], dtype=dtypes.float32, device=self.device), [1] * self.ndim
|
||||
index = (scale*arr if align_corners else (scale*(arr+0.5))-0.5).clip(0, self.shape[i]-1)
|
||||
reshape[i] = expand[i] = size[i]
|
||||
low, high, perc = [y.reshape(reshape).expand(expand) for y in (index.floor(), index.ceil(), index - index.floor())]
|
||||
|
|
Loading…
Reference in New Issue