mirror of https://github.com/commaai/tinygrad.git
fix Tensor.arange if (stop-start) and step have different signs (#5775)
This commit is contained in:
parent
d0fd84e617
commit
600a39771d
|
@ -453,6 +453,9 @@ class TestTypeSpec(unittest.TestCase):
|
|||
_assert_eq(Tensor.arange(5, dtype=dtypes.float16), dtypes.float16, np.arange(5))
|
||||
_assert_eq(Tensor.arange(3, 9, 0.7), dtypes.default_float, np.arange(3, 9, 0.7))
|
||||
_assert_eq(Tensor.arange(3, 8.5, 3), dtypes.default_float, np.arange(3, 8.5, 3))
|
||||
# stop-start and step have different signs
|
||||
_assert_eq(Tensor.arange(3, 5, -2), dtypes.default_int, np.arange(3, 5, -2))
|
||||
_assert_eq(Tensor.arange(5.0, 3.0), dtypes.default_float, np.arange(5.0, 3.0))
|
||||
|
||||
@given(strat.sampled_from(core_dtypes), strat.sampled_from([operator.gt, operator.ge, operator.le, operator.lt, operator.eq, operator.ne]))
|
||||
def test_bool_ops(self, dtype, op):
|
||||
|
|
|
@ -505,6 +505,8 @@ class Tensor:
|
|||
if stop is None: stop, start = start, 0
|
||||
assert all(isinstance(s, (int, float)) for s in (start, stop, step)), f"symbolic arange not supported {start=}, {stop=}, {step=}"
|
||||
dtype = kwargs.pop("dtype", dtypes.default_float if any(isinstance(x, float) for x in (start, stop, step)) else dtypes.default_int)
|
||||
# NOTE: this matches numpy, torch raises RuntimeError if stop-start and step have different signs
|
||||
if (stop-start)/step <= 0: return Tensor([], dtype=dtype, **kwargs)
|
||||
return (Tensor.full((math.ceil((stop-start)/step),), step, dtype=dtype, **kwargs)._cumsum() + (start - step)).cast(dtype)
|
||||
|
||||
@staticmethod
|
||||
|
|
Loading…
Reference in New Issue