revert UOps eq, this needs to be isolated in realize.py (#6063)

This reverts commit dccca7f227.
This commit is contained in:
qazal 2024-08-13 23:02:34 +08:00 committed by GitHub
parent fa84e6ec48
commit 9145ad52ff
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 66 additions and 86 deletions

View File

@ -1,6 +1,8 @@
import sys
import sys, unittest
from typing import Optional, Set, Tuple
import numpy as np
from tinygrad import Tensor, Device, dtypes
from tinygrad.codegen.uops import UOp
from tinygrad.tensor import _to_np_dtype
from tinygrad.engine.realize import Runner
from tinygrad.dtype import DType
@ -51,3 +53,20 @@ def rand_for_dtype(dt:DType, size:int):
elif dt == dtypes.bool:
return np.random.choice([True, False], size=size)
return np.random.uniform(-10, 10, size=size).astype(_to_np_dtype(dt))
class TestUOps(unittest.TestCase):
def assert_equiv_uops(self, uop1:UOp, uop2:UOp, cache:Optional[Set[Tuple[UOp, UOp]]]=None):
if cache is None: cache = set()
if (uop1, uop2) in cache: return
cache.add((uop1, uop2))
# NOTE: direct UOps __eq__ is comparing object reference, use this function to compare two uops
try:
self.assertIs(uop1.op, uop2.op)
self.assertEqual(uop1.dtype, uop2.dtype)
self.assertEqual(uop1.arg, uop2.arg)
self.assertEqual(len(uop1.src), len(uop2.src))
for s1, s2 in zip(uop1.src, uop2.src): self.assert_equiv_uops(s1, s2)
except AssertionError as e:
print(f"{uop1=}")
print(f"{uop2=}")
raise e

View File

@ -1,10 +1,11 @@
import unittest, itertools
from test.helpers import TestUOps
from tinygrad.dtype import dtypes
from tinygrad.ops import BinaryOps, TernaryOps, ReduceOps, UnaryOps # noqa: F401
from tinygrad.codegen.uops import UOps, UOp, PatternMatcher, UPat
from tinygrad.codegen.uopgraph import constant_folder
class TestPatternMatcher(unittest.TestCase):
class TestPatternMatcher(TestUOps):
def test_simple_match(self):
matcher = PatternMatcher([(UPat(UOps.CONST, name="x", dtype=dtypes.float), lambda x: x)])
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)

View File

