fix Tensor.arange if (stop-start) and step have different signs (#5775)

This commit is contained in:
chenyu 2024-07-28 14:34:10 -04:00 committed by GitHub
parent d0fd84e617
commit 600a39771d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 5 additions and 0 deletions

View File

@ -453,6 +453,9 @@ class TestTypeSpec(unittest.TestCase):
_assert_eq(Tensor.arange(5, dtype=dtypes.float16), dtypes.float16, np.arange(5))
_assert_eq(Tensor.arange(3, 9, 0.7), dtypes.default_float, np.arange(3, 9, 0.7))
_assert_eq(Tensor.arange(3, 8.5, 3), dtypes.default_float, np.arange(3, 8.5, 3))
# stop-start and step have different signs
_assert_eq(Tensor.arange(3, 5, -2), dtypes.default_int, np.arange(3, 5, -2))
_assert_eq(Tensor.arange(5.0, 3.0), dtypes.default_float, np.arange(5.0, 3.0))
@given(strat.sampled_from(core_dtypes), strat.sampled_from([operator.gt, operator.ge, operator.le, operator.lt, operator.eq, operator.ne]))
def test_bool_ops(self, dtype, op):

View File

@ -505,6 +505,8 @@ class Tensor:
if stop is None: stop, start = start, 0
assert all(isinstance(s, (int, float)) for s in (start, stop, step)), f"symbolic arange not supported {start=}, {stop=}, {step=}"
dtype = kwargs.pop("dtype", dtypes.default_float if any(isinstance(x, float) for x in (start, stop, step)) else dtypes.default_int)
# NOTE: this matches numpy, torch raises RuntimeError if stop-start and step have different signs
if (stop-start)/step <= 0: return Tensor([], dtype=dtype, **kwargs)
return (Tensor.full((math.ceil((stop-start)/step),), step, dtype=dtype, **kwargs)._cumsum() + (start - step)).cast(dtype)
@staticmethod