From a2239c812e43646ade0547841f6974ea6b55253a Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Tue, 17 Sep 2024 13:02:41 +0800 Subject: [PATCH] minimum new style expand (#6534) * minimum new style expand [run_process_replay] * float4 folding works * fix uop graph * if means or * dype.count idx overload * fix test arange * expand nope * fix expand contract * fix amd tensor core * oh, that's a good test with a real failure * remove prints * early reduce * tomorrow, we remove sorted on expand args * fix wmma issue * that makes test_arange pass * vectorized folding * no check * broadcast * fix clang with self assign rule --- test/test_uop_graph.py | 85 +++++++++---------- test/test_uops_stats.py | 2 +- tinygrad/codegen/lowerer.py | 4 +- tinygrad/codegen/uopgraph.py | 157 +++++++++++++++++++++++++++-------- tinygrad/ops.py | 7 ++ 5 files changed, 175 insertions(+), 80 deletions(-) diff --git a/test/test_uop_graph.py b/test/test_uop_graph.py index ea0ba740..f4fcf482 100644 --- a/test/test_uop_graph.py +++ b/test/test_uop_graph.py @@ -295,6 +295,7 @@ class TestUOpGraph(unittest.TestCase): assert_equiv_uops(uops[0], acc) self.assertEqual(len(uops), 1) + @unittest.skip("wmma is wrong here, it needs an arg") def test_wmma_vectorize_no_fold(self): for i in [4, 8]: vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i), @@ -437,63 +438,60 @@ def float4_rewrite(sink): return graph_rewrite(sink, constant_folder + expander class TestExpander(unittest.TestCase): def test_expand_add_broadcast(self): - e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),)) + e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(x for x in range(4))),), ((1,4),)) sink = expander_rewrite(e1+3) - assert sink.op is UOps.EXPAND and len(sink.src) == 4 - self.assertListEqual([x.arg for x in sink.src], [3,4,5,6]) + assert sink.op is UOps.EXPAND and len(sink.src[0].src) == 4 + self.assertListEqual([x.arg for x in sink.src[0].src], [3,4,5,6]) def test_contract_simple(self): - e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),)) + e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(x for x in range(4))),), ((1,4),)) con = UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), ((1,4),)) sink = expander_rewrite(con) assert sink.op is UOps.VECTORIZE and len(sink.src) == 4 self.assertListEqual([x.arg for x in sink.src], [0,1,2,3]) def test_contract_axis_1(self): - e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(16)), ((1,4),(2,4))) + e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(16), tuple(x for x in range(16))),), ((1,4),(2,4))) con = UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), ((1,4),)) sink = expander_rewrite(con) - assert sink.op is UOps.EXPAND and len(sink.src) == 4 and sink.arg == ((2,4),) - assert sink.src[0].op is UOps.VECTORIZE and len(sink.src[0].src) == 4 - self.assertListEqual([x.arg for x in sink.src[0].src], [0,4,8,12]) - self.assertListEqual([x.arg for x in sink.src[3].src], [3,7,11,15]) + assert sink.op is UOps.EXPAND and len(sink.src[0].src) == 16 and sink.arg == ((2,4),) + assert sink.src[0].op is UOps.VECTORIZE and len(sink.src[0].src) == 16 + self.assertListEqual([x.arg for x in sink.src[0].src][0:4], [0,4,8,12]) + self.assertListEqual([x.arg for x in sink.src[0].src][12:], [3,7,11,15]) def test_contract_axis_2(self): - e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(16)), ((1,4),(2,4))) + e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(16), tuple(x for x in range(16))),), ((1,4),(2,4))) con = UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), ((2,4),)) sink = expander_rewrite(con) - assert sink.op is UOps.EXPAND and len(sink.src) == 4 and sink.arg == ((1,4),) - assert sink.src[0].op is UOps.VECTORIZE and len(sink.src[0].src) == 4 - self.assertListEqual([x.arg for x in sink.src[0].src], [0,1,2,3]) - self.assertListEqual([x.arg for x in sink.src[3].src], [12,13,14,15]) + assert sink.op is UOps.EXPAND and len(sink.src[0].src) == 16 and sink.arg == ((1,4),) + assert sink.src[0].op is UOps.VECTORIZE and len(sink.src[0].src) == 16 + self.assertListEqual([x.arg for x in sink.src[0].src][0:4], [0,1,2,3]) + self.assertListEqual([x.arg for x in sink.src[0].src][12:], [12,13,14,15]) def test_contract_axis_2_big(self): - e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(16)), ((1,2),(2,2),(3,2),(4,2))) + e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(16), tuple(x for x in range(16))),), ((1,2),(2,2),(3,2),(4,2))) con = UOp(UOps.CONTRACT, dtypes.int.vec(2), (e1,), ((2,2),)) sink = expander_rewrite(con) assert sink.op is UOps.EXPAND and sink.arg == ((1, 2), (3, 2), (4, 2)) - self.assertListEqual([x.arg for x in sink.src[0].src], [0,4]) - self.assertListEqual([x.arg for x in sink.src[6].src], [10,14]) + self.assertListEqual([x.arg for x in sink.src[0].src][0:2], [0,4]) + self.assertListEqual([x.arg for x in sink.src[0].src][12:14], [10,14]) def test_contract_multi_axis(self): - e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(16)), ((1,2),(2,2),(3,2),(4,2))) - sink = expander_rewrite(UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), ((3,2),(2,2)))) + e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(16), tuple(x for x in range(16))),), ((1,2),(2,2),(3,2),(4,2))) + sink = expander_rewrite(UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), ((3, 2), (2, 2)))) assert sink.op is UOps.EXPAND and sink.arg == ((1, 2), (4, 2)) - self.assertListEqual([x.arg for x in sink.src[0].src], [0,4,2,6]) - sink = expander_rewrite(UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), ((2,2),(3,2)))) + self.assertListEqual([x.arg for x in sink.src[0].src][0:4], [0, 4, 2, 6]) + sink = expander_rewrite(UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), ((2, 2), (3, 2)))) assert sink.op is UOps.EXPAND and sink.arg == ((1, 2), (4, 2)) - self.assertListEqual([x.arg for x in sink.src[0].src], [0,2,4,6]) + self.assertListEqual([x.arg for x in sink.src[0].src][0:4], [0, 2, 4, 6]) def test_contract_mid(self): - e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(8)), ((1,2),(2,2),(3,2))) + e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(8), tuple(x for x in range(8))),), ((1,2),(2,2),(3,2))) con = UOp(UOps.CONTRACT, dtypes.int.vec(2), (e1,), ((2,2),)) sink = expander_rewrite(con) - assert sink.op is UOps.EXPAND and len(sink.src) == 4 and sink.arg == ((1,2),(3,2)) - assert sink.src[0].op is UOps.VECTORIZE and len(sink.src[0].src) == 2 - self.assertListEqual([x.arg for x in sink.src[0].src], [0,2]) - self.assertListEqual([x.arg for x in sink.src[1].src], [1,3]) - self.assertListEqual([x.arg for x in sink.src[2].src], [4,6]) - self.assertListEqual([x.arg for x in sink.src[3].src], [5,7]) + assert sink.op is UOps.EXPAND and sink.arg == ((1,2),(3,2)) + assert sink.src[0].op is UOps.VECTORIZE and len(sink.src[0].src) == 8 + self.assertListEqual([x.arg for x in sink.src[0].src], [0,2,1,3,4,6,5,7]) def test_contract_no_expand(self): e1 = UOp(UOps.DEFINE_VAR, dtypes.int) @@ -503,7 +501,7 @@ class TestExpander(unittest.TestCase): assert sink.src[0] == sink.src[1] def test_contract_half_expand(self): - e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),)) + e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(x for x in range(4))),), ((1,4),)) con = UOp(UOps.CONTRACT, dtypes.int.vec(8), (e1,), ((1,4), (2,2))) sink = expander_rewrite(con) assert sink.op is UOps.VECTORIZE and len(sink.src) == 8 @@ -512,19 +510,19 @@ class TestExpander(unittest.TestCase): assert sink.src[6] == sink.src[7] def test_expand_same_axis(self): - e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),)) - e2 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, 4*x) for x in range(4)), ((1,4),)) + e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(x for x in range(4))),), ((1,4),)) + e2 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(4*x for x in range(4))),), ((1,4),)) sink = expander_rewrite(e1+e2) - assert sink.op is UOps.EXPAND and len(sink.src) == 4 - self.assertListEqual([x.arg for x in sink.src], [0,5,10,15]) + assert sink.op is UOps.EXPAND and len(sink.src[0].src) == 4 + self.assertListEqual([x.arg for x in sink.src[0].src], [0,5,10,15]) def test_expand_different_axis(self, flip=False): - e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, 4*x) for x in range(4)), ((1,4),)) - e2 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((2,4),)) + e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(4*x for x in range(4))),), ((1,4),)) + e2 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(x for x in range(4))),), ((2,4),)) sink = expander_rewrite((e2+e1) if flip else (e1+e2)) - assert sink.op is UOps.EXPAND and len(sink.src) == 16 + assert sink.op is UOps.EXPAND and len(sink.src[0].src) == 16 assert sink.arg == ((1, 4), (2, 4)) - self.assertListEqual([x.arg for x in sink.src], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]) + self.assertListEqual([x.arg for x in sink.src[0].src], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]) def test_expand_different_axis_flip(self): self.test_expand_different_axis(True) @@ -544,6 +542,7 @@ class TestExpander(unittest.TestCase): assert sink.op is UOps.CONST self.assertEqual(sink.arg, 3*4) + @unittest.skip("no longer supported") def test_double_expand(self): e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((2,4),)) e2 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, 4+x) for x in range(4)), ((2,4),)) @@ -553,6 +552,7 @@ class TestExpander(unittest.TestCase): assert sink.arg == ((1, 2), (2, 4)) self.assertListEqual([x.arg for x in sink.src], [0,1,2,3,4,5,6,7]) + @unittest.skip("no longer supported") def test_double_expand_reverse(self): e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),)) e2 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, 4+x) for x in range(4)), ((1,4),)) @@ -562,6 +562,7 @@ class TestExpander(unittest.TestCase): assert sink.arg == ((1, 4), (2, 2)) self.assertListEqual([x.arg for x in sink.src], [0, 4, 1, 5, 2, 6, 3, 7]) + @unittest.skip("no longer supported") def test_double_expand_middle(self): e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,2),(3,2))) e2 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, 4+x) for x in range(4)), ((1,2),(3,2))) @@ -585,14 +586,14 @@ class TestLoadStoreFolder(unittest.TestCase): def test_simple_load_fold(self): buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float)) load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i))) for i in range(4)] - sink = UOp(UOps.EXPAND, dtypes.float, tuple(load), ((0,4),)) + sink = UOp(UOps.VECTORIZE, dtypes.float.vec(len(load)), tuple(load)) sink = float4_rewrite(sink) assert len([x for x in sink.sparents if x.op is UOps.LOAD]) == 1 def test_two_load_fold(self): buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float)) load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i))) for i in range(8)] - sink = UOp(UOps.EXPAND, dtypes.float, tuple(load), ((0,8),)) + sink = UOp(UOps.VECTORIZE, dtypes.float.vec(len(load)), tuple(load)) sink = float4_rewrite(sink) assert len([x for x in sink.sparents if x.op is UOps.LOAD]) == 2 @@ -600,7 +601,7 @@ class TestLoadStoreFolder(unittest.TestCase): buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float)) gate = UOp(UOps.DEFINE_VAR, dtypes.bool) load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate)) for i in range(4)] - sink = UOp(UOps.EXPAND, dtypes.float, tuple(load), ((0,4),)) + sink = UOp(UOps.VECTORIZE, dtypes.float.vec(len(load)), tuple(load)) sink = float4_rewrite(sink) assert len([x for x in sink.sparents if x.op is UOps.LOAD]) == 1 single_load = [x for x in sink.sparents if x.op is UOps.LOAD][0] @@ -611,7 +612,7 @@ class TestLoadStoreFolder(unittest.TestCase): gate = UOp(UOps.DEFINE_VAR, dtypes.bool, arg=("g1", UOp.const(dtypes.bool, False), UOp.const(dtypes.bool, True))) gate2 = UOp(UOps.DEFINE_VAR, dtypes.bool, arg=("g2", UOp.const(dtypes.bool, False), UOp.const(dtypes.bool, True))) load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate if i == 0 else gate2)) for i in range(4)] - sink = UOp(UOps.EXPAND, dtypes.float, tuple(load), ((0,4),)) + sink = UOp(UOps.VECTORIZE, dtypes.float.vec(len(load)), tuple(load)) sink = float4_rewrite(sink) assert len([x for x in sink.sparents if x.op is UOps.LOAD]) == 3 diff --git a/test/test_uops_stats.py b/test/test_uops_stats.py index f5ebb5e8..62f20c98 100644 --- a/test/test_uops_stats.py +++ b/test/test_uops_stats.py @@ -103,7 +103,7 @@ class TestUOpsStats(unittest.TestCase): c = a.matmul(b) c.realize() expected_ops = N ** 3 * 2 - assert expected_ops == GlobalCounters.global_ops + self.assertEqual(expected_ops, GlobalCounters.global_ops) #MULACC should have the same stats as MUL + ADD def test_mulacc(self): diff --git a/tinygrad/codegen/lowerer.py b/tinygrad/codegen/lowerer.py index 6c55ddca..97e3b2ad 100644 --- a/tinygrad/codegen/lowerer.py +++ b/tinygrad/codegen/lowerer.py @@ -72,7 +72,7 @@ class IndependentLowerer: # upcast loops for i,g in enumerate(full_shape[first_upcasted:], start=first_upcasted): assert isinstance(g, int), "needs to be int to upcast/unroll" - self.idxs.append(UOp(UOps.EXPAND, dtypes.pyint, tuple(UOp.const(dtypes.pyint, j) for j in range(0, g)), ((i,g),))) + self.idxs.append(UOp(UOps.EXPAND, dtypes.pyint, (UOp.const(dtypes.pyint.vec(g), tuple(range(g))),), ((i,g),))) # late indexes (group for reduce) self.ridxs = self.idxs[:] @@ -117,7 +117,7 @@ class IndependentLowerer: UOp(UOps.CONTRACT, dtype=in_uops[0].dtype.vec(wmma_sz[0]), src=(in_uops[0],), arg=upcast_axes[0]), UOp(UOps.CONTRACT, dtype=in_uops[1].dtype.vec(wmma_sz[1]), src=(in_uops[1],), arg=upcast_axes[1]), UOp.const(x.dtype.vec(wmma_sz[2]), 0.0)), arg=x.arg) - return UOp(UOps.EXPAND, x.dtype, tuple(UOp(UOps.GEP, x.dtype, (ret,), (i,)) for i in range(wmma_sz[2])), arg=upcast_axes[2]) + return UOp(UOps.EXPAND, x.dtype, (ret,), arg=upcast_axes[2]) if x.op is UOps.REDUCE_AXIS: # NOTE: always using ridxs is fine here reduce_range, reduce_expand = partition([self.ridxs[i] for i in x.arg[1]], lambda y: y.op is UOps.RANGE) diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index 514a248c..bf774c11 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -5,7 +5,7 @@ from collections import defaultdict from tinygrad.dtype import dtypes, PtrDType, ImageDType from tinygrad.ops import UnaryOps, BinaryOps, exec_alu, UOp, UOps, END_FOR_UOP, type_verify, print_uops, identity_element from tinygrad.ops import UPat, PatternMatcher, graph_rewrite -from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, AMX, prod, CI, partition +from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, AMX, prod, CI, partition, all_same from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, TRANSCENDENTAL_SUPPORTED_DTYPES if TYPE_CHECKING: from tinygrad.renderer import Renderer @@ -71,7 +71,7 @@ def fix_unfoldable_image_load(load:UOp, buf:UOp): return functools.reduce(lambda ret, i: id4.ne(i).where(ret, UOp(UOps.GEP, load.dtype, (vec_load,), (i,))), range(4), load.const_like(float('nan'))) float4_folding = PatternMatcher([ - (UPat(UOps.EXPAND, src=UPat(UOps.LOAD, src=(UPat.var("buf"), UPat()), allow_any_len=True), name="ex"), fold_expanded), + (UPat(UOps.VECTORIZE, src=UPat(UOps.LOAD, src=(UPat.var("buf"), UPat()), allow_any_len=True), name="ex"), fold_expanded), (UPat((UOps.BARRIER, UOps.SINK), src=UPat(UOps.STORE, src=(UPat.var("buf"), UPat(), UPat()), allow_any_len=True), name="ex"), fold_expanded), ]) @@ -201,7 +201,7 @@ def reduce_before_expand(reduce, expand, x): red = UOp(UOps.REDUCE, x.dtype, (x,)+reduce.src[1:], reduce.arg) return UOp(expand.op, expand.dtype, tuple(UOp(UOps.GEP, reduce.dtype, (red,), (i,)) for i in range(x.dtype.count)), expand.arg) -def loop_collapse(loop_start, loop_end, compval, idx, mval, multconst, rng, reduce, idx2=None, idx3=None, extra=None): +def loop_collapse(loop_start, loop_end, compval, idx, mval, multconst, rng, reduce, idx2=None, idx3=None, extra=None, vec=None): if getenv("DISABLE_LOOP_COLLAPSE") or rng not in reduce.src: return None # must be the right REDUCE if mval.arg >= 0 or loop_start.arg != 0: # TODO: support and test this with other mvals and loop_starts @@ -209,6 +209,10 @@ def loop_collapse(loop_start, loop_end, compval, idx, mval, multconst, rng, redu return None if idx2 is not None: idx = idx + idx2 if idx3 is not None: idx = idx + idx3 + if vec is not None: + # idx, mval, loop_start, loop_end + def dvec(x): return UOp(UOps.VECTORIZE, x.dtype.vec(vec.dtype.count), src=(x,)*vec.dtype.count) + idx, mval, loop_start, loop_end = dvec(idx), dvec(mval), dvec(loop_start), dvec(loop_end) comprange = UOp.min(loop_end, UOp.max((idx-compval-mval)//mval + (loop_end-loop_start), loop_start)) new_reduce_op = comprange.cast(multconst.dtype) * multconst ret = UOp(UOps.REDUCE, reduce.dtype, (new_reduce_op,) + tuple(x for x in reduce.src[1:] if x is not rng), reduce.arg) @@ -225,6 +229,8 @@ constant_folder = PatternMatcher([ # bool ADD is OR, MUL is AND. prevents other rules to rewrite bool ADD/MUL incorrectly (UPat(UOps.ALU, dtypes.bool, arg=BinaryOps.ADD, name="x"), lambda x: UOp(x.op, x.dtype, x.src, BinaryOps.OR)), (UPat(UOps.ALU, dtypes.bool, arg=BinaryOps.MUL, name="x"), lambda x: UOp(x.op, x.dtype, x.src, BinaryOps.AND)), + # self ASSIGN is just self + (UPat(UOps.ASSIGN, src=(UPat.var('x'), UPat.var('x'))), lambda x: x), # VECTORIZE/GEP: the expander rule allows tuple GEP creation, this is just for removal (UPat(UOps.VECTORIZE, src=UPat(UOps.GEP, src=(UPat(name="x"),)), name="vec"), lambda vec,x: x if x.dtype == vec.dtype and tuple(y.arg[0] for y in vec.src) == tuple(range(len(vec.src))) else None), @@ -234,6 +240,7 @@ constant_folder = PatternMatcher([ # VECTORIZE of a single element is just that element (UPat(UOps.VECTORIZE, src=(UPat(name='x'),)), lambda x: x), # VECTORIZE void is SINK + (UPat(UOps.VECTORIZE, dtype=dtypes.void, src=UPat(UOps.BARRIER, name='b')), lambda b: b), (UPat(UOps.VECTORIZE, dtype=dtypes.void, name='x'), lambda x: UOp(UOps.SINK, dtypes.void, x.src)), # GEP/VECTORIZE, GEP/GEP, GEP/CONST, GEP/VCONST (UPat(UOps.GEP, src=(UPat(UOps.GEP, name='g2'),), name='g1'), @@ -242,9 +249,12 @@ constant_folder = PatternMatcher([ lambda gep, vec: UOp(UOps.VECTORIZE, gep.dtype, tuple(vec.src[i] for i in gep.arg)) if len(gep.arg) > 1 else vec.src[gep.arg[0]]), (UPat(UOps.GEP, src=(UPat.cvar("c", vec=False),), name="gep"), lambda gep, c: gep.const_like(c.arg)), (UPat(UOps.GEP, src=(UPat(UOps.VCONST, name="c"),), name="gep"), lambda gep, c: gep.const_like(tuple(c.arg[x] for x in gep.arg))), + # push all GEPs through ALUs (fix arange stuff) + (UPat(UOps.GEP, src=(UPat((UOps.ALU, UOps.CAST, UOps.BITCAST), name='alu'),), name='gep'), + lambda gep,alu: UOp(alu.op, alu.dtype.scalar().vec(gep.dtype.count), tuple(x.gep(gep.arg) for x in alu.src), alu.arg)), # GEP add push (non-shrinking only) - (UPat(UOps.GEP, None, (UPat.var('x') + UPat.cvar('c1'),), name="gep"), - lambda x,c1,gep: x.gep(gep.arg) + c1.gep(gep.arg) if len(gep.arg) >= x.dtype.count else None), + #(UPat(UOps.GEP, None, (UPat.var('x') + UPat.cvar('c1'),), name="gep"), + # lambda x,c1,gep: x.gep(gep.arg) + c1.gep(gep.arg) if len(gep.arg) >= x.dtype.count else None), # tensor core with a 0 input is acc *[(UPat(UOps.WMMA, src=(UPat.const(None, 0.0), UPat.var(), UPat.var("acc"))), lambda acc: acc) for i in [2, 4, 8]], *[(UPat(UOps.WMMA, src=(UPat.var(), UPat.const(None, 0.0), UPat.var("acc"))), lambda acc: acc) for i in [2, 4, 8]], @@ -265,12 +275,22 @@ constant_folder = PatternMatcher([ .where(UPat.cvar("multconst"), UPat.const(None, 0)),), arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse), # arange loop folding (reduce) (UPat(UOps.REDUCE, src=((UPat.var("idx") + UPat.cvar("mval") * UPat(UOps.RANGE, src=(UPat.var("loop_start"), UPat.var("loop_end")), name="rng")) - .lt(UPat.cvar("compval")) - .where(UPat.cvar("multconst"), UPat.const(None, 0)),), arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse), + .lt(UPat.cvar("compval")).where(UPat.cvar("multconst"), UPat.const(None, 0)),), + arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse), # arange loop folding (unrolled) (UPat(UOps.REDUCE, src=((UPat.var("idx") + UPat.cvar("mval") * UPat(UOps.RANGE, src=(UPat.var("loop_start"), UPat.var("loop_end")), name="rng")) .lt(UPat.cvar("compval")).where(UPat.cvar("multconst"), UPat.const(None, 0)) + UPat.var("extra"),), arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse), + # arange loop folding (vectorized) + (UPat(UOps.REDUCE, src=(UPat(UOps.VECTORIZE, name="vec", src=(UPat.var("idx") + UPat.cvar("mval", vec=False) * + UPat(UOps.RANGE, src=(UPat.cvar("loop_start", vec=False), UPat.var("loop_end")), name="rng"))) + .lt(UPat.cvar("compval")).where(UPat.cvar("multconst"), UPat.const(None, 0)),), + arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse), + # arange loop folding (unrolled, vectorized) + (UPat(UOps.REDUCE, src=(UPat(UOps.VECTORIZE, name="vec", src=(UPat.var("idx") + UPat.cvar("mval", vec=False) * + UPat(UOps.RANGE, src=(UPat.cvar("loop_start", vec=False), UPat.var("loop_end")), name="rng"))) + .lt(UPat.cvar("compval")).where(UPat.cvar("multconst"), UPat.const(None, 0)) + UPat.var("extra"),), + arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse), # unrolled arange div folding (UPat.var("divs") + UPat.cvar("c"), fold_unrolled_divs), # indexing (with a multiply offset)! @@ -406,22 +426,47 @@ def do_expand(root:UOp): if len(expands) == 0: return None # NOTE: we 0 out the reduce axis for WMMA. in theory they should all be the same, but is this always correct? exclude_args = tuple(dedup(root.arg[-1] + tuple(y[0] for y in flatten(root.arg[-2])))) if root.op is UOps.WMMA else () - expand_args = tuple(x for x in sorted(dedup(flatten([x.arg for x in expands]))) if x[0] not in exclude_args) - esrcs = [[src.src[x] for x in _swizzle_args(expand_args, src.arg, exclude_args)] \ - if src.op is UOps.EXPAND else itertools.repeat(src) for src in root.src] - new_srcs = [UOp(root.op, root.dtype, new_src, root.arg) for new_src in zip(*esrcs)] - if root.op is UOps.EXPAND: - # merge two expands - expand_args, old_args = tuple(sorted(root.arg+expand_args)), expand_args - assert len(expand_args) == (len(old_args) + len(root.arg)) - new_srcs = [new_srcs[_expand_arg_to_idx(old_args, rpk)].src[_expand_arg_to_idx(root.arg, rpk)] for rpk in _choices_from_args(expand_args)] - if root.op is UOps.IF: - # merge ifs into an or - conditions = functools.reduce(lambda x,y: x|y, dedup(x.src[0] for x in new_srcs if x.src[0].op is not UOps.CONST)) - barriers = tuple(set(x.src[1] for x in new_srcs)) - new_srcs = [UOp(UOps.IF, src=(conditions,)+barriers) for _ in new_srcs] - assert prod([x[1] for x in expand_args]) == len(new_srcs) - return UOp(UOps.EXPAND, root.dtype, tuple(new_srcs), expand_args) + if all_same(expands_args:=[x.arg for x in expands]) and len(exclude_args) == 0: + # if there's only one expand arg, it's okay to use it (optimization) + expand_args = expands[0].arg + else: + # otherwise, we sort them and GEP + expand_args = tuple(x for x in sorted(dedup(flatten(expands_args))) if x[0] not in exclude_args) + expand_sz = prod([x[1] for x in expand_args]) + new_srcs = [] + for i,src in enumerate(root.src): + if src.op is UOps.EXPAND: + if root.op is UOps.IF and i == 0: + # IF means OR on first arg to IF + new_srcs.append(functools.reduce(operator.__or__, [src.src[0].gep(i) for i in range(expand_sz)])) + elif expand_args == src.arg: + # just remove the expand + new_srcs.append(src.src[0]) + else: + lst = _swizzle_args(expand_args, src.arg, exclude_args) + # if the base dtype is > 1, put those at the end + if src.dtype.count > 1: lst = flatten([[i*src.dtype.count+j for j in range(src.dtype.count)] for i in lst]) + new_srcs.append(src.src[0].gep(tuple(lst))) + else: + # non-EXPAND input + if (root.op in {UOps.LOAD, UOps.STORE} and i == 0) or (root.op is UOps.REDUCE and i != 0): + # for the first arg of LOAD/STORE and the RANGE args of REDUCE, just pass them through ignoring EXPANDS + new_srcs.append(src) + elif src.dtype.count > 1: + # put any input dtype > 1 grouped together + new_srcs.append(UOp(UOps.VECTORIZE, + src.dtype.scalar().vec(expand_sz*src.dtype.count), tuple(src.gep(i) for i in range(src.dtype.count))*expand_sz)) + else: + # repeat the arg + new_srcs.append(src.broadcast(expand_sz)) + + new_arg = root.arg + if root.op is UOps.GEP: + assert root.dtype.count == 1 + # is this right? + new_arg = tuple(range(root.arg[0], new_srcs[0].dtype.count, new_srcs[0].dtype.count // expand_sz)) + nsrc = UOp(root.op, root.dtype.scalar().vec(root.dtype.count*expand_sz), tuple(new_srcs), new_arg) + return UOp(UOps.EXPAND, root.dtype, (nsrc,), expand_args) acc_number = 0 def do_reduce(root:UOp): @@ -435,7 +480,7 @@ def do_reduce(root:UOp): ret = UOp(UOps.ASSIGN, root.dtype, (acc, acc.alu(root.arg, ret))) # for MAX, we can just ignore the unparented if root.arg is BinaryOps.ADD: - for r in reduce_unparented: ret = ret * (r.src[1]-r.src[0]).cast(ret.dtype) + for r in reduce_unparented:ret = ret * (r.src[1]-r.src[0]).cast(ret.dtype.scalar()).broadcast(ret.dtype.count) return ret def do_contract(con:UOp): @@ -444,16 +489,16 @@ def do_contract(con:UOp): if ex.op is not UOps.EXPAND: return UOp(UOps.VECTORIZE, con.dtype, con.src*con.dtype.count) # CONTRACT may remove several axes from EXPAND assert con.dtype.count == prod([x[1] for x in con.arg]), "dtype is wrong" - srcs = [] + idxs = [] for rpk in _choices_from_args(new_ex_args:=tuple(x for x in ex.arg if x not in con.arg)): - lsrcs = [ex.src[_expand_arg_to_idx(ex.arg, {**rpk, **lrpk})] for lrpk in _choices_from_args(con.arg)] - srcs.append(UOp(UOps.VECTORIZE, con.dtype, tuple(lsrcs))) - return srcs[0] if len(srcs) == 1 else UOp(UOps.EXPAND, con.dtype, tuple(srcs), new_ex_args) + idxs += [_expand_arg_to_idx(ex.arg, {**rpk, **lrpk}) for lrpk in _choices_from_args(con.arg)] + return UOp(UOps.EXPAND, con.dtype, (ex.src[0].gep(tuple(idxs)),), new_ex_args) def no_vectorized_alu(alu): if alu.dtype.count == 1: return None alus = tuple(UOp(alu.op, alu.dtype.scalar(), - tuple(UOp(UOps.GEP, s.dtype.scalar(), (s,), (i,)) for s in alu.src), alu.arg) for i in range(alu.dtype.count)) + tuple(UOp(UOps.GEP, s.dtype.scalar(), (s,), (i,)) if alu.op is not UOps.REDUCE or j == 0 else s for j,s in enumerate(alu.src)), + alu.arg) for i in range(alu.dtype.count)) return UOp(UOps.VECTORIZE, alu.dtype, alus) def create_gate(root:UOp) -> Optional[UOp]: @@ -470,9 +515,12 @@ expander = PatternMatcher([ (UPat(UOps.VECTORIZE, src=UPat(UOps.GEP, src=(UPat(name="x"),)), name="vec"), lambda vec,x: x.gep(tuple(y.arg[0] for y in vec.src))), # create gate MUST BE BEFORE expander (UPat(UOps.STORE, name="root"), create_gate), + # double expand + (UPat(UOps.EXPAND, name="outer", src=(UPat(UOps.EXPAND, name="inner"),)), + lambda outer, inner: UOp(UOps.EXPAND, outer.dtype, (inner.src[0],), inner.arg+outer.arg)), # do expansion (UPat((UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.GEP, UOps.WMMA, UOps.LOAD, UOps.STORE, - UOps.VECTORIZE, UOps.REDUCE, UOps.EXPAND, UOps.IF), name="root", custom_early_reject=set([(UOps.EXPAND, None)])), do_expand), + UOps.VECTORIZE, UOps.REDUCE, UOps.IF), name="root", custom_early_reject=set([(UOps.EXPAND, None)])), do_expand), (UPat(UOps.CONTRACT, name="con"), do_contract), # remove EXPANDs from SINK (UPat(UOps.SINK, name="root"), @@ -488,6 +536,37 @@ expander = PatternMatcher([ lambda ex,x,y: UOp(UOps.EXPAND, ex.dtype, tuple((x+y).gep(i) for i in range(256 if AMX else 8)), ex.arg)), ]) +just_reduce = PatternMatcher([ + # do reduce (in expander now) + (UPat(UOps.REDUCE, name="root"), do_reduce), +]) + +def no_vectorized_load_store(ls:UOp): + idx = ls.src[1] + if idx.dtype.count == 1: return None + # ugh, the meaning of a dtype.count idx is overloaded + if ls.op is UOps.LOAD and idx.dtype.count != ls.dtype.count: return None + if ls.op is UOps.STORE and idx.dtype.count != ls.src[2].dtype.count: return None + tv = [UOp(ls.op, ls.dtype.scalar(), (ls.src[0],) + tuple(j.gep(i) for j in ls.src[1:])) for i in range(idx.dtype.count)] + return UOp(UOps.VECTORIZE, ls.dtype, tuple(tv)) + +def no_vectorized_acc(acc:UOp): + if acc.dtype.count == 1: return None + alus = tuple(UOp(acc.op, acc.dtype.scalar(), + tuple(UOp(UOps.GEP, s.dtype.scalar(), (s,), (i,)) if j == 0 else s for j,s in enumerate(acc.src)), acc.arg+(i,)) for i in range(acc.dtype.count)) + return UOp(UOps.VECTORIZE, acc.dtype, alus) + +def no_vectorized_wmma(wmma:UOp): + out_sz = prod(x[1] for x in wmma.arg[6][-1]) + if wmma.dtype.count == out_sz: return None + tsrcs = [] + for s,sz in zip(wmma.src, wmma.arg[6]): + ssz = prod(x[1] for x in sz) + tsrcs.append([s.gep(tuple(range(grp, grp+ssz))) for grp in range(0, s.dtype.count, ssz)]) + wmmas = [UOp(UOps.WMMA, wmma.dtype.scalar().vec(out_sz), tsrc, wmma.arg) for tsrc in zip(*tsrcs)] + wmma_ex = flatten([[e.gep(i) for i in range(out_sz)] for e in wmmas]) + return UOp(UOps.VECTORIZE, wmma.dtype, tuple(wmma_ex)) + def delete_redundant_gates(root:UOp) -> Optional[UOp]: @functools.lru_cache(None) def find_gate(x:UOp) -> Optional[UOp]: @@ -496,14 +575,19 @@ def delete_redundant_gates(root:UOp) -> Optional[UOp]: if len(root.src) == 3 or (gate:=find_gate(root)) is None or gate.src[0] is not root.src[3]: return None return UOp(UOps.STORE, root.dtype, root.src[:3], root.arg) +devectorize = PatternMatcher([ + # no ALU on vectorized dtypes + (UPat((UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.ASSIGN, UOps.REDUCE), name="alu"), no_vectorized_alu), + (UPat(UOps.WMMA, name="wmma"), no_vectorized_wmma), + (UPat(UOps.DEFINE_ACC, name="acc"), no_vectorized_acc), + (UPat((UOps.LOAD, UOps.STORE), name="ls"), no_vectorized_load_store), +]) + reducer = PatternMatcher([ - (UPat(UOps.REDUCE, name="root"), do_reduce), (UPat(UOps.CONST, name='c'), lambda c: UOp(UOps.VECTORIZE, c.dtype, (UOp.const(c.dtype.scalar(), c.arg),)*c.dtype.count) if c.dtype.count > 1 else None), (UPat(UOps.VCONST, name='c'), lambda c: UOp(UOps.VECTORIZE, c.dtype, tuple(UOp.const(c.dtype.scalar(), x) for x in c.arg))), (UPat(UOps.GEP, name='gep'), lambda gep: UOp(UOps.VECTORIZE, gep.dtype, tuple(gep.src[0].gep(x) for x in gep.arg)) if len(gep.arg) > 1 else None), - # no ALU on vectorized dtypes - (UPat((UOps.ALU, UOps.CAST, UOps.BITCAST), name="alu"), no_vectorized_alu), # delete_redundant_gates (after expand, is this still needed?) (UPat(UOps.STORE, name="root"), delete_redundant_gates), # late fixup of unfoldable image loads @@ -544,8 +628,11 @@ def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp: # expand linearize_cnt += 1 if linearize_cnt != (de:=getenv("DEBUG_EXPAND", 0)) and de != -1: - sink = graph_rewrite(sink, folder+(expander+float4_folding if opts is not None and opts.supports_float4 else expander)) - if getenv("DO_REDUCE", 1): sink = graph_rewrite(sink, folder+reducer) + sink = graph_rewrite(sink, folder+expander) + if getenv("DO_REDUCE", 1): + sink = graph_rewrite(sink, folder+just_reduce) + sink = graph_rewrite(sink, folder+(devectorize+float4_folding if opts is not None and opts.supports_float4 else devectorize)) + sink = graph_rewrite(sink, folder+reducer) # for PTX only if opts is not None and opts.extra_matcher is not None: sink = graph_rewrite(sink, folder+opts.extra_matcher) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 7231c8ef..122399a2 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -349,6 +349,9 @@ class UOp(MathTrait): def __init__(self, op: UOps, dtype:DType=dtypes.void, src: Tuple[UOp,...]=tuple(), arg:Any=None): # TODO: instant check rules here make debugging easier #if op is UOps.ALU and arg is BinaryOps.CMPNE: assert dtype.scalar() == dtypes.bool + #if op is UOps.VECTORIZE and dtype != dtypes.void: assert len(src) == dtype.count, f"{len(src)} invalid for {dtype}" + #if op is UOps.ALU and arg not in (BinaryOps.CMPNE, BinaryOps.CMPLT, TernaryOps.WHERE): assert all_same([dtype] + [x.dtype for x in src]) + #if op is UOps.CAST: assert dtype.count == src[0].dtype.count, f"cast can't change vectorization {src[0].dtype} --> {dtype}" self.op, self.dtype, self.src, self.arg = op, dtype, src, arg def replace(self, op: Optional[UOps]=None, dtype:Optional[DType]=None, src: Optional[Tuple[UOp,...]]=None, arg:Any=None): return UOp(op or self.op, dtype or self.dtype, self.src if src is None else src, self.arg if arg is None else arg) @@ -387,6 +390,10 @@ class UOp(MathTrait): def sink(self, *srcs): return UOp(UOps.SINK, dtypes.void, (self,)+srcs) def swizzle(self, st:ShapeTracker): return UOp(UOps.SWIZZLE, self.dtype, (self,), st) def const_like(self, b:ConstType|Variable|Tuple[ConstType]): return type(self).const(self.dtype, b) + def broadcast(self, count:int): + assert self.dtype.count == 1 + if count == 1: return self + return UOp(UOps.VECTORIZE, self.dtype.vec(count), (self,)*count) def cast(self, dtype:DType): return type(self)(UOps.CAST, dtype, (self,)) def bitcast(self, dtype:DType): return type(self)(UOps.BITCAST, dtype, (self,)) def gep(self, i:Union[Tuple[int, ...], int]):