@ -1,4 +1,5 @@
import unittest
from test.helpers import TestUOps
from tinygrad import dtypes, Variable
from tinygrad.dtype import PtrDType
from tinygrad.helpers import DEBUG
@ -89,7 +90,7 @@ class TestGraphRewrite(unittest.TestCase):
self.assertEqual(sink.src[1].op, UOps.CONST)
self.assertEqual(len([x for x in sink.sparents if x.op is UOps.CONST]), 3)
class TestUOpGraph(unittest.TestCase):
class TestUOpGraph(TestUOps):
def test_add_constant_fold(self):
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
c2 = UOp(UOps.CONST, dtypes.float, arg=2.0)
@ -160,7 +161,7 @@ class TestUOpGraph(unittest.TestCase):
# possible
val = UOp(UOps.LOAD, dtypes.float.vec(4), (d1, idx))
xyzw = tuple(UOp(UOps.GEP, dtypes.float, (val,), i) for i in range(4))
self.assertEqual(_test_vec(xyzw), val)
self.assert_equiv_uops(_test_vec(xyzw), val)
# unaligned
val = UOp(UOps.LOAD, dtypes.float.vec(4), (d1, idx))
@ -189,7 +190,7 @@ class TestUOpGraph(unittest.TestCase):
geps = [UOp(UOps.GEP, dtypes.float, (vec,), i) for i in range(vec_size)]
g = UOpGraph(geps)
for uop, const in zip(g.uops, consts):
self.assertEqual(uop, const)
self.assert_equiv_uops(uop, const)
def test_wmma_vectorize_fold(self):
for i in [2, 4, 8]:
@ -198,7 +199,7 @@ class TestUOpGraph(unittest.TestCase):
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc))
g = UOpGraph([wmma])
self.assertEqual(g.uops[0], acc)
self.assert_equiv_uops(g.uops[0], acc)
self.assertEqual(len(g.uops), 1)
for i in [2, 4, 8]:
@ -207,7 +208,7 @@ class TestUOpGraph(unittest.TestCase):
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc))
g = UOpGraph([wmma])
self.assertEqual(g.uops[0], acc)
self.assert_equiv_uops(g.uops[0], acc)
self.assertEqual(len(g.uops), 1)
def test_wmma_vectorize_no_fold(self):
@ -219,7 +220,7 @@ class TestUOpGraph(unittest.TestCase):
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc))
g = UOpGraph([wmma])
self.assertEqual(g.uops[-1], wmma)
self.assert_equiv_uops(g.uops[-1], wmma)
for i in [4, 8]:
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable(f'tmp{i}', 0.0, 1.0))
@ -229,7 +230,7 @@ class TestUOpGraph(unittest.TestCase):
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc))
g = UOpGraph([wmma])
self.assertEqual(g.uops[-1], wmma)
self.assert_equiv_uops(g.uops[-1], wmma)
for i in [2, 4, 8]:
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i),
@ -238,7 +239,7 @@ class TestUOpGraph(unittest.TestCase):
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc))
g = UOpGraph([wmma])
self.assertEqual(g.uops[-1], wmma)
self.assert_equiv_uops(g.uops[-1], wmma)
for i in [2, 4, 8]:
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable(f'tmp{i}', 0.0, 1.0))
@ -247,7 +248,7 @@ class TestUOpGraph(unittest.TestCase):
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc))
g = UOpGraph([wmma])
self.assertEqual(g.uops[-1], wmma)
self.assert_equiv_uops(g.uops[-1], wmma)
def test_cast_alu_fold(self):
d0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.bool), arg=0)
@ -293,9 +294,9 @@ class TestUOpGraph(unittest.TestCase):
uops = UOpGraph([UOp(UOps.STORE, None, (glbl0, idx, ld1+ld0))])
ld0, ld1 = uops[-1].src[2].src
# ld0 becomes the invalid value
self.assertEqual(ld1, UOp.const(dtypes.int, 2))
self.assert_equiv_uops(ld1, UOp.const(dtypes.int, 2))
# the gate and invalid value are deleted from ld1
self.assertEqual(ld0, UOp.load(glbl2, idx, dtype=dtypes.int))
self.assert_equiv_uops(ld0, UOp.load(glbl2, idx, dtype=dtypes.int))
def test_fold_gated_load_local(self):
glbl0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0)
@ -308,9 +309,9 @@ class TestUOpGraph(unittest.TestCase):
uops = UOpGraph([UOp(UOps.STORE, None, (glbl0, lidx, ld1+ld0))])
ld0, ld1 = uops[-1].src[2].src
# ld0 becomes the invalid value
self.assertEqual(ld1, UOp.const(dtypes.int, 2))
self.assert_equiv_uops(ld1, UOp.const(dtypes.int, 2))
# the gate and invalid value are deleted from ld1
self.assertEqual(ld0, UOp.load(smem, lidx+2, barrier, dtype=dtypes.int))
self.assert_equiv_uops(ld0, UOp.load(smem, lidx+2, barrier, dtype=dtypes.int))
def test_fold_gated_store(self):
glbl = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0)
@ -322,7 +323,7 @@ class TestUOpGraph(unittest.TestCase):
uops = UOpGraph([st0, st1])
# only the second store happens
self.assertEqual(len(uops.uops), 4)
self.assertEqual(uops[-1], UOp.store(glbl, idx1, val))
self.assert_equiv_uops(uops[-1], UOp.store(glbl, idx1, val))
def test_asserts_bad_gate(self):
glbl0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0)
@ -559,7 +560,7 @@ class TestLoadStoreFolder(unittest.TestCase):
def gate_rewrite(sink): return graph_rewrite(sink, constant_folder + expander + reducer)
class TestIFUOps(unittest.TestCase):
class TestIFUOps(TestUOps):
def test_create_ifs(self):
gbuf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 0)
sbuf = UOp(UOps.DEFINE_LOCAL, PtrDType(dtypes.float), (), ("smem", 4))
@ -575,7 +576,7 @@ class TestIFUOps(unittest.TestCase):
sink = gate_rewrite(sink)
if_uops = [u for u in sink.parents if u.op is UOps.IF]
self.assertEqual(len(if_uops), 1)
self.assertEqual(if_uops[0].src[0], gate)
self.assert_equiv_uops(if_uops[0].src[0], gate)
for st in sink.src:
self.assertEqual(len(st.src), 3)
@ -593,7 +594,7 @@ class TestIFUOps(unittest.TestCase):
sink = gate_rewrite(sink)
if_uops = [u for u in sink.parents if u.op is UOps.IF]
self.assertEqual(len(if_uops), 1)
self.assertEqual(if_uops[0].src[0], gate)
self.assert_equiv_uops(if_uops[0].src[0], gate)
for st in sink.src:
self.assertEqual(len(st.src), 3)
@ -609,19 +610,19 @@ class TestIFUOps(unittest.TestCase):
sink = gate_rewrite(sink)
if_uops = [u for u in sink.parents if u.op is UOps.IF]
self.assertEqual(len(if_uops), 1)
self.assertEqual(if_uops[0].src[0], gate)
self.assert_equiv_uops(if_uops[0].src[0], gate)
for st in sink.src:
self.assertEqual(len(st.src), 3)
class TestDivMod(unittest.TestCase):
class TestDivMod(TestUOps):
def c(self, c:int): return UOp.const(dtypes.int, c)
def x(self, expr:str, nmin:int, nmax:int): return UOp(UOps.DEFINE_VAR, dtypes.int, (self.c(nmin), self.c(nmax)), Variable(expr, nmin, nmax))
# NOTE: does not simplify to the end
def test_const_mod(self):
self.assertEqual(mod_folding(self.c(6), 3), self.c(1)*self.c(0))
self.assertEqual(mod_folding(self.c(7), 3), self.c(1)*self.c(1))
self.assertEqual(mod_folding(self.c(8), 3), self.c(1)*self.c(2))
self.assert_equiv_uops(mod_folding(self.c(6), 3), self.c(1)*self.c(0))
self.assert_equiv_uops(mod_folding(self.c(7), 3), self.c(1)*self.c(1))
self.assert_equiv_uops(mod_folding(self.c(8), 3), self.c(1)*self.c(2))
def test_var_mod(self):
self.assertIsNone(mod_folding(self.x("x", 0, 6), 3))
@ -629,32 +630,32 @@ class TestDivMod(unittest.TestCase):
@unittest.skip("does not simplify to the end")
def test_add_mod(self):
self.assertEqual(mod_folding(self.x("x", 0, 6)+40, 5), self.x("x", 0, 6))
self.assertEqual(mod_folding(self.x("x", 0, 6)-40, 5), self.x("x", 0, 6))
self.assertEqual(mod_folding(self.x("x", 0, 6)+42, 5), (self.x("x", 0, 6)+2))
self.assertEqual(mod_folding(self.x("x", 0, 6)-42, 5), (self.x("x", 0, 6)+3))
self.assertEqual(mod_folding(40+self.x("x", 0, 6), 5), self.x("x", 0, 6))
self.assertEqual(mod_folding(-40+self.x("x", 0, 6), 5), self.x("x", 0, 6))
self.assertEqual(mod_folding(42+self.x("x", 0, 6), 5), (2+self.x("x", 0, 6)))
self.assertEqual(mod_folding(-42+self.x("x", 0, 6), 5), (3+self.x("x", 0, 6)))
self.assert_equiv_uops(mod_folding(self.x("x", 0, 6)+40, 5), self.x("x", 0, 6))
self.assert_equiv_uops(mod_folding(self.x("x", 0, 6)-40, 5), self.x("x", 0, 6))
self.assert_equiv_uops(mod_folding(self.x("x", 0, 6)+42, 5), (self.x("x", 0, 6)+2))
self.assert_equiv_uops(mod_folding(self.x("x", 0, 6)-42, 5), (self.x("x", 0, 6)+3))
self.assert_equiv_uops(mod_folding(40+self.x("x", 0, 6), 5), self.x("x", 0, 6))
self.assert_equiv_uops(mod_folding(-40+self.x("x", 0, 6), 5), self.x("x", 0, 6))
self.assert_equiv_uops(mod_folding(42+self.x("x", 0, 6), 5), (2+self.x("x", 0, 6)))
self.assert_equiv_uops(mod_folding(-42+self.x("x", 0, 6), 5), (3+self.x("x", 0, 6)))
@unittest.skip("does not simplify to the end")
def test_mul_mod(self):
self.assertEqual(mod_folding(self.x("x", 0, 6)*40, 5), self.c(0))
self.assertEqual(mod_folding(self.x("x", 0, 6)*-40, 5), self.c(0))
self.assertEqual(mod_folding(self.x("x", 0, 6)*42, 5), (self.x("x", 0, 6)*2))
self.assertEqual(mod_folding(self.x("x", 0, 6)*-42, 5), (self.x("x", 0, 6)*3))
self.assertEqual(mod_folding(40*self.x("x", 0, 6), 5), self.c(0))
self.assertEqual(mod_folding(-40*self.x("x", 0, 6), 5), self.c(0))
self.assertEqual(mod_folding(42*self.x("x", 0, 6), 5), (2*self.x("x", 0, 6)))
self.assertEqual(mod_folding(-42*self.x("x", 0, 6), 5), (3*self.x("x", 0, 6)))
self.assert_equiv_uops(mod_folding(self.x("x", 0, 6)*40, 5), self.c(0))
self.assert_equiv_uops(mod_folding(self.x("x", 0, 6)*-40, 5), self.c(0))
self.assert_equiv_uops(mod_folding(self.x("x", 0, 6)*42, 5), (self.x("x", 0, 6)*2))
self.assert_equiv_uops(mod_folding(self.x("x", 0, 6)*-42, 5), (self.x("x", 0, 6)*3))
self.assert_equiv_uops(mod_folding(40*self.x("x", 0, 6), 5), self.c(0))
self.assert_equiv_uops(mod_folding(-40*self.x("x", 0, 6), 5), self.c(0))
self.assert_equiv_uops(mod_folding(42*self.x("x", 0, 6), 5), (2*self.x("x", 0, 6)))
self.assert_equiv_uops(mod_folding(-42*self.x("x", 0, 6), 5), (3*self.x("x", 0, 6)))
@unittest.skip("does not simplify to the end now")
def test_mul_add_mod(self):
x = self.x("x", 0, 10)
y = self.x("y", 0, 10)
z = self.x("z", 0, 10)
self.assertEqual(mod_folding(x*40+y*12+z, 5), (y*2+z))
self.assert_equiv_uops(mod_folding(x*40+y*12+z, 5), (y*2+z))
if __name__ == '__main__':

