mirror of https://github.com/commaai/tinygrad.git
revert UOps eq, this needs to be isolated in realize.py (#6063)
This reverts commit dccca7f227
.
This commit is contained in:
parent
fa84e6ec48
commit
9145ad52ff
|
@ -1,6 +1,8 @@
|
|||
import sys
|
||||
import sys, unittest
|
||||
from typing import Optional, Set, Tuple
|
||||
import numpy as np
|
||||
from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad.codegen.uops import UOp
|
||||
from tinygrad.tensor import _to_np_dtype
|
||||
from tinygrad.engine.realize import Runner
|
||||
from tinygrad.dtype import DType
|
||||
|
@ -51,3 +53,20 @@ def rand_for_dtype(dt:DType, size:int):
|
|||
elif dt == dtypes.bool:
|
||||
return np.random.choice([True, False], size=size)
|
||||
return np.random.uniform(-10, 10, size=size).astype(_to_np_dtype(dt))
|
||||
|
||||
class TestUOps(unittest.TestCase):
|
||||
def assert_equiv_uops(self, uop1:UOp, uop2:UOp, cache:Optional[Set[Tuple[UOp, UOp]]]=None):
|
||||
if cache is None: cache = set()
|
||||
if (uop1, uop2) in cache: return
|
||||
cache.add((uop1, uop2))
|
||||
# NOTE: direct UOps __eq__ is comparing object reference, use this function to compare two uops
|
||||
try:
|
||||
self.assertIs(uop1.op, uop2.op)
|
||||
self.assertEqual(uop1.dtype, uop2.dtype)
|
||||
self.assertEqual(uop1.arg, uop2.arg)
|
||||
self.assertEqual(len(uop1.src), len(uop2.src))
|
||||
for s1, s2 in zip(uop1.src, uop2.src): self.assert_equiv_uops(s1, s2)
|
||||
except AssertionError as e:
|
||||
print(f"{uop1=}")
|
||||
print(f"{uop2=}")
|
||||
raise e
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
import unittest, itertools
|
||||
from test.helpers import TestUOps
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.ops import BinaryOps, TernaryOps, ReduceOps, UnaryOps # noqa: F401
|
||||
from tinygrad.codegen.uops import UOps, UOp, PatternMatcher, UPat
|
||||
from tinygrad.codegen.uopgraph import constant_folder
|
||||
|
||||
class TestPatternMatcher(unittest.TestCase):
|
||||
class TestPatternMatcher(TestUOps):
|
||||
def test_simple_match(self):
|
||||
matcher = PatternMatcher([(UPat(UOps.CONST, name="x", dtype=dtypes.float), lambda x: x)])
|
||||
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import unittest
|
||||
from test.helpers import TestUOps
|
||||
from tinygrad import dtypes, Variable
|
||||
from tinygrad.dtype import PtrDType
|
||||
from tinygrad.helpers import DEBUG
|
||||
|
@ -89,7 +90,7 @@ class TestGraphRewrite(unittest.TestCase):
|
|||
self.assertEqual(sink.src[1].op, UOps.CONST)
|
||||
self.assertEqual(len([x for x in sink.sparents if x.op is UOps.CONST]), 3)
|
||||
|
||||
class TestUOpGraph(unittest.TestCase):
|
||||
class TestUOpGraph(TestUOps):
|
||||
def test_add_constant_fold(self):
|
||||
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
|
||||
c2 = UOp(UOps.CONST, dtypes.float, arg=2.0)
|
||||
|
@ -160,7 +161,7 @@ class TestUOpGraph(unittest.TestCase):
|
|||
# possible
|
||||
val = UOp(UOps.LOAD, dtypes.float.vec(4), (d1, idx))
|
||||
xyzw = tuple(UOp(UOps.GEP, dtypes.float, (val,), i) for i in range(4))
|
||||
self.assertEqual(_test_vec(xyzw), val)
|
||||
self.assert_equiv_uops(_test_vec(xyzw), val)
|
||||
|
||||
# unaligned
|
||||
val = UOp(UOps.LOAD, dtypes.float.vec(4), (d1, idx))
|
||||
|
@ -189,7 +190,7 @@ class TestUOpGraph(unittest.TestCase):
|
|||
geps = [UOp(UOps.GEP, dtypes.float, (vec,), i) for i in range(vec_size)]
|
||||
g = UOpGraph(geps)
|
||||
for uop, const in zip(g.uops, consts):
|
||||
self.assertEqual(uop, const)
|
||||
self.assert_equiv_uops(uop, const)
|
||||
|
||||
def test_wmma_vectorize_fold(self):
|
||||
for i in [2, 4, 8]:
|
||||
|
@ -198,7 +199,7 @@ class TestUOpGraph(unittest.TestCase):
|
|||
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
|
||||
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc))
|
||||
g = UOpGraph([wmma])
|
||||
self.assertEqual(g.uops[0], acc)
|
||||
self.assert_equiv_uops(g.uops[0], acc)
|
||||
self.assertEqual(len(g.uops), 1)
|
||||
|
||||
for i in [2, 4, 8]:
|
||||
|
@ -207,7 +208,7 @@ class TestUOpGraph(unittest.TestCase):
|
|||
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
|
||||
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc))
|
||||
g = UOpGraph([wmma])
|
||||
self.assertEqual(g.uops[0], acc)
|
||||
self.assert_equiv_uops(g.uops[0], acc)
|
||||
self.assertEqual(len(g.uops), 1)
|
||||
|
||||
def test_wmma_vectorize_no_fold(self):
|
||||
|
@ -219,7 +220,7 @@ class TestUOpGraph(unittest.TestCase):
|
|||
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
|
||||
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc))
|
||||
g = UOpGraph([wmma])
|
||||
self.assertEqual(g.uops[-1], wmma)
|
||||
self.assert_equiv_uops(g.uops[-1], wmma)
|
||||
|
||||
for i in [4, 8]:
|
||||
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable(f'tmp{i}', 0.0, 1.0))
|
||||
|
@ -229,7 +230,7 @@ class TestUOpGraph(unittest.TestCase):
|
|||
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
|
||||
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc))
|
||||
g = UOpGraph([wmma])
|
||||
self.assertEqual(g.uops[-1], wmma)
|
||||
self.assert_equiv_uops(g.uops[-1], wmma)
|
||||
|
||||
for i in [2, 4, 8]:
|
||||
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i),
|
||||
|
@ -238,7 +239,7 @@ class TestUOpGraph(unittest.TestCase):
|
|||
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
|
||||
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc))
|
||||
g = UOpGraph([wmma])
|
||||
self.assertEqual(g.uops[-1], wmma)
|
||||
self.assert_equiv_uops(g.uops[-1], wmma)
|
||||
|
||||
for i in [2, 4, 8]:
|
||||
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable(f'tmp{i}', 0.0, 1.0))
|
||||
|
@ -247,7 +248,7 @@ class TestUOpGraph(unittest.TestCase):
|
|||
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
|
||||
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc))
|
||||
g = UOpGraph([wmma])
|
||||
self.assertEqual(g.uops[-1], wmma)
|
||||
self.assert_equiv_uops(g.uops[-1], wmma)
|
||||
|
||||
def test_cast_alu_fold(self):
|
||||
d0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.bool), arg=0)
|
||||
|
@ -293,9 +294,9 @@ class TestUOpGraph(unittest.TestCase):
|
|||
uops = UOpGraph([UOp(UOps.STORE, None, (glbl0, idx, ld1+ld0))])
|
||||
ld0, ld1 = uops[-1].src[2].src
|
||||
# ld0 becomes the invalid value
|
||||
self.assertEqual(ld1, UOp.const(dtypes.int, 2))
|
||||
self.assert_equiv_uops(ld1, UOp.const(dtypes.int, 2))
|
||||
# the gate and invalid value are deleted from ld1
|
||||
self.assertEqual(ld0, UOp.load(glbl2, idx, dtype=dtypes.int))
|
||||
self.assert_equiv_uops(ld0, UOp.load(glbl2, idx, dtype=dtypes.int))
|
||||
|
||||
def test_fold_gated_load_local(self):
|
||||
glbl0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0)
|
||||
|
@ -308,9 +309,9 @@ class TestUOpGraph(unittest.TestCase):
|
|||
uops = UOpGraph([UOp(UOps.STORE, None, (glbl0, lidx, ld1+ld0))])
|
||||
ld0, ld1 = uops[-1].src[2].src
|
||||
# ld0 becomes the invalid value
|
||||
self.assertEqual(ld1, UOp.const(dtypes.int, 2))
|
||||
self.assert_equiv_uops(ld1, UOp.const(dtypes.int, 2))
|
||||
# the gate and invalid value are deleted from ld1
|
||||
self.assertEqual(ld0, UOp.load(smem, lidx+2, barrier, dtype=dtypes.int))
|
||||
self.assert_equiv_uops(ld0, UOp.load(smem, lidx+2, barrier, dtype=dtypes.int))
|
||||
|
||||
def test_fold_gated_store(self):
|
||||
glbl = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0)
|
||||
|
@ -322,7 +323,7 @@ class TestUOpGraph(unittest.TestCase):
|
|||
uops = UOpGraph([st0, st1])
|
||||
# only the second store happens
|
||||
self.assertEqual(len(uops.uops), 4)
|
||||
self.assertEqual(uops[-1], UOp.store(glbl, idx1, val))
|
||||
self.assert_equiv_uops(uops[-1], UOp.store(glbl, idx1, val))
|
||||
|
||||
def test_asserts_bad_gate(self):
|
||||
glbl0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0)
|
||||
|
@ -559,7 +560,7 @@ class TestLoadStoreFolder(unittest.TestCase):
|
|||
|
||||
def gate_rewrite(sink): return graph_rewrite(sink, constant_folder + expander + reducer)
|
||||
|
||||
class TestIFUOps(unittest.TestCase):
|
||||
class TestIFUOps(TestUOps):
|
||||
def test_create_ifs(self):
|
||||
gbuf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 0)
|
||||
sbuf = UOp(UOps.DEFINE_LOCAL, PtrDType(dtypes.float), (), ("smem", 4))
|
||||
|
@ -575,7 +576,7 @@ class TestIFUOps(unittest.TestCase):
|
|||
sink = gate_rewrite(sink)
|
||||
if_uops = [u for u in sink.parents if u.op is UOps.IF]
|
||||
self.assertEqual(len(if_uops), 1)
|
||||
self.assertEqual(if_uops[0].src[0], gate)
|
||||
self.assert_equiv_uops(if_uops[0].src[0], gate)
|
||||
for st in sink.src:
|
||||
self.assertEqual(len(st.src), 3)
|
||||
|
||||
|
@ -593,7 +594,7 @@ class TestIFUOps(unittest.TestCase):
|
|||
sink = gate_rewrite(sink)
|
||||
if_uops = [u for u in sink.parents if u.op is UOps.IF]
|
||||
self.assertEqual(len(if_uops), 1)
|
||||
self.assertEqual(if_uops[0].src[0], gate)
|
||||
self.assert_equiv_uops(if_uops[0].src[0], gate)
|
||||
for st in sink.src:
|
||||
self.assertEqual(len(st.src), 3)
|
||||
|
||||
|
@ -609,19 +610,19 @@ class TestIFUOps(unittest.TestCase):
|
|||
sink = gate_rewrite(sink)
|
||||
if_uops = [u for u in sink.parents if u.op is UOps.IF]
|
||||
self.assertEqual(len(if_uops), 1)
|
||||
self.assertEqual(if_uops[0].src[0], gate)
|
||||
self.assert_equiv_uops(if_uops[0].src[0], gate)
|
||||
for st in sink.src:
|
||||
self.assertEqual(len(st.src), 3)
|
||||
|
||||
class TestDivMod(unittest.TestCase):
|
||||
class TestDivMod(TestUOps):
|
||||
def c(self, c:int): return UOp.const(dtypes.int, c)
|
||||
def x(self, expr:str, nmin:int, nmax:int): return UOp(UOps.DEFINE_VAR, dtypes.int, (self.c(nmin), self.c(nmax)), Variable(expr, nmin, nmax))
|
||||
|
||||
# NOTE: does not simplify to the end
|
||||
def test_const_mod(self):
|
||||
self.assertEqual(mod_folding(self.c(6), 3), self.c(1)*self.c(0))
|
||||
self.assertEqual(mod_folding(self.c(7), 3), self.c(1)*self.c(1))
|
||||
self.assertEqual(mod_folding(self.c(8), 3), self.c(1)*self.c(2))
|
||||
self.assert_equiv_uops(mod_folding(self.c(6), 3), self.c(1)*self.c(0))
|
||||
self.assert_equiv_uops(mod_folding(self.c(7), 3), self.c(1)*self.c(1))
|
||||
self.assert_equiv_uops(mod_folding(self.c(8), 3), self.c(1)*self.c(2))
|
||||
|
||||
def test_var_mod(self):
|
||||
self.assertIsNone(mod_folding(self.x("x", 0, 6), 3))
|
||||
|
@ -629,32 +630,32 @@ class TestDivMod(unittest.TestCase):
|
|||
|
||||
@unittest.skip("does not simplify to the end")
|
||||
def test_add_mod(self):
|
||||
self.assertEqual(mod_folding(self.x("x", 0, 6)+40, 5), self.x("x", 0, 6))
|
||||
self.assertEqual(mod_folding(self.x("x", 0, 6)-40, 5), self.x("x", 0, 6))
|
||||
self.assertEqual(mod_folding(self.x("x", 0, 6)+42, 5), (self.x("x", 0, 6)+2))
|
||||
self.assertEqual(mod_folding(self.x("x", 0, 6)-42, 5), (self.x("x", 0, 6)+3))
|
||||
self.assertEqual(mod_folding(40+self.x("x", 0, 6), 5), self.x("x", 0, 6))
|
||||
self.assertEqual(mod_folding(-40+self.x("x", 0, 6), 5), self.x("x", 0, 6))
|
||||
self.assertEqual(mod_folding(42+self.x("x", 0, 6), 5), (2+self.x("x", 0, 6)))
|
||||
self.assertEqual(mod_folding(-42+self.x("x", 0, 6), 5), (3+self.x("x", 0, 6)))
|
||||
self.assert_equiv_uops(mod_folding(self.x("x", 0, 6)+40, 5), self.x("x", 0, 6))
|
||||
self.assert_equiv_uops(mod_folding(self.x("x", 0, 6)-40, 5), self.x("x", 0, 6))
|
||||
self.assert_equiv_uops(mod_folding(self.x("x", 0, 6)+42, 5), (self.x("x", 0, 6)+2))
|
||||
self.assert_equiv_uops(mod_folding(self.x("x", 0, 6)-42, 5), (self.x("x", 0, 6)+3))
|
||||
self.assert_equiv_uops(mod_folding(40+self.x("x", 0, 6), 5), self.x("x", 0, 6))
|
||||
self.assert_equiv_uops(mod_folding(-40+self.x("x", 0, 6), 5), self.x("x", 0, 6))
|
||||
self.assert_equiv_uops(mod_folding(42+self.x("x", 0, 6), 5), (2+self.x("x", 0, 6)))
|
||||
self.assert_equiv_uops(mod_folding(-42+self.x("x", 0, 6), 5), (3+self.x("x", 0, 6)))
|
||||
|
||||
@unittest.skip("does not simplify to the end")
|
||||
def test_mul_mod(self):
|
||||
self.assertEqual(mod_folding(self.x("x", 0, 6)*40, 5), self.c(0))
|
||||
self.assertEqual(mod_folding(self.x("x", 0, 6)*-40, 5), self.c(0))
|
||||
self.assertEqual(mod_folding(self.x("x", 0, 6)*42, 5), (self.x("x", 0, 6)*2))
|
||||
self.assertEqual(mod_folding(self.x("x", 0, 6)*-42, 5), (self.x("x", 0, 6)*3))
|
||||
self.assertEqual(mod_folding(40*self.x("x", 0, 6), 5), self.c(0))
|
||||
self.assertEqual(mod_folding(-40*self.x("x", 0, 6), 5), self.c(0))
|
||||
self.assertEqual(mod_folding(42*self.x("x", 0, 6), 5), (2*self.x("x", 0, 6)))
|
||||
self.assertEqual(mod_folding(-42*self.x("x", 0, 6), 5), (3*self.x("x", 0, 6)))
|
||||
self.assert_equiv_uops(mod_folding(self.x("x", 0, 6)*40, 5), self.c(0))
|
||||
self.assert_equiv_uops(mod_folding(self.x("x", 0, 6)*-40, 5), self.c(0))
|
||||
self.assert_equiv_uops(mod_folding(self.x("x", 0, 6)*42, 5), (self.x("x", 0, 6)*2))
|
||||
self.assert_equiv_uops(mod_folding(self.x("x", 0, 6)*-42, 5), (self.x("x", 0, 6)*3))
|
||||
self.assert_equiv_uops(mod_folding(40*self.x("x", 0, 6), 5), self.c(0))
|
||||
self.assert_equiv_uops(mod_folding(-40*self.x("x", 0, 6), 5), self.c(0))
|
||||
self.assert_equiv_uops(mod_folding(42*self.x("x", 0, 6), 5), (2*self.x("x", 0, 6)))
|
||||
self.assert_equiv_uops(mod_folding(-42*self.x("x", 0, 6), 5), (3*self.x("x", 0, 6)))
|
||||
|
||||
@unittest.skip("does not simplify to the end now")
|
||||
def test_mul_add_mod(self):
|
||||
x = self.x("x", 0, 10)
|
||||
y = self.x("y", 0, 10)
|
||||
z = self.x("z", 0, 10)
|
||||
self.assertEqual(mod_folding(x*40+y*12+z, 5), (y*2+z))
|
||||
self.assert_equiv_uops(mod_folding(x*40+y*12+z, 5), (y*2+z))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from typing import Optional, Tuple, Any, List
|
||||
import unittest, math, time
|
||||
import unittest, math
|
||||
import numpy as np
|
||||
from tinygrad.tensor import Tensor, _to_np_dtype
|
||||
from tinygrad.helpers import CI, DEBUG, getenv, Context
|
||||
|
@ -11,7 +11,7 @@ from tinygrad.engine.schedule import create_schedule
|
|||
from tinygrad.engine.realize import CompiledRunner, lower_schedule_item, get_kernel
|
||||
from tinygrad.codegen.uops import UOps, NOp, UOp
|
||||
from tinygrad.codegen.uopgraph import UOpGraph
|
||||
from test.helpers import is_dtype_supported
|
||||
from test.helpers import is_dtype_supported, TestUOps as TestEqUOps
|
||||
|
||||
def _uops_to_prg(uops_list, print_uops=False):
|
||||
uops = UOpGraph(uops_list)
|
||||
|
@ -357,39 +357,7 @@ class TestUOpCompare(unittest.TestCase):
|
|||
mul = UOp(UOps.ALU, dtypes.float, (a, b), BinaryOps.MUL)
|
||||
assert (add < mul) or (mul < add), "add and mul with same src should have an order"
|
||||
|
||||
def test_uop_eq_fields(self):
|
||||
a = UOp(UOps.CONST, dtypes.float, (), 2.0)
|
||||
b = UOp(UOps.CONST, dtypes.float, (), 2.0)
|
||||
self.assertEqual(a, b)
|
||||
|
||||
def test_uop_ne_fields(self):
|
||||
a = UOp(UOps.RANGE, dtypes.pyint, (UOp.const(dtypes.pyint, 0), UOp.const(dtypes.pyint, 1)), (1, False))
|
||||
b = UOp(UOps.RANGE, dtypes.pyint, (UOp.const(dtypes.pyint, 0), UOp.const(dtypes.pyint, 2)), (1, False))
|
||||
self.assertNotEqual(a, b)
|
||||
|
||||
def test_recursive_eq_src(self):
|
||||
st = time.perf_counter()
|
||||
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0)
|
||||
idx = UOp.const(dtypes.int, 0)
|
||||
a = UOp(UOps.LOAD, dtypes.float, (buf, idx))
|
||||
for _ in range(24): a += a
|
||||
b = UOp(UOps.LOAD, dtypes.float, (buf, idx))
|
||||
for _ in range(24): b += b
|
||||
self.assertEqual(a, b)
|
||||
self.assertLess(time.perf_counter()-st, 1e-2)
|
||||
|
||||
# NOTE: NOp uses the dataclass compare, this is fine
|
||||
def test_nop_ne(self):
|
||||
a = NOp(UOps.CONST, dtypes.float, (), 2.0, name="a")
|
||||
b = NOp(UOps.CONST, dtypes.float, (), 2.0, name="b")
|
||||
self.assertNotEqual(a, b)
|
||||
|
||||
def test_nop_eq(self):
|
||||
a1 = NOp(UOps.CONST, dtypes.float, (), 2.0, name="a")
|
||||
a2 = NOp(UOps.CONST, dtypes.float, (), 2.0, name="a")
|
||||
self.assertEqual(a1, a2)
|
||||
|
||||
class TestUOpStr(unittest.TestCase):
|
||||
class TestUOpStr(TestEqUOps):
|
||||
def test_uop_str(self):
|
||||
a = UOp(UOps.CONST, dtypes.float, (), 2.0) + UOp(UOps.CONST, dtypes.float, (), 3.0)
|
||||
for _ in range(20): a = a + a
|
||||
|
@ -401,7 +369,7 @@ class TestUOpStr(unittest.TestCase):
|
|||
# nice big complicated uop
|
||||
with Context(NOOPT=1):
|
||||
sink = get_kernel(Device[Device.DEFAULT].renderer, t.schedule()[-1].ast).linearize().uops.sink
|
||||
self.assertEqual(sink, eval(str(sink)))
|
||||
self.assert_equiv_uops(sink, eval(str(sink)))
|
||||
|
||||
def test_nop_str(self):
|
||||
a = NOp(UOps.CONST, dtypes.float, (), 2.0, name="c0") + NOp(UOps.CONST, dtypes.float, (), 3.0, name="c1")
|
||||
|
|
|
@ -43,15 +43,6 @@ class UOp:
|
|||
return (self.op.value, (self.arg if self.op is not UOps.DEFINE_VAR else self.arg.expr) if self.op is not UOps.ALU else \
|
||||
self.arg.value, self.dtype, self.src)
|
||||
def __lt__(self, x:UOp): return self.cmp_tuple < x.cmp_tuple
|
||||
def cached_eq(self, x:UOp, context:Dict[Tuple[int, int], bool]) -> bool:
|
||||
if id(self) == id(x): return True
|
||||
if self.op != x.op or self.dtype != x.dtype or self.arg != x.arg or len(self.src) != len(x.src): return False
|
||||
if (key := (id(self), id(x))) in context: return context[key]
|
||||
return context.setdefault(key, all(a.cached_eq(b, context) for a,b in zip(self.src, x.src)))
|
||||
def __eq__(self, x): return self.cached_eq(x, context={})
|
||||
@functools.cached_property
|
||||
def hash(self): return hash((self.op, self.dtype, self.src, self.arg))
|
||||
def __hash__(self): return self.hash
|
||||
def __repr__(self): return pretty_print(self, lambda x: f"{type(self).__name__}({x.op}, {x.dtype}, arg={x.arg}, src=(%s))")
|
||||
# *** uop syntactic sugar
|
||||
def ufix(self, x): return self.const(x) if not isinstance(x, UOp) else x
|
||||
|
|
Loading…
Reference in New Issue