From 912f01ed4b5adbd197f83b7dfe754a20d1a627dd Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Fri, 16 Aug 2024 19:48:39 -0700 Subject: [PATCH] UOpGraph -> linearize_uop [run_process_replay] (#6119) --- test/external/fuzz_uops.py | 10 +- test/test_uop_graph.py | 45 +++++---- test/test_uops.py | 20 ++-- test/test_uops_stats.py | 8 +- test/unit/test_uop_symbolic.py | 5 +- tinygrad/codegen/kernel.py | 5 +- tinygrad/codegen/uopgraph.py | 170 ++++++++++++++++----------------- 7 files changed, 127 insertions(+), 136 deletions(-) diff --git a/test/external/fuzz_uops.py b/test/external/fuzz_uops.py index 9e780bed..af9f8d7f 100644 --- a/test/external/fuzz_uops.py +++ b/test/external/fuzz_uops.py @@ -3,8 +3,7 @@ from collections import defaultdict import numpy as np from dataclasses import replace from typing import DefaultDict, Dict, List, Tuple -from tinygrad.ops import END_FOR_UOP, UOp -from tinygrad.codegen.uopgraph import UOpGraph +from tinygrad.ops import END_FOR_UOP, UOp, print_uops from tinygrad.device import Buffer, Device from tinygrad.engine.realize import CompiledRunner from tinygrad.helpers import DEBUG, colored @@ -49,10 +48,9 @@ class UOpsFuzzerRunner(CompiledRunner): for i, path in enumerate(fuzz_paths): # setup prg - uops = UOpGraph([]) - uops._uops = list(path) - if DEBUG >= 5: uops.print() - self.p = replace(self.p, name=(name:=f"{init_name}fuzz{i}"), src=Device[self.p.dname].renderer.render(name, uops.uops), uops=uops.uops) + uops = list(path) + if DEBUG >= 5: print_uops(uops) + self.p = replace(self.p, name=(name:=f"{init_name}fuzz{i}"), src=Device[self.p.dname].renderer.render(name, uops), uops=uops) if DEBUG >= 4: print(self.p.src) self.lib = Device[self.p.dname].compiler.compile_cached(self.p.src) self.clprg = Device[self.p.dname].runtime(name, self.lib) diff --git a/test/test_uop_graph.py b/test/test_uop_graph.py index d6679429..bc0c93b3 100644 --- a/test/test_uop_graph.py +++ b/test/test_uop_graph.py @@ -4,7 +4,7 @@ from tinygrad import dtypes, Variable from tinygrad.dtype import PtrDType from tinygrad.helpers import DEBUG from tinygrad.ops import BinaryOps, TernaryOps, UnaryOps, ReduceOps, UOps, UOp, NOp, PatternMatcher -from tinygrad.codegen.uopgraph import UOpGraph, graph_rewrite, expander, reducer, constant_folder, float4_folding +from tinygrad.codegen.uopgraph import linearize_uop, graph_rewrite, expander, reducer, constant_folder, float4_folding simple_pm = PatternMatcher([ (NOp.cvar('x', dtypes.int), lambda x: UOp.const(dtypes.float, 1.0) + UOp.const(dtypes.float, 2.0)), @@ -94,7 +94,7 @@ class TestUOpGraph(TestUOps): c1 = UOp(UOps.CONST, dtypes.float, arg=1.0) c2 = UOp(UOps.CONST, dtypes.float, arg=2.0) out = UOp(UOps.ALU, dtypes.float, (c1, c2), BinaryOps.ADD) - uops = UOpGraph([out]).linearize() + uops = linearize_uop([out]) self.assertEqual(len(uops), 1) out = uops[-1] self.assertEqual(out.op, UOps.CONST) @@ -106,7 +106,7 @@ class TestUOpGraph(TestUOps): vc = UOp(UOps.ALU, dtypes.bool, (v, c0), BinaryOps.CMPNE) c1 = UOp(UOps.CONST, dtypes.float, arg=1.0) out = UOp(UOps.ALU, dtypes.float, (vc, c1, c1), TernaryOps.WHERE) - uops = UOpGraph([out]).linearize() + uops = linearize_uop([out]) self.assertEqual(len(uops), 1) out = uops[-1] self.assertEqual(out.op, UOps.CONST) @@ -117,7 +117,7 @@ class TestUOpGraph(TestUOps): c1 = UOp(UOps.CONST, dtypes.float, arg=1.0) c2 = UOp(UOps.CONST, dtypes.float, arg=2.0) out = UOp(UOps.ALU, dtypes.float, (bf, c1, c2), TernaryOps.WHERE) - uops = UOpGraph([out]).linearize() + uops = linearize_uop([out]) self.assertEqual(len(uops), 1) out = uops[-1] self.assertEqual(out.op, UOps.CONST) @@ -126,7 +126,7 @@ class TestUOpGraph(TestUOps): def test_const_cast(self): bf = UOp(UOps.CONST, dtypes.bool, arg=False) out = UOp(UOps.CAST, dtypes.int, (bf,)) - uops = UOpGraph([out]).linearize() + uops = linearize_uop([out]) self.assertEqual(len(uops), 1) out = uops[-1] self.assertEqual(out.op, UOps.CONST) @@ -140,7 +140,7 @@ class TestUOpGraph(TestUOps): x = UOp(UOps.GEP, dtypes.float, (vec, ), arg=0) alu = UOp(UOps.ALU, dtypes.float, (x, ), UnaryOps.SQRT) out = UOp(UOps.STORE, None, (d0, idx, alu)) - uops = UOpGraph([out]).linearize() + uops = linearize_uop([out]) self.assertEqual(len([x for x in uops if x.op is UOps.VECTORIZE]), 0) def test_gep_vec_fold(self): @@ -151,7 +151,7 @@ class TestUOpGraph(TestUOps): def _test_vec(geps, count=4): vec = UOp(UOps.VECTORIZE, dtypes.float.vec(count), geps) out = UOp(UOps.STORE, None, (d0, idx, vec)) - uops = UOpGraph([out]).linearize() + uops = linearize_uop([out]) if DEBUG >= 4: from tinygrad import Device print(Device[Device.DEFAULT].renderer.render("test", uops)) @@ -187,7 +187,7 @@ class TestUOpGraph(TestUOps): consts = [UOp.const(dtypes.float, float(i)) for i in range(vec_size)] vec = UOp(UOps.VECTORIZE, dtypes.float.vec(vec_size), tuple(consts)) geps = [UOp(UOps.GEP, dtypes.float, (vec,), i) for i in range(vec_size)] - uops = UOpGraph(geps).linearize() + uops = linearize_uop(geps) for uop, const in zip(uops, consts): self.assert_equiv_uops(uop, const) @@ -197,7 +197,7 @@ class TestUOpGraph(TestUOps): var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i)) acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0)) wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc)) - uops = UOpGraph([wmma]).linearize() + uops = linearize_uop([wmma]) self.assert_equiv_uops(uops[0], acc) self.assertEqual(len(uops), 1) @@ -206,7 +206,7 @@ class TestUOpGraph(TestUOps): vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i), tuple(UOp.const(dtypes.half, 0.0) for _ in range(i))) acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0)) wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc)) - uops = UOpGraph([wmma]).linearize() + uops = linearize_uop([wmma]) self.assert_equiv_uops(uops[0], acc) self.assertEqual(len(uops), 1) @@ -218,7 +218,7 @@ class TestUOpGraph(TestUOps): var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable(f'tmp{i}', 0.0, 1.0)) acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0)) wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc)) - uops = UOpGraph([wmma]).linearize() + uops = linearize_uop([wmma]) self.assert_equiv_uops(uops[-1], wmma) for i in [4, 8]: @@ -228,7 +228,7 @@ class TestUOpGraph(TestUOps): tuple(UOp(UOps.DEFINE_VAR, dtypes.half, arg=Variable(f'tmp{j}', 0.0, 1.0)) for j in range(i//2))) acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0)) wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc)) - uops = UOpGraph([wmma]).linearize() + uops = linearize_uop([wmma]) self.assert_equiv_uops(uops[-1], wmma) for i in [2, 4, 8]: @@ -237,7 +237,7 @@ class TestUOpGraph(TestUOps): var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable(f'tmp{i}', 0.0, 1.0)) acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0)) wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc)) - uops = UOpGraph([wmma]).linearize() + uops = linearize_uop([wmma]) self.assert_equiv_uops(uops[-1], wmma) for i in [2, 4, 8]: @@ -246,7 +246,7 @@ class TestUOpGraph(TestUOps): tuple(UOp.const(dtypes.half, 1.0 if j == 0 else 0.0) for j in range(i))) acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0)) wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc)) - uops = UOpGraph([wmma]).linearize() + uops = linearize_uop([wmma]) self.assert_equiv_uops(uops[-1], wmma) def test_cast_alu_fold(self): @@ -256,7 +256,7 @@ class TestUOpGraph(TestUOps): ld = UOp(UOps.LOAD, dtypes.int, (d1, idx)) alu = ld.lt(1).cast(dtypes.bool) out = UOp(UOps.STORE, None, (d0, idx, alu)) - uops = UOpGraph([out]).linearize() + uops = linearize_uop([out]) self.assertEqual(len([x for x in uops if x.op is UOps.CAST]), 0) def test_double_cast_fold(self): @@ -266,7 +266,7 @@ class TestUOpGraph(TestUOps): ld = UOp(UOps.LOAD, dtypes.int, (d1, idx)) alu = ld.cast(dtypes.float).cast(dtypes.float) out = UOp(UOps.STORE, None, (d0, idx, alu)) - uops = UOpGraph([out]).linearize() + uops = linearize_uop([out]) self.assertEqual(len([x for x in uops if x.op is UOps.CAST]), 1) def test_depth_2_const_fold(self): @@ -275,7 +275,7 @@ class TestUOpGraph(TestUOps): c4 = UOp(UOps.CONST, dtypes.int, arg=4) vc = UOp(UOps.ALU, dtypes.int, (v, c2), BinaryOps.ADD) out = UOp(UOps.ALU, dtypes.int, (vc, c4), BinaryOps.ADD) - uops = UOpGraph([out]).linearize() + uops = linearize_uop([out]) self.assertEqual(len(uops), 5) out = uops[-1] self.assertEqual(out.op, UOps.ALU) @@ -290,7 +290,7 @@ class TestUOpGraph(TestUOps): idx = UOp.const(dtypes.int, 0) ld0 = UOp(UOps.LOAD, dtypes.int, (glbl1, idx, UOp.const(dtypes.int, 2), UOp.const(dtypes.bool, False))) ld1 = UOp(UOps.LOAD, dtypes.int, (glbl2, idx, UOp.const(dtypes.int, 3), UOp.const(dtypes.bool, True))) - uops = UOpGraph([UOp(UOps.STORE, None, (glbl0, idx, ld1+ld0))]).linearize() + uops = linearize_uop([UOp(UOps.STORE, None, (glbl0, idx, ld1+ld0))]) ld0, ld1 = uops[-1].src[2].src # ld0 becomes the invalid value self.assert_equiv_uops(ld1, UOp.const(dtypes.int, 2)) @@ -305,7 +305,7 @@ class TestUOpGraph(TestUOps): barrier = UOp(UOps.BARRIER, None, (st, )) ld0 = UOp(UOps.LOAD, dtypes.int, (smem, lidx+1, UOp.const(dtypes.int, 2), UOp.const(dtypes.bool, False), barrier)) ld1 = UOp(UOps.LOAD, dtypes.int, (smem, lidx+2, UOp.const(dtypes.int, 3), UOp.const(dtypes.bool, True), barrier)) - uops = UOpGraph([UOp(UOps.STORE, None, (glbl0, lidx, ld1+ld0))]).linearize() + uops = linearize_uop([UOp(UOps.STORE, None, (glbl0, lidx, ld1+ld0))]) ld0, ld1 = uops[-1].src[2].src # ld0 becomes the invalid value self.assert_equiv_uops(ld1, UOp.const(dtypes.int, 2)) @@ -319,7 +319,7 @@ class TestUOpGraph(TestUOps): val = UOp.const(dtypes.int, 42) st0 = UOp(UOps.STORE, None, (glbl, idx0, val, UOp.const(dtypes.bool, False))) st1 = UOp(UOps.STORE, None, (glbl, idx1, val, UOp.const(dtypes.bool, True))) - uops = UOpGraph([st0, st1]).linearize() + uops = linearize_uop([st0, st1]) # only the second store happens self.assertEqual(len(uops), 4) self.assert_equiv_uops(uops[-1], UOp.store(glbl, idx1, val)) @@ -328,8 +328,7 @@ class TestUOpGraph(TestUOps): glbl0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0) idx = UOp.const(dtypes.int, 0) bad_gate = UOp.const(dtypes.int, 1) - uops = UOpGraph([UOp(UOps.STORE, None, (glbl0, idx, UOp.const(dtypes.int, 42), bad_gate))]) - with self.assertRaises(AssertionError): uops.linearize() + with self.assertRaises(AssertionError): linearize_uop([UOp(UOps.STORE, None, (glbl0, idx, UOp.const(dtypes.int, 42), bad_gate))]) def test_switched_range_order(self): glbl = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0) @@ -340,7 +339,7 @@ class TestUOpGraph(TestUOps): r2 = UOp(UOps.RANGE, dtypes.int, (c0, c2), (1, 1, False)) alu = UOp(UOps.ALU, dtypes.int, (r2, r1), BinaryOps.MUL) store = UOp(UOps.STORE, None, (glbl, alu, cf)) - uops = UOpGraph([store]).linearize() + uops = linearize_uop([store]) ranges = [x for x in uops if x.op is UOps.RANGE] endranges = [x for x in uops if x.op is UOps.ENDRANGE] # ranges are closed in the right order diff --git a/test/test_uops.py b/test/test_uops.py index 55542dbb..4f04909b 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -9,11 +9,11 @@ from tinygrad.ops import UOps, NOp, UOp, UnaryOps, BinaryOps, TernaryOps, Reduce from tinygrad.renderer import Program from tinygrad.engine.schedule import create_schedule from tinygrad.engine.realize import CompiledRunner, lower_schedule_item, get_kernel -from tinygrad.codegen.uopgraph import UOpGraph +from tinygrad.codegen.uopgraph import linearize_uop from test.helpers import is_dtype_supported, TestUOps as TestEqUOps def _uops_to_prg(uops_list): - uops = UOpGraph(uops_list).linearize(Device[Device.DEFAULT].renderer.extra_matcher) + uops = linearize_uop(uops_list, extra_pm=Device[Device.DEFAULT].renderer.extra_matcher) src = Device[Device.DEFAULT].renderer.render("test", uops) has_local = Device[Device.DEFAULT].renderer.has_local return CompiledRunner(Program("test", src, Device.DEFAULT, uops=uops, @@ -247,7 +247,7 @@ class TestGatedStoreRewrite(unittest.TestCase): val = UOp.const(dtypes.float, 42.0) gate = gidx0.lt(UOp.const(dtypes.int, 1)) store = UOp(UOps.STORE, None, (gmem, idx, val, gate)) - uops = UOpGraph([store]) + uops = linearize_uop([store]) if DEBUG >= 4: print(Device[Device.DEFAULT].renderer.render("test", uops)) if_uop = next(u for u in uops if u.op is UOps.IF) endif = next(u for u in uops if u.op is UOps.ENDIF) @@ -265,7 +265,7 @@ class TestGatedStoreRewrite(unittest.TestCase): val = UOp.const(dtypes.float, 42.0) gate = gidx0.lt(UOp.const(dtypes.int, 1)) stores = [UOp.store(gmem0, idx, val, gate), UOp.store(gmem1, idx, val)] - uops = UOpGraph(stores) + uops = linearize_uop(stores) if DEBUG >= 4: print(Device[Device.DEFAULT].renderer.render("test", uops)) if_uop = next(u for u in uops if u.op is UOps.IF) endif = next(u for u in uops if u.op is UOps.ENDIF) @@ -284,7 +284,7 @@ class TestGatedStoreRewrite(unittest.TestCase): val = UOp.const(dtypes.float, 42.0) gate = gidx0.lt(UOp.const(dtypes.int, 1)) stores = [UOp.store(gmem0, idx, val, gate), UOp.store(gmem1, idx, val, gate)] - uops = UOpGraph(stores) + uops = linearize_uop(stores) if DEBUG >= 4: print(Device[Device.DEFAULT].renderer.render("test", uops)) ifs = [u for u in uops if u.op is UOps.IF] endifs = [u for u in uops if u.op is UOps.ENDIF] @@ -326,7 +326,7 @@ class TestAssembly(unittest.TestCase): l1 = UOp(UOps.LOAD, dtypes.int, (g1, c1)) a1 = UOp(UOps.ALU, dtypes.int, (l1, c1), BinaryOps.MUL) a2 = UOp(UOps.ALU, dtypes.int, (l1, c2), BinaryOps.MUL) - uops = UOpGraph([a1,a2]).linearize(Device[Device.DEFAULT].renderer.extra_matcher) + uops = linearize_uop([a1,a2], extra_pm=Device[Device.DEFAULT].renderer.extra_matcher) Device[Device.DEFAULT].renderer.render("test", uops) self.assertEqual(uops[-1].arg, BinaryOps.SHL) self.assertEqual(uops[-2].arg, BinaryOps.MUL) @@ -338,7 +338,7 @@ class TestAssembly(unittest.TestCase): l1 = UOp(UOps.LOAD, dtypes.int, (g1, c1)) a1 = UOp(UOps.ALU, dtypes.int, (l1, c1), BinaryOps.IDIV) a2 = UOp(UOps.ALU, dtypes.int, (l1, c2), BinaryOps.IDIV) - uops = UOpGraph([a1,a2]).linearize(Device[Device.DEFAULT].renderer.extra_matcher) + uops = linearize_uop([a1,a2], extra_pm=Device[Device.DEFAULT].renderer.extra_matcher) Device[Device.DEFAULT].renderer.render("test", uops) self.assertEqual(uops[-1].arg, BinaryOps.SHR) self.assertEqual(uops[-2].arg, BinaryOps.IDIV) @@ -377,7 +377,7 @@ class TestIndexingOrdering(unittest.TestCase): buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 0) st0 = UOp(UOps.STORE, dtypes.float.vec(4), (buf, UOp.const(dtypes.int, 0), UOp.const(dtypes.float.vec(4), 42))) st1 = UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, 4), UOp.const(dtypes.float, 10))) - uops = UOpGraph([st1, st0]).linearize(skip_check=True) + uops = linearize_uop([st1, st0], skip_check=True) stores = [st for st in uops if st.op is UOps.STORE] assert stores[0].src[1] < stores[1].src[1], f"stored at idx {stores[1].src[1].arg} AFTER {stores[0].src[1].arg}" @@ -389,7 +389,7 @@ class TestIndexingOrdering(unittest.TestCase): st1_0 = UOp(UOps.STORE, dtypes.float, (buf0, UOp.const(dtypes.int, 4), UOp.const(dtypes.float, 10))) st0_1 = UOp(UOps.STORE, dtypes.float.vec(4), (buf1, UOp.const(dtypes.int, 0), UOp.const(dtypes.float.vec(4), 42))) st1_1 = UOp(UOps.STORE, dtypes.float, (buf1, UOp.const(dtypes.int, 4), UOp.const(dtypes.float, 10))) - uops = UOpGraph([st0_0, st1_0, st0_1, st1_1]).linearize(skip_check=True) + uops = linearize_uop([st0_0, st1_0, st0_1, st1_1], skip_check=True) stores = [st for st in uops if st.op is UOps.STORE] print("\n".join(map(str, stores))) # buf0 stores come first @@ -405,7 +405,7 @@ class TestIndexingOrdering(unittest.TestCase): gidx0 = UOp(UOps.SPECIAL, dtypes.int, (), ('gidx0', 4)) st0 = UOp(UOps.STORE, dtypes.float.vec(4), (buf, gidx0+UOp.const(dtypes.int, 0), UOp.const(dtypes.float.vec(4), 42))) st1 = UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, 4), UOp.const(dtypes.float, 10))) - uops = UOpGraph([st1, st0]).linearize(skip_check=True) + uops = linearize_uop([st1, st0], skip_check=True) stores = [st for st in uops if st.op is UOps.STORE] assert stores[0].src[1] < stores[1].src[1], f"stored at idx {stores[1].src[1].arg} AFTER {stores[0].src[1].arg}" diff --git a/test/test_uops_stats.py b/test/test_uops_stats.py index efa356b5..5f1dfcfa 100644 --- a/test/test_uops_stats.py +++ b/test/test_uops_stats.py @@ -3,7 +3,7 @@ from tinygrad import Tensor from tinygrad.helpers import getenv from tinygrad.engine.schedule import create_schedule from tinygrad.engine.realize import lower_schedule_item -from tinygrad.codegen.uopgraph import UOpGraph +from tinygrad.codegen.uopgraph import linearize_uop from tinygrad.ops import BinaryOps, TernaryOps, flops_mem, UOps, UOp from tinygrad.dtype import dtypes from tinygrad.codegen.kernel import Kernel, Opt, OptOps, KernelOptError @@ -105,7 +105,7 @@ class TestUOpsStats(unittest.TestCase): u3 = UOp(UOps.CONST, dtypes.int, tuple(), 3) u4 = UOp(UOps.ALU, dtypes.int, (u1,u2), BinaryOps.MUL) u5 = UOp(UOps.ALU, dtypes.int, (u4,u3), BinaryOps.ADD) - uops = UOpGraph([u5]) + uops = linearize_uop([u5]) globl = UOp(UOps.DEFINE_GLOBAL, dtypes.int, tuple()) o1 = UOp(UOps.CONST, dtypes.int, tuple(), 1) @@ -114,9 +114,9 @@ class TestUOpsStats(unittest.TestCase): u2 = UOp(UOps.LOAD, dtypes.int, (globl, o2)) u3 = UOp(UOps.CONST, dtypes.int, tuple(), 3) u4 = UOp(UOps.ALU, dtypes.int, (u1,u2,u3), TernaryOps.MULACC) - uops_fma = UOpGraph([u4]) + uops_fma = linearize_uop([u4]) - self.assertEqual(flops_mem(uops.linearize()), flops_mem(uops_fma.linearize())) + self.assertEqual(flops_mem(uops), flops_mem(uops_fma)) N = 100 @unittest.skipIf(getenv("PTX"), "wrong in PTX") # maybe? diff --git a/test/unit/test_uop_symbolic.py b/test/unit/test_uop_symbolic.py index 4c9470b0..23e25ba2 100644 --- a/test/unit/test_uop_symbolic.py +++ b/test/unit/test_uop_symbolic.py @@ -9,15 +9,14 @@ from typing import Tuple from tinygrad.helpers import DEBUG from tinygrad.dtype import dtypes, PtrDType, ConstType -from tinygrad.codegen.uopgraph import UOpGraph +from tinygrad.codegen.uopgraph import linearize_uop from tinygrad.ops import BinaryOps, UOp, UOps, print_uops import functools def render(self) -> Tuple[str, ConstType, ConstType]: # NOTE: we need STORE so the ALU op has children glbl = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=0) - graph = UOpGraph([UOp(UOps.STORE, None, (glbl, UOp.const(dtypes.int, 0), self))]) - uops = graph.linearize() + uops = linearize_uop([UOp(UOps.STORE, None, (glbl, UOp.const(dtypes.int, 0), self))]) if DEBUG>=5: print_uops(uops) from tinygrad.renderer.cstyle import CStyleLanguage class TestRenderer(CStyleLanguage): diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index 0a95b3b3..819375b3 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -13,7 +13,7 @@ from tinygrad.helpers import all_same, colored, ansilen, dedup, getenv, prod, DE from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.symbolic import Variable, sint from tinygrad.shape.view import strides_for_shape -from tinygrad.codegen.uopgraph import UOpGraph +from tinygrad.codegen.uopgraph import linearize_uop from tinygrad.codegen.lowerer import ast_to_uop from enum import Enum, auto @@ -737,8 +737,7 @@ class Kernel: print(self.applied_opts) verify_ast(modified_ast) - # generate the UOpGraph - self.uops:List[UOp] = UOpGraph(ast_to_uop(modified_ast, self.opts), self.opts).linearize(self.opts.extra_matcher) + self.uops:List[UOp] = linearize_uop(ast_to_uop(modified_ast, self.opts), self.opts, extra_pm=self.opts.extra_matcher) if DEBUG >= 5: print_uops(self.uops) if getenv("GRAPHUOPS"): from tinygrad.engine.graph import graph_uops diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index b2c322a0..e5e67695 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -517,104 +517,100 @@ def graph_rewrite(sink:UOp, pm:PatternMatcher) -> UOp: return found return __inner_rewrite(sink) -class UOpGraph: - def __init__(self, sink:Union[UOp, List[UOp]], opts:Optional[Renderer]=None): - self.sink: UOp = sink if isinstance(sink, UOp) else UOp(UOps.SINK, None, tuple(sink)) - assert self.sink.op is UOps.SINK, f"sink isn't sink, it's {self.sink.op}" - self.opts = opts - self.folder = constant_folder + transcendental_folding({} if TRANSCENDENTAL >= 2 or opts is None else opts.code_for_op.keys()) +linearize_cnt = 0 +def linearize_uop(sink_in:Union[UOp, List[UOp]], opts:Optional[Renderer]=None, extra_pm:Optional[PatternMatcher]=None, skip_check=False) -> List[UOp]: + global linearize_cnt, acc_number + sink: UOp = sink_in if isinstance(sink_in, UOp) else UOp(UOps.SINK, None, tuple(sink_in)) + assert sink.op is UOps.SINK, f"sink isn't sink, it's {sink.op}" + folder = constant_folder + transcendental_folding({} if TRANSCENDENTAL >= 2 or opts is None else opts.code_for_op.keys()) - cnt = 0 - def linearize(self, extra_pm:Optional[PatternMatcher]=None, skip_check=False) -> List[UOp]: - global acc_number - acc_number = 0 + # do graph rewrite + acc_number = 0 + sink = graph_rewrite(sink, folder) - # do graph rewrite - sink = graph_rewrite(self.sink, self.folder) + # rewrite pyint to int32 + sink = graph_rewrite(sink, PatternMatcher([(UPat({UOps.CONST, UOps.ALU, UOps.SPECIAL, UOps.RANGE}, dtype=dtypes.pyint, name="x"), + lambda x: UOp(x.op, dtypes.int32, x.src, x.arg))])) - # rewrite pyint to int32 - sink = graph_rewrite(sink, PatternMatcher([(UPat({UOps.CONST, UOps.ALU, UOps.SPECIAL, UOps.RANGE}, dtype=dtypes.pyint, name="x"), - lambda x: UOp(x.op, dtypes.int32, x.src, x.arg))])) + # expand + linearize_cnt += 1 + if linearize_cnt != getenv("DEBUG_EXPAND", 0): + sink = graph_rewrite(sink, folder+expander+float4_folding if opts is not None and opts.supports_float4 else folder+expander) + sink = graph_rewrite(sink, folder+expander+reducer) - # expand - UOpGraph.cnt += 1 - if UOpGraph.cnt != getenv("DEBUG_EXPAND", 0): - sink = graph_rewrite(sink, self.folder+expander+float4_folding if self.opts is not None and self.opts.supports_float4 else self.folder+expander) - sink = graph_rewrite(sink, self.folder+expander+reducer) + # for PTX only + if extra_pm: sink = graph_rewrite(sink, folder+extra_pm) - # for PTX only - if extra_pm: sink = graph_rewrite(sink, self.folder+extra_pm) + # filter nodes that don't link to a sink + # BFS toposort + children: Dict[UOp, List[UOp]] = {} + range_srcs: Dict[UOp, Dict[UOp, None]] = {} + in_degree: Dict[UOp, int] = {} + get_children_dfs(sink, children, range_srcs, in_degree) - # filter nodes that don't link to a sink - # BFS toposort - children: Dict[UOp, List[UOp]] = {} - range_srcs: Dict[UOp, Dict[UOp, None]] = {} - in_degree: Dict[UOp, int] = {} - get_children_dfs(sink, children, range_srcs, in_degree) + @functools.lru_cache(None) + def get_recursive_children(x:UOp, end:UOps, include_self=False) -> Set[UOp]: + if x.op is UOps.SINK: return set() + return set.union({x} if include_self else set(), *([get_recursive_children(u, end, True) for u in children[x] if x.op is not end])) - @functools.lru_cache(None) - def get_recursive_children(x:UOp, end:UOps, include_self=False) -> Set[UOp]: - if x.op is UOps.SINK: return set() - return set.union({x} if include_self else set(), *([get_recursive_children(u, end, True) for u in children[x] if x.op is not end])) + # scope children impact the toposort and END* insertion + scope_children = {p:get_recursive_children(p, END_FOR_UOP[p.op][0]) for p in reversed(in_degree) if p.op in END_FOR_UOP} + range_phi = {r:[p for p in scope_children[r] if p.op is UOps.PHI] for r in scope_children if r.op is UOps.RANGE} - # scope children impact the toposort and END* insertion - scope_children = {p:get_recursive_children(p, END_FOR_UOP[p.op][0]) for p in reversed(in_degree) if p.op in END_FOR_UOP} - range_phi = {r:[p for p in scope_children[r] if p.op is UOps.PHI] for r in scope_children if r.op is UOps.RANGE} + queue:List[Tuple[int, UOp]] = [] + def push(u:UOp): + priority = 0 + # prefer ranges that depend on the least number of independent ranges + if u.op is UOps.RANGE and u.arg[1]: + priority += u.arg[0] + for p in range_phi[u]: + priority += 10000*len([r for r in range_srcs[p] if not any(i in range_phi[u] for i in range_phi[r])]) + # prefer uops that are loop children + else: + priority -= sum([(l.arg[0]+1) + 1000*l.arg[1] for l,ss in scope_children.items() if l.op is UOps.RANGE and u in ss]) + heapq.heappush(queue, (priority, u)) - queue:List[Tuple[int, UOp]] = [] - def push(u:UOp): - priority = 0 - # prefer ranges that depend on the least number of independent ranges - if u.op is UOps.RANGE and u.arg[1]: - priority += u.arg[0] - for p in range_phi[u]: - priority += 10000*len([r for r in range_srcs[p] if not any(i in range_phi[u] for i in range_phi[r])]) - # prefer uops that are loop children - else: - priority -= sum([(l.arg[0]+1) + 1000*l.arg[1] for l,ss in scope_children.items() if l.op is UOps.RANGE and u in ss]) - heapq.heappush(queue, (priority, u)) + for u in children: + if in_degree[u] == 0: push(u) - for u in children: + scope_end: Dict[UOp, UOp] = {} + _uops: List[UOp] = [] + while queue: + p,x = heapq.heappop(queue) + if DEBUG >= 7: print(f"{p:5d}",x) + if x in scope_children: scope_end[x] = x + if x.op is UOps.DEFINE_ACC: + idx = min([_uops.index(l) for l in x.src if l.op is UOps.RANGE]) + _uops.insert(idx, x) + else: _uops.append(x) + for u, ss in scope_children.items(): + if x in ss: + ss.remove(x) + if len(ss) == 0: scope_end[u] = x + for u in children[x]: + in_degree[u] -= 1 if in_degree[u] == 0: push(u) - scope_end: Dict[UOp, UOp] = {} - _uops: List[UOp] = [] - while queue: - p,x = heapq.heappop(queue) - if DEBUG >= 7: print(f"{p:5d}",x) - if x in scope_children: scope_end[x] = x - if x.op is UOps.DEFINE_ACC: - idx = min([_uops.index(l) for l in x.src if l.op is UOps.RANGE]) - _uops.insert(idx, x) - else: _uops.append(x) - for u, ss in scope_children.items(): - if x in ss: - ss.remove(x) - if len(ss) == 0: scope_end[u] = x - for u in children[x]: - in_degree[u] -= 1 - if in_degree[u] == 0: push(u) + # end scopes in toposort order + for u, x in scope_end.items(): _uops.insert(_uops.index(x)+1, UOp(END_FOR_UOP[u.op][1], None, (u,))) - # end scopes in toposort order - for u, x in scope_end.items(): _uops.insert(_uops.index(x)+1, UOp(END_FOR_UOP[u.op][1], None, (u,))) + # sanity checks (NOTE: these can cause things to be skipped in BEAM) + if not skip_check: + bad_ops = dedup([x.op for x in _uops if x.op in {UOps.EXPAND, UOps.CONTRACT, UOps.REDUCE}]) + try: + type_verify(_uops) + assert _uops[-1].op is UOps.SINK, f"didn't end with SINK, ended with {_uops[-1]}" + assert len(bad_ops) == 0, f"bad UOps left in list: {bad_ops}" + # TODO: this should be enabled, and the valid clause should be removed + # NOTE: multiple identical stores to DEFINE_LOCAL is okay + assert len(all_stores := [x.src[0:2]+x.src[3:] for x in _uops if x.op is UOps.STORE and x.src[0].op is not UOps.DEFINE_LOCAL]) \ + == len(dedup(all_stores)), "repeated stores in uops" + except AssertionError as e: + print_uops(_uops) + if not CI: + from tinygrad.engine.graph import graph_uops + graph_uops(_uops) + raise e - # sanity checks (NOTE: these can cause things to be skipped in BEAM) - if not skip_check: - bad_ops = dedup([x.op for x in _uops if x.op in {UOps.EXPAND, UOps.CONTRACT, UOps.REDUCE}]) - try: - type_verify(_uops) - assert _uops[-1].op is UOps.SINK, f"didn't end with SINK, ended with {_uops[-1]}" - assert len(bad_ops) == 0, f"bad UOps left in list: {bad_ops}" - # TODO: this should be enabled, and the valid clause should be removed - # NOTE: multiple identical stores to DEFINE_LOCAL is okay - assert len(all_stores := [x.src[0:2]+x.src[3:] for x in _uops if x.op is UOps.STORE and x.src[0].op is not UOps.DEFINE_LOCAL]) \ - == len(dedup(all_stores)), "repeated stores in uops" - except AssertionError as e: - print_uops(_uops) - if not CI: - from tinygrad.engine.graph import graph_uops - graph_uops(_uops) - raise e - - # strip the SINK - return _uops[:-1] + # strip the SINK + return _uops[:-1]