fast compare for lazyop (#2893)

This commit is contained in:
George Hotz 2023-12-20 23:32:27 -08:00 committed by GitHub
parent 1500aca43d
commit f6c7833f9f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 25 additions and 4 deletions

View File

@ -27,12 +27,26 @@ class TestFusionOp(unittest.TestCase):
def test_recursive_add(self): def test_recursive_add(self):
st = time.perf_counter() st = time.perf_counter()
a = Tensor([1,2,3,4]) a = Tensor([1,2,3,4])
for _ in range(20): a = a + a for _ in range(24): a = a + a
sched = create_schedule([a.lazydata], None) sched = create_schedule([a.lazydata], None)
ji = lower_schedule_item(sched[-1]) ji = lower_schedule_item(sched[-1])
et = time.perf_counter() self.assertLess(time.perf_counter()-st, 1.0)
self.assertLess(et-st, 10.0)
assert isinstance(ji, InterpretedASTRunner) or len(ji.prg) < 5000 assert isinstance(ji, InterpretedASTRunner) or len(ji.prg) < 5000
def test_recursive_add_cmp(self):
st = time.perf_counter()
a = Tensor([1,2,3,4])
for _ in range(24): a = a + a
sched1 = create_schedule([a.lazydata], None)
b = Tensor([1,2,3,4])
for _ in range(24): b = b + b
sched2 = create_schedule([b.lazydata], None)
c = Tensor([1,2,3,4])
for _ in range(23): c = c + c
sched3 = create_schedule([c.lazydata], None)
assert sched1[-1].ast == sched2[-1].ast
assert sched1[-1].ast != sched3[-1].ast
self.assertLess(time.perf_counter()-st, 1.0)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main(verbosity=2) unittest.main(verbosity=2)

View File

@ -45,11 +45,18 @@ class ScheduleItem:
inputs: Tuple[LazyBuffer, ...] inputs: Tuple[LazyBuffer, ...]
var_vals: Dict[Variable, int] var_vals: Dict[Variable, int]
@dataclass(frozen=True) @dataclass(frozen=True, eq=False)
class LazyOp: class LazyOp:
op: Op op: Op
src: Tuple[LazyOp, ...] = () src: Tuple[LazyOp, ...] = ()
arg: Any = None arg: Any = None
def cached_compare(self, x, context):
if id(self) == id(x): return True
if self.op != x.op or self.arg != x.arg or len(self.src) != len(x.src): return False
if (self,x) in context: return context[(self,x)]
ret = context[self,x] = all(a.cached_compare(b, context) for a,b in zip(self.src, x.src))
return ret
def __eq__(self, x): return self.cached_compare(x, context={})
def __repr__(self): return f"LazyOp(op={self.op}, src={self.src}, arg={self.arg})" def __repr__(self): return f"LazyOp(op={self.op}, src={self.src}, arg={self.arg})"
@functools.cached_property @functools.cached_property
def hash(self): return hash((self.op, self.src, self.arg)) def hash(self): return hash((self.op, self.src, self.arg))