mirror of https://github.com/commaai/tinygrad.git
use old cumsum optimization for arange (#3813)
revert to old cumsum opt while phi simplification is disabled. added a flops complexity test for this
This commit is contained in:
parent
ac866eaf5a
commit
a6ed2ae3c6
|
@ -0,0 +1,14 @@
|
|||
import unittest
|
||||
from tinygrad import Tensor, GlobalCounters
|
||||
|
||||
class TestArange(unittest.TestCase):
|
||||
def _get_flops(self, N):
|
||||
GlobalCounters.reset()
|
||||
Tensor.arange(N).realize()
|
||||
return GlobalCounters.global_ops
|
||||
|
||||
def test_complexity(self):
|
||||
f1 = self._get_flops(256)
|
||||
f2 = self._get_flops(2560)
|
||||
print(f"{f1=}, {f2=}")
|
||||
assert f2 / f1 < 15, f"bad complexity, flops {f2/f1:.1f}X while inputs 10X"
|
|
@ -268,7 +268,7 @@ class Tensor:
|
|||
if stop is None: stop, start = start, 0
|
||||
assert all(isinstance(s, (int, float)) for s in (start, stop, step)), "symbolic arange not supported"
|
||||
dtype = kwargs.pop("dtype", dtypes.default_float if any(isinstance(x, float) for x in (start, stop, step)) else dtypes.default_int)
|
||||
return (Tensor.full((math.ceil((stop-start)/step),), step, dtype=dtype, **kwargs)._cumsum() + (start - step)).cast(dtype)
|
||||
return (Tensor.full((math.ceil((stop-start)/step),), step, dtype=dtype, **kwargs).cumsum() + (start - step)).cast(dtype)
|
||||
|
||||
@staticmethod
|
||||
def eye(dim:int, **kwargs):
|
||||
|
|
Loading…
Reference in New Issue