View File

@ -1,5 +1,5 @@
from typing import Optional, Tuple, Any, List
import unittest, math, time
import unittest, math
import numpy as np
from tinygrad.tensor import Tensor, _to_np_dtype
from tinygrad.helpers import CI, DEBUG, getenv, Context
@ -11,7 +11,7 @@ from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import CompiledRunner, lower_schedule_item, get_kernel
from tinygrad.codegen.uops import UOps, NOp, UOp
from tinygrad.codegen.uopgraph import UOpGraph
from test.helpers import is_dtype_supported
from test.helpers import is_dtype_supported, TestUOps as TestEqUOps
def _uops_to_prg(uops_list, print_uops=False):
uops = UOpGraph(uops_list)
@ -357,39 +357,7 @@ class TestUOpCompare(unittest.TestCase):
mul = UOp(UOps.ALU, dtypes.float, (a, b), BinaryOps.MUL)
assert (add < mul) or (mul < add), "add and mul with same src should have an order"
def test_uop_eq_fields(self):
a = UOp(UOps.CONST, dtypes.float, (), 2.0)
b = UOp(UOps.CONST, dtypes.float, (), 2.0)
self.assertEqual(a, b)
def test_uop_ne_fields(self):
a = UOp(UOps.RANGE, dtypes.pyint, (UOp.const(dtypes.pyint, 0), UOp.const(dtypes.pyint, 1)), (1, False))
b = UOp(UOps.RANGE, dtypes.pyint, (UOp.const(dtypes.pyint, 0), UOp.const(dtypes.pyint, 2)), (1, False))
self.assertNotEqual(a, b)
def test_recursive_eq_src(self):
st = time.perf_counter()
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0)
idx = UOp.const(dtypes.int, 0)
a = UOp(UOps.LOAD, dtypes.float, (buf, idx))
for _ in range(24): a += a
b = UOp(UOps.LOAD, dtypes.float, (buf, idx))
for _ in range(24): b += b
self.assertEqual(a, b)
self.assertLess(time.perf_counter()-st, 1e-2)
# NOTE: NOp uses the dataclass compare, this is fine
def test_nop_ne(self):
a = NOp(UOps.CONST, dtypes.float, (), 2.0, name="a")
b = NOp(UOps.CONST, dtypes.float, (), 2.0, name="b")
self.assertNotEqual(a, b)
def test_nop_eq(self):
a1 = NOp(UOps.CONST, dtypes.float, (), 2.0, name="a")
a2 = NOp(UOps.CONST, dtypes.float, (), 2.0, name="a")
self.assertEqual(a1, a2)
class TestUOpStr(unittest.TestCase):
class TestUOpStr(TestEqUOps):
def test_uop_str(self):
a = UOp(UOps.CONST, dtypes.float, (), 2.0) + UOp(UOps.CONST, dtypes.float, (), 3.0)
for _ in range(20): a = a + a
@ -401,7 +369,7 @@ class TestUOpStr(unittest.TestCase):
# nice big complicated uop
with Context(NOOPT=1):
sink = get_kernel(Device[Device.DEFAULT].renderer, t.schedule()[-1].ast).linearize().uops.sink
self.assertEqual(sink, eval(str(sink)))
self.assert_equiv_uops(sink, eval(str(sink)))
def test_nop_str(self):
a = NOp(UOps.CONST, dtypes.float, (), 2.0, name="c0") + NOp(UOps.CONST, dtypes.float, (), 3.0, name="c1")

View File

@ -43,15 +43,6 @@ class UOp:
return (self.op.value, (self.arg if self.op is not UOps.DEFINE_VAR else self.arg.expr) if self.op is not UOps.ALU else \
self.arg.value, self.dtype, self.src)
def __lt__(self, x:UOp): return self.cmp_tuple < x.cmp_tuple
def cached_eq(self, x:UOp, context:Dict[Tuple[int, int], bool]) -> bool:
if id(self) == id(x): return True
if self.op != x.op or self.dtype != x.dtype or self.arg != x.arg or len(self.src) != len(x.src): return False
if (key := (id(self), id(x))) in context: return context[key]
return context.setdefault(key, all(a.cached_eq(b, context) for a,b in zip(self.src, x.src)))
def __eq__(self, x): return self.cached_eq(x, context={})
@functools.cached_property
def hash(self): return hash((self.op, self.dtype, self.src, self.arg))
def __hash__(self): return self.hash
def __repr__(self): return pretty_print(self, lambda x: f"{type(self).__name__}({x.op}, {x.dtype}, arg={x.arg}, src=(%s))")
# *** uop syntactic sugar
def ufix(self, x): return self.const(x) if not isinstance(x, UOp) else x