diff --git a/test/helpers.py b/test/helpers.py index c77affee..f4e49316 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -1,6 +1,8 @@ -import sys +import sys, unittest +from typing import Optional, Set, Tuple import numpy as np from tinygrad import Tensor, Device, dtypes +from tinygrad.codegen.uops import UOp from tinygrad.tensor import _to_np_dtype from tinygrad.engine.realize import Runner from tinygrad.dtype import DType @@ -51,3 +53,20 @@ def rand_for_dtype(dt:DType, size:int): elif dt == dtypes.bool: return np.random.choice([True, False], size=size) return np.random.uniform(-10, 10, size=size).astype(_to_np_dtype(dt)) + +class TestUOps(unittest.TestCase): + def assert_equiv_uops(self, uop1:UOp, uop2:UOp, cache:Optional[Set[Tuple[UOp, UOp]]]=None): + if cache is None: cache = set() + if (uop1, uop2) in cache: return + cache.add((uop1, uop2)) + # NOTE: direct UOps __eq__ is comparing object reference, use this function to compare two uops + try: + self.assertIs(uop1.op, uop2.op) + self.assertEqual(uop1.dtype, uop2.dtype) + self.assertEqual(uop1.arg, uop2.arg) + self.assertEqual(len(uop1.src), len(uop2.src)) + for s1, s2 in zip(uop1.src, uop2.src): self.assert_equiv_uops(s1, s2) + except AssertionError as e: + print(f"{uop1=}") + print(f"{uop2=}") + raise e diff --git a/test/test_pattern_matcher.py b/test/test_pattern_matcher.py index 78552081..216cf38c 100644 --- a/test/test_pattern_matcher.py +++ b/test/test_pattern_matcher.py @@ -1,10 +1,11 @@ import unittest, itertools +from test.helpers import TestUOps from tinygrad.dtype import dtypes from tinygrad.ops import BinaryOps, TernaryOps, ReduceOps, UnaryOps # noqa: F401 from tinygrad.codegen.uops import UOps, UOp, PatternMatcher, UPat from tinygrad.codegen.uopgraph import constant_folder -class TestPatternMatcher(unittest.TestCase): +class TestPatternMatcher(TestUOps): def test_simple_match(self): matcher = PatternMatcher([(UPat(UOps.CONST, name="x", dtype=dtypes.float), lambda x: x)]) c1 = UOp(UOps.CONST, dtypes.float, arg=1.0) diff --git a/test/test_uop_graph.py b/test/test_uop_graph.py index d41b9dc3..bbd3553f 100644 --- a/test/test_uop_graph.py +++ b/test/test_uop_graph.py @@ -1,4 +1,5 @@ import unittest +from test.helpers import TestUOps from tinygrad import dtypes, Variable from tinygrad.dtype import PtrDType from tinygrad.helpers import DEBUG @@ -89,7 +90,7 @@ class TestGraphRewrite(unittest.TestCase): self.assertEqual(sink.src[1].op, UOps.CONST) self.assertEqual(len([x for x in sink.sparents if x.op is UOps.CONST]), 3) -class TestUOpGraph(unittest.TestCase): +class TestUOpGraph(TestUOps): def test_add_constant_fold(self): c1 = UOp(UOps.CONST, dtypes.float, arg=1.0) c2 = UOp(UOps.CONST, dtypes.float, arg=2.0) @@ -160,7 +161,7 @@ class TestUOpGraph(unittest.TestCase): # possible val = UOp(UOps.LOAD, dtypes.float.vec(4), (d1, idx)) xyzw = tuple(UOp(UOps.GEP, dtypes.float, (val,), i) for i in range(4)) - self.assertEqual(_test_vec(xyzw), val) + self.assert_equiv_uops(_test_vec(xyzw), val) # unaligned val = UOp(UOps.LOAD, dtypes.float.vec(4), (d1, idx)) @@ -189,7 +190,7 @@ class TestUOpGraph(unittest.TestCase): geps = [UOp(UOps.GEP, dtypes.float, (vec,), i) for i in range(vec_size)] g = UOpGraph(geps) for uop, const in zip(g.uops, consts): - self.assertEqual(uop, const) + self.assert_equiv_uops(uop, const) def test_wmma_vectorize_fold(self): for i in [2, 4, 8]: @@ -198,7 +199,7 @@ class TestUOpGraph(unittest.TestCase): acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0)) wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc)) g = UOpGraph([wmma]) - self.assertEqual(g.uops[0], acc) + self.assert_equiv_uops(g.uops[0], acc) self.assertEqual(len(g.uops), 1) for i in [2, 4, 8]: @@ -207,7 +208,7 @@ class TestUOpGraph(unittest.TestCase): acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0)) wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc)) g = UOpGraph([wmma]) - self.assertEqual(g.uops[0], acc) + self.assert_equiv_uops(g.uops[0], acc) self.assertEqual(len(g.uops), 1) def test_wmma_vectorize_no_fold(self): @@ -219,7 +220,7 @@ class TestUOpGraph(unittest.TestCase): acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0)) wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc)) g = UOpGraph([wmma]) - self.assertEqual(g.uops[-1], wmma) + self.assert_equiv_uops(g.uops[-1], wmma) for i in [4, 8]: var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable(f'tmp{i}', 0.0, 1.0)) @@ -229,7 +230,7 @@ class TestUOpGraph(unittest.TestCase): acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0)) wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc)) g = UOpGraph([wmma]) - self.assertEqual(g.uops[-1], wmma) + self.assert_equiv_uops(g.uops[-1], wmma) for i in [2, 4, 8]: vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i), @@ -238,7 +239,7 @@ class TestUOpGraph(unittest.TestCase): acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0)) wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc)) g = UOpGraph([wmma]) - self.assertEqual(g.uops[-1], wmma) + self.assert_equiv_uops(g.uops[-1], wmma) for i in [2, 4, 8]: var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable(f'tmp{i}', 0.0, 1.0)) @@ -247,7 +248,7 @@ class TestUOpGraph(unittest.TestCase): acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0)) wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc)) g = UOpGraph([wmma]) - self.assertEqual(g.uops[-1], wmma) + self.assert_equiv_uops(g.uops[-1], wmma) def test_cast_alu_fold(self): d0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.bool), arg=0) @@ -293,9 +294,9 @@ class TestUOpGraph(unittest.TestCase): uops = UOpGraph([UOp(UOps.STORE, None, (glbl0, idx, ld1+ld0))]) ld0, ld1 = uops[-1].src[2].src # ld0 becomes the invalid value - self.assertEqual(ld1, UOp.const(dtypes.int, 2)) + self.assert_equiv_uops(ld1, UOp.const(dtypes.int, 2)) # the gate and invalid value are deleted from ld1 - self.assertEqual(ld0, UOp.load(glbl2, idx, dtype=dtypes.int)) + self.assert_equiv_uops(ld0, UOp.load(glbl2, idx, dtype=dtypes.int)) def test_fold_gated_load_local(self): glbl0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0) @@ -308,9 +309,9 @@ class TestUOpGraph(unittest.TestCase): uops = UOpGraph([UOp(UOps.STORE, None, (glbl0, lidx, ld1+ld0))]) ld0, ld1 = uops[-1].src[2].src # ld0 becomes the invalid value - self.assertEqual(ld1, UOp.const(dtypes.int, 2)) + self.assert_equiv_uops(ld1, UOp.const(dtypes.int, 2)) # the gate and invalid value are deleted from ld1 - self.assertEqual(ld0, UOp.load(smem, lidx+2, barrier, dtype=dtypes.int)) + self.assert_equiv_uops(ld0, UOp.load(smem, lidx+2, barrier, dtype=dtypes.int)) def test_fold_gated_store(self): glbl = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0) @@ -322,7 +323,7 @@ class TestUOpGraph(unittest.TestCase): uops = UOpGraph([st0, st1]) # only the second store happens self.assertEqual(len(uops.uops), 4) - self.assertEqual(uops[-1], UOp.store(glbl, idx1, val)) + self.assert_equiv_uops(uops[-1], UOp.store(glbl, idx1, val)) def test_asserts_bad_gate(self): glbl0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0) @@ -559,7 +560,7 @@ class TestLoadStoreFolder(unittest.TestCase): def gate_rewrite(sink): return graph_rewrite(sink, constant_folder + expander + reducer) -class TestIFUOps(unittest.TestCase): +class TestIFUOps(TestUOps): def test_create_ifs(self): gbuf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 0) sbuf = UOp(UOps.DEFINE_LOCAL, PtrDType(dtypes.float), (), ("smem", 4)) @@ -575,7 +576,7 @@ class TestIFUOps(unittest.TestCase): sink = gate_rewrite(sink) if_uops = [u for u in sink.parents if u.op is UOps.IF] self.assertEqual(len(if_uops), 1) - self.assertEqual(if_uops[0].src[0], gate) + self.assert_equiv_uops(if_uops[0].src[0], gate) for st in sink.src: self.assertEqual(len(st.src), 3) @@ -593,7 +594,7 @@ class TestIFUOps(unittest.TestCase): sink = gate_rewrite(sink) if_uops = [u for u in sink.parents if u.op is UOps.IF] self.assertEqual(len(if_uops), 1) - self.assertEqual(if_uops[0].src[0], gate) + self.assert_equiv_uops(if_uops[0].src[0], gate) for st in sink.src: self.assertEqual(len(st.src), 3) @@ -609,19 +610,19 @@ class TestIFUOps(unittest.TestCase): sink = gate_rewrite(sink) if_uops = [u for u in sink.parents if u.op is UOps.IF] self.assertEqual(len(if_uops), 1) - self.assertEqual(if_uops[0].src[0], gate) + self.assert_equiv_uops(if_uops[0].src[0], gate) for st in sink.src: self.assertEqual(len(st.src), 3) -class TestDivMod(unittest.TestCase): +class TestDivMod(TestUOps): def c(self, c:int): return UOp.const(dtypes.int, c) def x(self, expr:str, nmin:int, nmax:int): return UOp(UOps.DEFINE_VAR, dtypes.int, (self.c(nmin), self.c(nmax)), Variable(expr, nmin, nmax)) # NOTE: does not simplify to the end def test_const_mod(self): - self.assertEqual(mod_folding(self.c(6), 3), self.c(1)*self.c(0)) - self.assertEqual(mod_folding(self.c(7), 3), self.c(1)*self.c(1)) - self.assertEqual(mod_folding(self.c(8), 3), self.c(1)*self.c(2)) + self.assert_equiv_uops(mod_folding(self.c(6), 3), self.c(1)*self.c(0)) + self.assert_equiv_uops(mod_folding(self.c(7), 3), self.c(1)*self.c(1)) + self.assert_equiv_uops(mod_folding(self.c(8), 3), self.c(1)*self.c(2)) def test_var_mod(self): self.assertIsNone(mod_folding(self.x("x", 0, 6), 3)) @@ -629,32 +630,32 @@ class TestDivMod(unittest.TestCase): @unittest.skip("does not simplify to the end") def test_add_mod(self): - self.assertEqual(mod_folding(self.x("x", 0, 6)+40, 5), self.x("x", 0, 6)) - self.assertEqual(mod_folding(self.x("x", 0, 6)-40, 5), self.x("x", 0, 6)) - self.assertEqual(mod_folding(self.x("x", 0, 6)+42, 5), (self.x("x", 0, 6)+2)) - self.assertEqual(mod_folding(self.x("x", 0, 6)-42, 5), (self.x("x", 0, 6)+3)) - self.assertEqual(mod_folding(40+self.x("x", 0, 6), 5), self.x("x", 0, 6)) - self.assertEqual(mod_folding(-40+self.x("x", 0, 6), 5), self.x("x", 0, 6)) - self.assertEqual(mod_folding(42+self.x("x", 0, 6), 5), (2+self.x("x", 0, 6))) - self.assertEqual(mod_folding(-42+self.x("x", 0, 6), 5), (3+self.x("x", 0, 6))) + self.assert_equiv_uops(mod_folding(self.x("x", 0, 6)+40, 5), self.x("x", 0, 6)) + self.assert_equiv_uops(mod_folding(self.x("x", 0, 6)-40, 5), self.x("x", 0, 6)) + self.assert_equiv_uops(mod_folding(self.x("x", 0, 6)+42, 5), (self.x("x", 0, 6)+2)) + self.assert_equiv_uops(mod_folding(self.x("x", 0, 6)-42, 5), (self.x("x", 0, 6)+3)) + self.assert_equiv_uops(mod_folding(40+self.x("x", 0, 6), 5), self.x("x", 0, 6)) + self.assert_equiv_uops(mod_folding(-40+self.x("x", 0, 6), 5), self.x("x", 0, 6)) + self.assert_equiv_uops(mod_folding(42+self.x("x", 0, 6), 5), (2+self.x("x", 0, 6))) + self.assert_equiv_uops(mod_folding(-42+self.x("x", 0, 6), 5), (3+self.x("x", 0, 6))) @unittest.skip("does not simplify to the end") def test_mul_mod(self): - self.assertEqual(mod_folding(self.x("x", 0, 6)*40, 5), self.c(0)) - self.assertEqual(mod_folding(self.x("x", 0, 6)*-40, 5), self.c(0)) - self.assertEqual(mod_folding(self.x("x", 0, 6)*42, 5), (self.x("x", 0, 6)*2)) - self.assertEqual(mod_folding(self.x("x", 0, 6)*-42, 5), (self.x("x", 0, 6)*3)) - self.assertEqual(mod_folding(40*self.x("x", 0, 6), 5), self.c(0)) - self.assertEqual(mod_folding(-40*self.x("x", 0, 6), 5), self.c(0)) - self.assertEqual(mod_folding(42*self.x("x", 0, 6), 5), (2*self.x("x", 0, 6))) - self.assertEqual(mod_folding(-42*self.x("x", 0, 6), 5), (3*self.x("x", 0, 6))) + self.assert_equiv_uops(mod_folding(self.x("x", 0, 6)*40, 5), self.c(0)) + self.assert_equiv_uops(mod_folding(self.x("x", 0, 6)*-40, 5), self.c(0)) + self.assert_equiv_uops(mod_folding(self.x("x", 0, 6)*42, 5), (self.x("x", 0, 6)*2)) + self.assert_equiv_uops(mod_folding(self.x("x", 0, 6)*-42, 5), (self.x("x", 0, 6)*3)) + self.assert_equiv_uops(mod_folding(40*self.x("x", 0, 6), 5), self.c(0)) + self.assert_equiv_uops(mod_folding(-40*self.x("x", 0, 6), 5), self.c(0)) + self.assert_equiv_uops(mod_folding(42*self.x("x", 0, 6), 5), (2*self.x("x", 0, 6))) + self.assert_equiv_uops(mod_folding(-42*self.x("x", 0, 6), 5), (3*self.x("x", 0, 6))) @unittest.skip("does not simplify to the end now") def test_mul_add_mod(self): x = self.x("x", 0, 10) y = self.x("y", 0, 10) z = self.x("z", 0, 10) - self.assertEqual(mod_folding(x*40+y*12+z, 5), (y*2+z)) + self.assert_equiv_uops(mod_folding(x*40+y*12+z, 5), (y*2+z)) if __name__ == '__main__': diff --git a/test/test_uops.py b/test/test_uops.py index 69873c6a..443f2c50 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -1,5 +1,5 @@ from typing import Optional, Tuple, Any, List -import unittest, math, time +import unittest, math import numpy as np from tinygrad.tensor import Tensor, _to_np_dtype from tinygrad.helpers import CI, DEBUG, getenv, Context @@ -11,7 +11,7 @@ from tinygrad.engine.schedule import create_schedule from tinygrad.engine.realize import CompiledRunner, lower_schedule_item, get_kernel from tinygrad.codegen.uops import UOps, NOp, UOp from tinygrad.codegen.uopgraph import UOpGraph -from test.helpers import is_dtype_supported +from test.helpers import is_dtype_supported, TestUOps as TestEqUOps def _uops_to_prg(uops_list, print_uops=False): uops = UOpGraph(uops_list) @@ -357,39 +357,7 @@ class TestUOpCompare(unittest.TestCase): mul = UOp(UOps.ALU, dtypes.float, (a, b), BinaryOps.MUL) assert (add < mul) or (mul < add), "add and mul with same src should have an order" - def test_uop_eq_fields(self): - a = UOp(UOps.CONST, dtypes.float, (), 2.0) - b = UOp(UOps.CONST, dtypes.float, (), 2.0) - self.assertEqual(a, b) - - def test_uop_ne_fields(self): - a = UOp(UOps.RANGE, dtypes.pyint, (UOp.const(dtypes.pyint, 0), UOp.const(dtypes.pyint, 1)), (1, False)) - b = UOp(UOps.RANGE, dtypes.pyint, (UOp.const(dtypes.pyint, 0), UOp.const(dtypes.pyint, 2)), (1, False)) - self.assertNotEqual(a, b) - - def test_recursive_eq_src(self): - st = time.perf_counter() - buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0) - idx = UOp.const(dtypes.int, 0) - a = UOp(UOps.LOAD, dtypes.float, (buf, idx)) - for _ in range(24): a += a - b = UOp(UOps.LOAD, dtypes.float, (buf, idx)) - for _ in range(24): b += b - self.assertEqual(a, b) - self.assertLess(time.perf_counter()-st, 1e-2) - - # NOTE: NOp uses the dataclass compare, this is fine - def test_nop_ne(self): - a = NOp(UOps.CONST, dtypes.float, (), 2.0, name="a") - b = NOp(UOps.CONST, dtypes.float, (), 2.0, name="b") - self.assertNotEqual(a, b) - - def test_nop_eq(self): - a1 = NOp(UOps.CONST, dtypes.float, (), 2.0, name="a") - a2 = NOp(UOps.CONST, dtypes.float, (), 2.0, name="a") - self.assertEqual(a1, a2) - -class TestUOpStr(unittest.TestCase): +class TestUOpStr(TestEqUOps): def test_uop_str(self): a = UOp(UOps.CONST, dtypes.float, (), 2.0) + UOp(UOps.CONST, dtypes.float, (), 3.0) for _ in range(20): a = a + a @@ -401,7 +369,7 @@ class TestUOpStr(unittest.TestCase): # nice big complicated uop with Context(NOOPT=1): sink = get_kernel(Device[Device.DEFAULT].renderer, t.schedule()[-1].ast).linearize().uops.sink - self.assertEqual(sink, eval(str(sink))) + self.assert_equiv_uops(sink, eval(str(sink))) def test_nop_str(self): a = NOp(UOps.CONST, dtypes.float, (), 2.0, name="c0") + NOp(UOps.CONST, dtypes.float, (), 3.0, name="c1") diff --git a/tinygrad/codegen/uops.py b/tinygrad/codegen/uops.py index 3ed1c16e..37761798 100644 --- a/tinygrad/codegen/uops.py +++ b/tinygrad/codegen/uops.py @@ -43,15 +43,6 @@ class UOp: return (self.op.value, (self.arg if self.op is not UOps.DEFINE_VAR else self.arg.expr) if self.op is not UOps.ALU else \ self.arg.value, self.dtype, self.src) def __lt__(self, x:UOp): return self.cmp_tuple < x.cmp_tuple - def cached_eq(self, x:UOp, context:Dict[Tuple[int, int], bool]) -> bool: - if id(self) == id(x): return True - if self.op != x.op or self.dtype != x.dtype or self.arg != x.arg or len(self.src) != len(x.src): return False - if (key := (id(self), id(x))) in context: return context[key] - return context.setdefault(key, all(a.cached_eq(b, context) for a,b in zip(self.src, x.src))) - def __eq__(self, x): return self.cached_eq(x, context={}) - @functools.cached_property - def hash(self): return hash((self.op, self.dtype, self.src, self.arg)) - def __hash__(self): return self.hash def __repr__(self): return pretty_print(self, lambda x: f"{type(self).__name__}({x.op}, {x.dtype}, arg={x.arg}, src=(%s))") # *** uop syntactic sugar def ufix(self, x): return self.const(x) if not isinstance(x, UOp) else x