mirror of https://github.com/commaai/tinygrad.git
fast compare for lazyop (#2893)
This commit is contained in:
parent
1500aca43d
commit
f6c7833f9f
|
@ -27,12 +27,26 @@ class TestFusionOp(unittest.TestCase):
|
|||
def test_recursive_add(self):
|
||||
st = time.perf_counter()
|
||||
a = Tensor([1,2,3,4])
|
||||
for _ in range(20): a = a + a
|
||||
for _ in range(24): a = a + a
|
||||
sched = create_schedule([a.lazydata], None)
|
||||
ji = lower_schedule_item(sched[-1])
|
||||
et = time.perf_counter()
|
||||
self.assertLess(et-st, 10.0)
|
||||
self.assertLess(time.perf_counter()-st, 1.0)
|
||||
assert isinstance(ji, InterpretedASTRunner) or len(ji.prg) < 5000
|
||||
|
||||
def test_recursive_add_cmp(self):
|
||||
st = time.perf_counter()
|
||||
a = Tensor([1,2,3,4])
|
||||
for _ in range(24): a = a + a
|
||||
sched1 = create_schedule([a.lazydata], None)
|
||||
b = Tensor([1,2,3,4])
|
||||
for _ in range(24): b = b + b
|
||||
sched2 = create_schedule([b.lazydata], None)
|
||||
c = Tensor([1,2,3,4])
|
||||
for _ in range(23): c = c + c
|
||||
sched3 = create_schedule([c.lazydata], None)
|
||||
assert sched1[-1].ast == sched2[-1].ast
|
||||
assert sched1[-1].ast != sched3[-1].ast
|
||||
self.assertLess(time.perf_counter()-st, 1.0)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
|
|
|
@ -45,11 +45,18 @@ class ScheduleItem:
|
|||
inputs: Tuple[LazyBuffer, ...]
|
||||
var_vals: Dict[Variable, int]
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@dataclass(frozen=True, eq=False)
|
||||
class LazyOp:
|
||||
op: Op
|
||||
src: Tuple[LazyOp, ...] = ()
|
||||
arg: Any = None
|
||||
def cached_compare(self, x, context):
|
||||
if id(self) == id(x): return True
|
||||
if self.op != x.op or self.arg != x.arg or len(self.src) != len(x.src): return False
|
||||
if (self,x) in context: return context[(self,x)]
|
||||
ret = context[self,x] = all(a.cached_compare(b, context) for a,b in zip(self.src, x.src))
|
||||
return ret
|
||||
def __eq__(self, x): return self.cached_compare(x, context={})
|
||||
def __repr__(self): return f"LazyOp(op={self.op}, src={self.src}, arg={self.arg})"
|
||||
@functools.cached_property
|
||||
def hash(self): return hash((self.op, self.src, self.arg))
|
||||
|
|
Loading…
Reference in New Issue