From 8a8aed23d23bf1d7877e207098d50e9946a59692 Mon Sep 17 00:00:00 2001 From: chenyu Date: Mon, 25 Dec 2023 11:33:17 -0500 Subject: [PATCH] test dtypes of return values of cumsum, argmax/min, multinomial (#2933) * test dtypes of return values of cumsum, argmax/min, multinomial cumsum behaves like sum, and functions that return an index return in dtypes.default_int * because webgpu is different --- test/test_dtype.py | 28 ++++++++++++++++++++++++++-- tinygrad/tensor.py | 6 +++--- 2 files changed, 29 insertions(+), 5 deletions(-) diff --git a/test/test_dtype.py b/test/test_dtype.py index e470b4a2..e1d4e316 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -7,6 +7,9 @@ from tinygrad.tensor import Tensor, dtypes from typing import Any, List from hypothesis import given, settings, strategies as st +core_dtypes = list(DTYPES_DICT.values()) +floats = [dt for dt in core_dtypes if dtypes.is_float(dt)] + def is_dtype_supported(dtype: DType, device: str = Device.DEFAULT): # for GPU, cl_khr_fp16 isn't supported # for LLVM, it segfaults because it can't link to the casting function @@ -305,8 +308,14 @@ class TestTypeSpec(unittest.TestCase): assert Tensor.arange(3, 9, 0.7).dtype == dtypes.default_float assert Tensor.arange(3, 8.5, 3).dtype == dtypes.default_float -core_dtypes = list(DTYPES_DICT.values()) -floats = [dt for dt in core_dtypes if dtypes.is_float(dt)] + @given(st.sampled_from(core_dtypes), + st.sampled_from([dtypes.int8,dtypes.int16,dtypes.int32,dtypes.int64]), st.sampled_from([dtypes.float16,dtypes.float32,dtypes.float64])) + def test_functions_return_index(self, dtype, default_int, default_float): + dtypes.default_int, dtypes.default_float = default_int, default_float + assert Tensor([0, 1], dtype=dtype).argmax().dtype == dtypes.default_int + assert Tensor([0, 1], dtype=dtype).argmin().dtype == dtypes.default_int + assert Tensor([0, 1], dtype=dtype).multinomial().dtype == dtypes.default_int + class TestTypePromotion(unittest.TestCase): @given(st.sampled_from(core_dtypes)) def test_self_promo_to_self(self, dtype): @@ -415,6 +424,21 @@ class TestAutoCastType(unittest.TestCase): assert (Tensor([0, 1], dtype=dtypes.float32)).sum().dtype == dtypes.float32 assert (Tensor([0, 1], dtype=dtypes.float64)).sum().dtype == dtypes.float64 + def test_cumsum(self): + assert (Tensor([0, 1], dtype=dtypes.bool)).cumsum(0).dtype == dtypes.int32 + assert (Tensor([0, 1], dtype=dtypes.int8)).cumsum(0).dtype == dtypes.int32 + assert (Tensor([0, 1], dtype=dtypes.int16)).cumsum(0).dtype == dtypes.int32 + assert (Tensor([0, 1], dtype=dtypes.int32)).cumsum(0).dtype == dtypes.int32 + assert (Tensor([0, 1], dtype=dtypes.int64)).cumsum(0).dtype == dtypes.int64 + assert (Tensor([0, 1], dtype=dtypes.uint8)).cumsum(0).dtype == dtypes.uint32 + assert (Tensor([0, 1], dtype=dtypes.uint16)).cumsum(0).dtype == dtypes.uint32 + assert (Tensor([0, 1], dtype=dtypes.uint32)).cumsum(0).dtype == dtypes.uint32 + assert (Tensor([0, 1], dtype=dtypes.uint64)).cumsum(0).dtype == dtypes.uint64 + assert (Tensor([0, 1], dtype=dtypes.float16)).cumsum(0).dtype == dtypes.float16 + assert (Tensor([0, 1], dtype=dtypes.bfloat16)).cumsum(0).dtype == dtypes.bfloat16 + assert (Tensor([0, 1], dtype=dtypes.float32)).cumsum(0).dtype == dtypes.float32 + assert (Tensor([0, 1], dtype=dtypes.float64)).cumsum(0).dtype == dtypes.float64 + @given(st.sampled_from(core_dtypes), st.sampled_from(core_dtypes)) def test_matmul(self, dt1, dt2): assert (Tensor([0, 1], dtype=dt1) @ Tensor([0, 1], dtype=dt2)).dtype == least_upper_dtype(dt1, dt2) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index b974c438..41ac2909 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -231,7 +231,7 @@ class Tensor: cdf = (cw := weight.cumsum(1).float()) / cw[:, -1].unsqueeze(1) unif_samples = Tensor.rand(num_samples, cdf.shape[0], 1) indices = (unif_samples.expand((-1, -1, cdf.shape[1])) >= cdf).sum(2).permute((1, 0)) - return (indices.squeeze(0) if self.ndim == 1 else indices).cast(dtypes.int32) + return (indices.squeeze(0) if self.ndim == 1 else indices).cast(dtypes.default_int) # ***** toposort and backward pass ***** @@ -517,11 +517,11 @@ class Tensor: def argmax(self, axis=None, keepdim=False): if axis is None: idx = (self == self.max(axis)) * Tensor.arange(prod(self.shape)-1,-1,-1, requires_grad=False, device=self.device).reshape(self.shape) - return prod(self.shape) - idx.max() - 1 + return (prod(self.shape) - idx.max() - 1).cast(dtypes.default_int) axis = axis + len(self.shape) if axis < 0 else axis m = self == self.max(axis=axis, keepdim=True) idx = m * Tensor.arange(self.shape[axis]-1,-1,-1, requires_grad=False, device=self.device).reshape(self.shape[axis], *[1]*(self.ndim-axis-1)) - return self.shape[axis]-idx.max(axis=axis, keepdim=keepdim)-1 + return (self.shape[axis]-idx.max(axis=axis, keepdim=keepdim)-1).cast(dtypes.default_int) def argmin(self, axis=None, keepdim=False): return (-self).argmax(axis=axis, keepdim=keepdim) @staticmethod