test dtypes of return values of cumsum, argmax/min, multinomial (#2933)

* test dtypes of return values of cumsum, argmax/min, multinomial

cumsum behaves like sum, and functions that return an index return in dtypes.default_int

* because webgpu is different
This commit is contained in:
chenyu 2023-12-25 11:33:17 -05:00 committed by GitHub
parent 12996d3a7d
commit 8a8aed23d2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 29 additions and 5 deletions

View File

@ -7,6 +7,9 @@ from tinygrad.tensor import Tensor, dtypes
from typing import Any, List
from hypothesis import given, settings, strategies as st
core_dtypes = list(DTYPES_DICT.values())
floats = [dt for dt in core_dtypes if dtypes.is_float(dt)]
def is_dtype_supported(dtype: DType, device: str = Device.DEFAULT):
# for GPU, cl_khr_fp16 isn't supported
# for LLVM, it segfaults because it can't link to the casting function
@ -305,8 +308,14 @@ class TestTypeSpec(unittest.TestCase):
assert Tensor.arange(3, 9, 0.7).dtype == dtypes.default_float
assert Tensor.arange(3, 8.5, 3).dtype == dtypes.default_float
core_dtypes = list(DTYPES_DICT.values())
floats = [dt for dt in core_dtypes if dtypes.is_float(dt)]
@given(st.sampled_from(core_dtypes),
st.sampled_from([dtypes.int8,dtypes.int16,dtypes.int32,dtypes.int64]), st.sampled_from([dtypes.float16,dtypes.float32,dtypes.float64]))
def test_functions_return_index(self, dtype, default_int, default_float):
dtypes.default_int, dtypes.default_float = default_int, default_float
assert Tensor([0, 1], dtype=dtype).argmax().dtype == dtypes.default_int
assert Tensor([0, 1], dtype=dtype).argmin().dtype == dtypes.default_int
assert Tensor([0, 1], dtype=dtype).multinomial().dtype == dtypes.default_int
class TestTypePromotion(unittest.TestCase):
@given(st.sampled_from(core_dtypes))
def test_self_promo_to_self(self, dtype):
@ -415,6 +424,21 @@ class TestAutoCastType(unittest.TestCase):
assert (Tensor([0, 1], dtype=dtypes.float32)).sum().dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.float64)).sum().dtype == dtypes.float64
def test_cumsum(self):
assert (Tensor([0, 1], dtype=dtypes.bool)).cumsum(0).dtype == dtypes.int32
assert (Tensor([0, 1], dtype=dtypes.int8)).cumsum(0).dtype == dtypes.int32
assert (Tensor([0, 1], dtype=dtypes.int16)).cumsum(0).dtype == dtypes.int32
assert (Tensor([0, 1], dtype=dtypes.int32)).cumsum(0).dtype == dtypes.int32
assert (Tensor([0, 1], dtype=dtypes.int64)).cumsum(0).dtype == dtypes.int64
assert (Tensor([0, 1], dtype=dtypes.uint8)).cumsum(0).dtype == dtypes.uint32
assert (Tensor([0, 1], dtype=dtypes.uint16)).cumsum(0).dtype == dtypes.uint32
assert (Tensor([0, 1], dtype=dtypes.uint32)).cumsum(0).dtype == dtypes.uint32
assert (Tensor([0, 1], dtype=dtypes.uint64)).cumsum(0).dtype == dtypes.uint64
assert (Tensor([0, 1], dtype=dtypes.float16)).cumsum(0).dtype == dtypes.float16
assert (Tensor([0, 1], dtype=dtypes.bfloat16)).cumsum(0).dtype == dtypes.bfloat16
assert (Tensor([0, 1], dtype=dtypes.float32)).cumsum(0).dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.float64)).cumsum(0).dtype == dtypes.float64
@given(st.sampled_from(core_dtypes), st.sampled_from(core_dtypes))
def test_matmul(self, dt1, dt2):
assert (Tensor([0, 1], dtype=dt1) @ Tensor([0, 1], dtype=dt2)).dtype == least_upper_dtype(dt1, dt2)

View File

@ -231,7 +231,7 @@ class Tensor:
cdf = (cw := weight.cumsum(1).float()) / cw[:, -1].unsqueeze(1)
unif_samples = Tensor.rand(num_samples, cdf.shape[0], 1)
indices = (unif_samples.expand((-1, -1, cdf.shape[1])) >= cdf).sum(2).permute((1, 0))
return (indices.squeeze(0) if self.ndim == 1 else indices).cast(dtypes.int32)
return (indices.squeeze(0) if self.ndim == 1 else indices).cast(dtypes.default_int)
# ***** toposort and backward pass *****
@ -517,11 +517,11 @@ class Tensor:
def argmax(self, axis=None, keepdim=False):
if axis is None:
idx = (self == self.max(axis)) * Tensor.arange(prod(self.shape)-1,-1,-1, requires_grad=False, device=self.device).reshape(self.shape)
return prod(self.shape) - idx.max() - 1
return (prod(self.shape) - idx.max() - 1).cast(dtypes.default_int)
axis = axis + len(self.shape) if axis < 0 else axis
m = self == self.max(axis=axis, keepdim=True)
idx = m * Tensor.arange(self.shape[axis]-1,-1,-1, requires_grad=False, device=self.device).reshape(self.shape[axis], *[1]*(self.ndim-axis-1))
return self.shape[axis]-idx.max(axis=axis, keepdim=keepdim)-1
return (self.shape[axis]-idx.max(axis=axis, keepdim=keepdim)-1).cast(dtypes.default_int)
def argmin(self, axis=None, keepdim=False): return (-self).argmax(axis=axis, keepdim=keepdim)
@staticmethod