use old cumsum optimization for arange (#3813)

revert to old cumsum opt while phi simplification is disabled.

added a flops complexity test for this
This commit is contained in:
chenyu 2024-03-18 20:01:03 -04:00 committed by GitHub
parent ac866eaf5a
commit a6ed2ae3c6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 15 additions and 1 deletions

14
test/test_arange.py Normal file
View File

@ -0,0 +1,14 @@
import unittest
from tinygrad import Tensor, GlobalCounters
class TestArange(unittest.TestCase):
def _get_flops(self, N):
GlobalCounters.reset()
Tensor.arange(N).realize()
return GlobalCounters.global_ops
def test_complexity(self):
f1 = self._get_flops(256)
f2 = self._get_flops(2560)
print(f"{f1=}, {f2=}")
assert f2 / f1 < 15, f"bad complexity, flops {f2/f1:.1f}X while inputs 10X"

View File

@ -268,7 +268,7 @@ class Tensor:
if stop is None: stop, start = start, 0
assert all(isinstance(s, (int, float)) for s in (start, stop, step)), "symbolic arange not supported"
dtype = kwargs.pop("dtype", dtypes.default_float if any(isinstance(x, float) for x in (start, stop, step)) else dtypes.default_int)
return (Tensor.full((math.ceil((stop-start)/step),), step, dtype=dtype, **kwargs)._cumsum() + (start - step)).cast(dtype)
return (Tensor.full((math.ceil((stop-start)/step),), step, dtype=dtype, **kwargs).cumsum() + (start - step)).cast(dtype)
@staticmethod
def eye(dim:int, **kwargs):