mirror of https://github.com/commaai/tinygrad.git
use dtypes.int32 as return dtype for functions that return indices (#3827)
behavior matches jax. It's fine to have a tensor greater than max int8 size even if we set default int to int8
This commit is contained in:
parent
fa1921ec7d
commit
99cbc24390
|
@ -405,13 +405,12 @@ class TestTypeSpec(unittest.TestCase):
|
|||
def test_bool_ops(self, dtype, op):
|
||||
assert op(Tensor.rand(4, 4, dtype=dtype), Tensor.rand(4, 4, dtype=dtype)).dtype == dtypes.bool
|
||||
|
||||
@given(strat.sampled_from(core_dtypes),
|
||||
strat.sampled_from([dtypes.int8,dtypes.int16,dtypes.int32,dtypes.int64]), strat.sampled_from([dtypes.float16,dtypes.float32,dtypes.float64]))
|
||||
@given(strat.sampled_from(core_dtypes), strat.sampled_from(dtype_ints), strat.sampled_from(dtype_floats))
|
||||
def test_functions_return_index(self, dtype, default_int, default_float):
|
||||
dtypes.default_int, dtypes.default_float = default_int, default_float
|
||||
assert Tensor([0, 1], dtype=dtype).argmax().dtype == dtypes.default_int
|
||||
assert Tensor([0, 1], dtype=dtype).argmin().dtype == dtypes.default_int
|
||||
assert Tensor([0, 1], dtype=dtype).multinomial().dtype == dtypes.default_int
|
||||
assert Tensor([0, 1], dtype=dtype).argmax().dtype == dtypes.int32
|
||||
assert Tensor([0, 1], dtype=dtype).argmin().dtype == dtypes.int32
|
||||
assert Tensor([0, 1], dtype=dtype).multinomial().dtype == dtypes.int32
|
||||
|
||||
class TestTypePromotion(unittest.TestCase):
|
||||
@given(strat.sampled_from(core_dtypes))
|
||||
|
|
|
@ -325,7 +325,7 @@ class Tensor:
|
|||
cdf = (cw := weight.cumsum(1).float()) / cw[:, -1].unsqueeze(1)
|
||||
unif_samples = Tensor.rand(num_samples, cdf.shape[0], 1, device=self.device)
|
||||
indices = (unif_samples.expand((-1, -1, cdf.shape[1])) >= cdf).sum(2).permute((1, 0))
|
||||
return (indices.squeeze(0) if self.ndim == 1 else indices).cast(dtypes.default_int)
|
||||
return (indices.squeeze(0) if self.ndim == 1 else indices).cast(dtypes.int32)
|
||||
|
||||
# ***** toposort and backward pass *****
|
||||
|
||||
|
@ -635,11 +635,11 @@ class Tensor:
|
|||
def argmax(self, axis=None, keepdim=False):
|
||||
if axis is None:
|
||||
idx = (self == self.max(axis)) * Tensor.arange(prod(self.shape)-1,-1,-1, requires_grad=False, device=self.device).reshape(self.shape)
|
||||
return (prod(self.shape) - idx.max() - 1).cast(dtypes.default_int)
|
||||
return (prod(self.shape) - idx.max() - 1).cast(dtypes.int32)
|
||||
axis = axis + len(self.shape) if axis < 0 else axis
|
||||
m = self == self.max(axis=axis, keepdim=True)
|
||||
idx = m * Tensor.arange(self.shape[axis]-1,-1,-1, requires_grad=False, device=self.device).reshape(self.shape[axis], *[1]*(self.ndim-axis-1))
|
||||
return (self.shape[axis]-idx.max(axis=axis, keepdim=keepdim)-1).cast(dtypes.default_int)
|
||||
return (self.shape[axis]-idx.max(axis=axis, keepdim=keepdim)-1).cast(dtypes.int32)
|
||||
def argmin(self, axis=None, keepdim=False): return (-self).argmax(axis=axis, keepdim=keepdim)
|
||||
|
||||
@staticmethod
|
||||
|
|
Loading…
Reference in New Issue