use dtypes.int32 as return dtype for functions that return indices (#3827)

behavior matches jax. It's fine to have a tensor greater than max int8 size even if we set default int to int8
This commit is contained in:
chenyu 2024-03-19 17:06:57 -04:00 committed by GitHub
parent fa1921ec7d
commit 99cbc24390
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 7 additions and 8 deletions

View File

@ -405,13 +405,12 @@ class TestTypeSpec(unittest.TestCase):
def test_bool_ops(self, dtype, op):
assert op(Tensor.rand(4, 4, dtype=dtype), Tensor.rand(4, 4, dtype=dtype)).dtype == dtypes.bool
@given(strat.sampled_from(core_dtypes),
strat.sampled_from([dtypes.int8,dtypes.int16,dtypes.int32,dtypes.int64]), strat.sampled_from([dtypes.float16,dtypes.float32,dtypes.float64]))
@given(strat.sampled_from(core_dtypes), strat.sampled_from(dtype_ints), strat.sampled_from(dtype_floats))
def test_functions_return_index(self, dtype, default_int, default_float):
dtypes.default_int, dtypes.default_float = default_int, default_float
assert Tensor([0, 1], dtype=dtype).argmax().dtype == dtypes.default_int
assert Tensor([0, 1], dtype=dtype).argmin().dtype == dtypes.default_int
assert Tensor([0, 1], dtype=dtype).multinomial().dtype == dtypes.default_int
assert Tensor([0, 1], dtype=dtype).argmax().dtype == dtypes.int32
assert Tensor([0, 1], dtype=dtype).argmin().dtype == dtypes.int32
assert Tensor([0, 1], dtype=dtype).multinomial().dtype == dtypes.int32
class TestTypePromotion(unittest.TestCase):
@given(strat.sampled_from(core_dtypes))

View File

@ -325,7 +325,7 @@ class Tensor:
cdf = (cw := weight.cumsum(1).float()) / cw[:, -1].unsqueeze(1)
unif_samples = Tensor.rand(num_samples, cdf.shape[0], 1, device=self.device)
indices = (unif_samples.expand((-1, -1, cdf.shape[1])) >= cdf).sum(2).permute((1, 0))
return (indices.squeeze(0) if self.ndim == 1 else indices).cast(dtypes.default_int)
return (indices.squeeze(0) if self.ndim == 1 else indices).cast(dtypes.int32)
# ***** toposort and backward pass *****
@ -635,11 +635,11 @@ class Tensor:
def argmax(self, axis=None, keepdim=False):
if axis is None:
idx = (self == self.max(axis)) * Tensor.arange(prod(self.shape)-1,-1,-1, requires_grad=False, device=self.device).reshape(self.shape)
return (prod(self.shape) - idx.max() - 1).cast(dtypes.default_int)
return (prod(self.shape) - idx.max() - 1).cast(dtypes.int32)
axis = axis + len(self.shape) if axis < 0 else axis
m = self == self.max(axis=axis, keepdim=True)
idx = m * Tensor.arange(self.shape[axis]-1,-1,-1, requires_grad=False, device=self.device).reshape(self.shape[axis], *[1]*(self.ndim-axis-1))
return (self.shape[axis]-idx.max(axis=axis, keepdim=keepdim)-1).cast(dtypes.default_int)
return (self.shape[axis]-idx.max(axis=axis, keepdim=keepdim)-1).cast(dtypes.int32)
def argmin(self, axis=None, keepdim=False): return (-self).argmax(axis=axis, keepdim=keepdim)
@staticmethod