mirror of https://github.com/commaai/tinygrad.git
UOpGraph -> linearize_uop [run_process_replay] (#6119)
This commit is contained in:
parent
7cae152aa2
commit
912f01ed4b
|
@ -3,8 +3,7 @@ from collections import defaultdict
|
|||
import numpy as np
|
||||
from dataclasses import replace
|
||||
from typing import DefaultDict, Dict, List, Tuple
|
||||
from tinygrad.ops import END_FOR_UOP, UOp
|
||||
from tinygrad.codegen.uopgraph import UOpGraph
|
||||
from tinygrad.ops import END_FOR_UOP, UOp, print_uops
|
||||
from tinygrad.device import Buffer, Device
|
||||
from tinygrad.engine.realize import CompiledRunner
|
||||
from tinygrad.helpers import DEBUG, colored
|
||||
|
@ -49,10 +48,9 @@ class UOpsFuzzerRunner(CompiledRunner):
|
|||
|
||||
for i, path in enumerate(fuzz_paths):
|
||||
# setup prg
|
||||
uops = UOpGraph([])
|
||||
uops._uops = list(path)
|
||||
if DEBUG >= 5: uops.print()
|
||||
self.p = replace(self.p, name=(name:=f"{init_name}fuzz{i}"), src=Device[self.p.dname].renderer.render(name, uops.uops), uops=uops.uops)
|
||||
uops = list(path)
|
||||
if DEBUG >= 5: print_uops(uops)
|
||||
self.p = replace(self.p, name=(name:=f"{init_name}fuzz{i}"), src=Device[self.p.dname].renderer.render(name, uops), uops=uops)
|
||||
if DEBUG >= 4: print(self.p.src)
|
||||
self.lib = Device[self.p.dname].compiler.compile_cached(self.p.src)
|
||||
self.clprg = Device[self.p.dname].runtime(name, self.lib)
|
||||
|
|
|
@ -4,7 +4,7 @@ from tinygrad import dtypes, Variable
|
|||
from tinygrad.dtype import PtrDType
|
||||
from tinygrad.helpers import DEBUG
|
||||
from tinygrad.ops import BinaryOps, TernaryOps, UnaryOps, ReduceOps, UOps, UOp, NOp, PatternMatcher
|
||||
from tinygrad.codegen.uopgraph import UOpGraph, graph_rewrite, expander, reducer, constant_folder, float4_folding
|
||||
from tinygrad.codegen.uopgraph import linearize_uop, graph_rewrite, expander, reducer, constant_folder, float4_folding
|
||||
|
||||
simple_pm = PatternMatcher([
|
||||
(NOp.cvar('x', dtypes.int), lambda x: UOp.const(dtypes.float, 1.0) + UOp.const(dtypes.float, 2.0)),
|
||||
|
@ -94,7 +94,7 @@ class TestUOpGraph(TestUOps):
|
|||
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
|
||||
c2 = UOp(UOps.CONST, dtypes.float, arg=2.0)
|
||||
out = UOp(UOps.ALU, dtypes.float, (c1, c2), BinaryOps.ADD)
|
||||
uops = UOpGraph([out]).linearize()
|
||||
uops = linearize_uop([out])
|
||||
self.assertEqual(len(uops), 1)
|
||||
out = uops[-1]
|
||||
self.assertEqual(out.op, UOps.CONST)
|
||||
|
@ -106,7 +106,7 @@ class TestUOpGraph(TestUOps):
|
|||
vc = UOp(UOps.ALU, dtypes.bool, (v, c0), BinaryOps.CMPNE)
|
||||
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
|
||||
out = UOp(UOps.ALU, dtypes.float, (vc, c1, c1), TernaryOps.WHERE)
|
||||
uops = UOpGraph([out]).linearize()
|
||||
uops = linearize_uop([out])
|
||||
self.assertEqual(len(uops), 1)
|
||||
out = uops[-1]
|
||||
self.assertEqual(out.op, UOps.CONST)
|
||||
|
@ -117,7 +117,7 @@ class TestUOpGraph(TestUOps):
|
|||
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
|
||||
c2 = UOp(UOps.CONST, dtypes.float, arg=2.0)
|
||||
out = UOp(UOps.ALU, dtypes.float, (bf, c1, c2), TernaryOps.WHERE)
|
||||
uops = UOpGraph([out]).linearize()
|
||||
uops = linearize_uop([out])
|
||||
self.assertEqual(len(uops), 1)
|
||||
out = uops[-1]
|
||||
self.assertEqual(out.op, UOps.CONST)
|
||||
|
@ -126,7 +126,7 @@ class TestUOpGraph(TestUOps):
|
|||
def test_const_cast(self):
|
||||
bf = UOp(UOps.CONST, dtypes.bool, arg=False)
|
||||
out = UOp(UOps.CAST, dtypes.int, (bf,))
|
||||
uops = UOpGraph([out]).linearize()
|
||||
uops = linearize_uop([out])
|
||||
self.assertEqual(len(uops), 1)
|
||||
out = uops[-1]
|
||||
self.assertEqual(out.op, UOps.CONST)
|
||||
|
@ -140,7 +140,7 @@ class TestUOpGraph(TestUOps):
|
|||
x = UOp(UOps.GEP, dtypes.float, (vec, ), arg=0)
|
||||
alu = UOp(UOps.ALU, dtypes.float, (x, ), UnaryOps.SQRT)
|
||||
out = UOp(UOps.STORE, None, (d0, idx, alu))
|
||||
uops = UOpGraph([out]).linearize()
|
||||
uops = linearize_uop([out])
|
||||
self.assertEqual(len([x for x in uops if x.op is UOps.VECTORIZE]), 0)
|
||||
|
||||
def test_gep_vec_fold(self):
|
||||
|
@ -151,7 +151,7 @@ class TestUOpGraph(TestUOps):
|
|||
def _test_vec(geps, count=4):
|
||||
vec = UOp(UOps.VECTORIZE, dtypes.float.vec(count), geps)
|
||||
out = UOp(UOps.STORE, None, (d0, idx, vec))
|
||||
uops = UOpGraph([out]).linearize()
|
||||
uops = linearize_uop([out])
|
||||
if DEBUG >= 4:
|
||||
from tinygrad import Device
|
||||
print(Device[Device.DEFAULT].renderer.render("test", uops))
|
||||
|
@ -187,7 +187,7 @@ class TestUOpGraph(TestUOps):
|
|||
consts = [UOp.const(dtypes.float, float(i)) for i in range(vec_size)]
|
||||
vec = UOp(UOps.VECTORIZE, dtypes.float.vec(vec_size), tuple(consts))
|
||||
geps = [UOp(UOps.GEP, dtypes.float, (vec,), i) for i in range(vec_size)]
|
||||
uops = UOpGraph(geps).linearize()
|
||||
uops = linearize_uop(geps)
|
||||
for uop, const in zip(uops, consts):
|
||||
self.assert_equiv_uops(uop, const)
|
||||
|
||||
|
@ -197,7 +197,7 @@ class TestUOpGraph(TestUOps):
|
|||
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i))
|
||||
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
|
||||
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc))
|
||||
uops = UOpGraph([wmma]).linearize()
|
||||
uops = linearize_uop([wmma])
|
||||
self.assert_equiv_uops(uops[0], acc)
|
||||
self.assertEqual(len(uops), 1)
|
||||
|
||||
|
@ -206,7 +206,7 @@ class TestUOpGraph(TestUOps):
|
|||
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i), tuple(UOp.const(dtypes.half, 0.0) for _ in range(i)))
|
||||
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
|
||||
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc))
|
||||
uops = UOpGraph([wmma]).linearize()
|
||||
uops = linearize_uop([wmma])
|
||||
self.assert_equiv_uops(uops[0], acc)
|
||||
self.assertEqual(len(uops), 1)
|
||||
|
||||
|
@ -218,7 +218,7 @@ class TestUOpGraph(TestUOps):
|
|||
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable(f'tmp{i}', 0.0, 1.0))
|
||||
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
|
||||
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc))
|
||||
uops = UOpGraph([wmma]).linearize()
|
||||
uops = linearize_uop([wmma])
|
||||
self.assert_equiv_uops(uops[-1], wmma)
|
||||
|
||||
for i in [4, 8]:
|
||||
|
@ -228,7 +228,7 @@ class TestUOpGraph(TestUOps):
|
|||
tuple(UOp(UOps.DEFINE_VAR, dtypes.half, arg=Variable(f'tmp{j}', 0.0, 1.0)) for j in range(i//2)))
|
||||
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
|
||||
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc))
|
||||
uops = UOpGraph([wmma]).linearize()
|
||||
uops = linearize_uop([wmma])
|
||||
self.assert_equiv_uops(uops[-1], wmma)
|
||||
|
||||
for i in [2, 4, 8]:
|
||||
|
@ -237,7 +237,7 @@ class TestUOpGraph(TestUOps):
|
|||
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable(f'tmp{i}', 0.0, 1.0))
|
||||
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
|
||||
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc))
|
||||
uops = UOpGraph([wmma]).linearize()
|
||||
uops = linearize_uop([wmma])
|
||||
self.assert_equiv_uops(uops[-1], wmma)
|
||||
|
||||
for i in [2, 4, 8]:
|
||||
|
@ -246,7 +246,7 @@ class TestUOpGraph(TestUOps):
|
|||
tuple(UOp.const(dtypes.half, 1.0 if j == 0 else 0.0) for j in range(i)))
|
||||
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
|
||||
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc))
|
||||
uops = UOpGraph([wmma]).linearize()
|
||||
uops = linearize_uop([wmma])
|
||||
self.assert_equiv_uops(uops[-1], wmma)
|
||||
|
||||
def test_cast_alu_fold(self):
|
||||
|
@ -256,7 +256,7 @@ class TestUOpGraph(TestUOps):
|
|||
ld = UOp(UOps.LOAD, dtypes.int, (d1, idx))
|
||||
alu = ld.lt(1).cast(dtypes.bool)
|
||||
out = UOp(UOps.STORE, None, (d0, idx, alu))
|
||||
uops = UOpGraph([out]).linearize()
|
||||
uops = linearize_uop([out])
|
||||
self.assertEqual(len([x for x in uops if x.op is UOps.CAST]), 0)
|
||||
|
||||
def test_double_cast_fold(self):
|
||||
|
@ -266,7 +266,7 @@ class TestUOpGraph(TestUOps):
|
|||
ld = UOp(UOps.LOAD, dtypes.int, (d1, idx))
|
||||
alu = ld.cast(dtypes.float).cast(dtypes.float)
|
||||
out = UOp(UOps.STORE, None, (d0, idx, alu))
|
||||
uops = UOpGraph([out]).linearize()
|
||||
uops = linearize_uop([out])
|
||||
self.assertEqual(len([x for x in uops if x.op is UOps.CAST]), 1)
|
||||
|
||||
def test_depth_2_const_fold(self):
|
||||
|
@ -275,7 +275,7 @@ class TestUOpGraph(TestUOps):
|
|||
c4 = UOp(UOps.CONST, dtypes.int, arg=4)
|
||||
vc = UOp(UOps.ALU, dtypes.int, (v, c2), BinaryOps.ADD)
|
||||
out = UOp(UOps.ALU, dtypes.int, (vc, c4), BinaryOps.ADD)
|
||||
uops = UOpGraph([out]).linearize()
|
||||
uops = linearize_uop([out])
|
||||
self.assertEqual(len(uops), 5)
|
||||
out = uops[-1]
|
||||
self.assertEqual(out.op, UOps.ALU)
|
||||
|
@ -290,7 +290,7 @@ class TestUOpGraph(TestUOps):
|
|||
idx = UOp.const(dtypes.int, 0)
|
||||
ld0 = UOp(UOps.LOAD, dtypes.int, (glbl1, idx, UOp.const(dtypes.int, 2), UOp.const(dtypes.bool, False)))
|
||||
ld1 = UOp(UOps.LOAD, dtypes.int, (glbl2, idx, UOp.const(dtypes.int, 3), UOp.const(dtypes.bool, True)))
|
||||
uops = UOpGraph([UOp(UOps.STORE, None, (glbl0, idx, ld1+ld0))]).linearize()
|
||||
uops = linearize_uop([UOp(UOps.STORE, None, (glbl0, idx, ld1+ld0))])
|
||||
ld0, ld1 = uops[-1].src[2].src
|
||||
# ld0 becomes the invalid value
|
||||
self.assert_equiv_uops(ld1, UOp.const(dtypes.int, 2))
|
||||
|
@ -305,7 +305,7 @@ class TestUOpGraph(TestUOps):
|
|||
barrier = UOp(UOps.BARRIER, None, (st, ))
|
||||
ld0 = UOp(UOps.LOAD, dtypes.int, (smem, lidx+1, UOp.const(dtypes.int, 2), UOp.const(dtypes.bool, False), barrier))
|
||||
ld1 = UOp(UOps.LOAD, dtypes.int, (smem, lidx+2, UOp.const(dtypes.int, 3), UOp.const(dtypes.bool, True), barrier))
|
||||
uops = UOpGraph([UOp(UOps.STORE, None, (glbl0, lidx, ld1+ld0))]).linearize()
|
||||
uops = linearize_uop([UOp(UOps.STORE, None, (glbl0, lidx, ld1+ld0))])
|
||||
ld0, ld1 = uops[-1].src[2].src
|
||||
# ld0 becomes the invalid value
|
||||
self.assert_equiv_uops(ld1, UOp.const(dtypes.int, 2))
|
||||
|
@ -319,7 +319,7 @@ class TestUOpGraph(TestUOps):
|
|||
val = UOp.const(dtypes.int, 42)
|
||||
st0 = UOp(UOps.STORE, None, (glbl, idx0, val, UOp.const(dtypes.bool, False)))
|
||||
st1 = UOp(UOps.STORE, None, (glbl, idx1, val, UOp.const(dtypes.bool, True)))
|
||||
uops = UOpGraph([st0, st1]).linearize()
|
||||
uops = linearize_uop([st0, st1])
|
||||
# only the second store happens
|
||||
self.assertEqual(len(uops), 4)
|
||||
self.assert_equiv_uops(uops[-1], UOp.store(glbl, idx1, val))
|
||||
|
@ -328,8 +328,7 @@ class TestUOpGraph(TestUOps):
|
|||
glbl0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0)
|
||||
idx = UOp.const(dtypes.int, 0)
|
||||
bad_gate = UOp.const(dtypes.int, 1)
|
||||
uops = UOpGraph([UOp(UOps.STORE, None, (glbl0, idx, UOp.const(dtypes.int, 42), bad_gate))])
|
||||
with self.assertRaises(AssertionError): uops.linearize()
|
||||
with self.assertRaises(AssertionError): linearize_uop([UOp(UOps.STORE, None, (glbl0, idx, UOp.const(dtypes.int, 42), bad_gate))])
|
||||
|
||||
def test_switched_range_order(self):
|
||||
glbl = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0)
|
||||
|
@ -340,7 +339,7 @@ class TestUOpGraph(TestUOps):
|
|||
r2 = UOp(UOps.RANGE, dtypes.int, (c0, c2), (1, 1, False))
|
||||
alu = UOp(UOps.ALU, dtypes.int, (r2, r1), BinaryOps.MUL)
|
||||
store = UOp(UOps.STORE, None, (glbl, alu, cf))
|
||||
uops = UOpGraph([store]).linearize()
|
||||
uops = linearize_uop([store])
|
||||
ranges = [x for x in uops if x.op is UOps.RANGE]
|
||||
endranges = [x for x in uops if x.op is UOps.ENDRANGE]
|
||||
# ranges are closed in the right order
|
||||
|
|
|
@ -9,11 +9,11 @@ from tinygrad.ops import UOps, NOp, UOp, UnaryOps, BinaryOps, TernaryOps, Reduce
|
|||
from tinygrad.renderer import Program
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.engine.realize import CompiledRunner, lower_schedule_item, get_kernel
|
||||
from tinygrad.codegen.uopgraph import UOpGraph
|
||||
from tinygrad.codegen.uopgraph import linearize_uop
|
||||
from test.helpers import is_dtype_supported, TestUOps as TestEqUOps
|
||||
|
||||
def _uops_to_prg(uops_list):
|
||||
uops = UOpGraph(uops_list).linearize(Device[Device.DEFAULT].renderer.extra_matcher)
|
||||
uops = linearize_uop(uops_list, extra_pm=Device[Device.DEFAULT].renderer.extra_matcher)
|
||||
src = Device[Device.DEFAULT].renderer.render("test", uops)
|
||||
has_local = Device[Device.DEFAULT].renderer.has_local
|
||||
return CompiledRunner(Program("test", src, Device.DEFAULT, uops=uops,
|
||||
|
@ -247,7 +247,7 @@ class TestGatedStoreRewrite(unittest.TestCase):
|
|||
val = UOp.const(dtypes.float, 42.0)
|
||||
gate = gidx0.lt(UOp.const(dtypes.int, 1))
|
||||
store = UOp(UOps.STORE, None, (gmem, idx, val, gate))
|
||||
uops = UOpGraph([store])
|
||||
uops = linearize_uop([store])
|
||||
if DEBUG >= 4: print(Device[Device.DEFAULT].renderer.render("test", uops))
|
||||
if_uop = next(u for u in uops if u.op is UOps.IF)
|
||||
endif = next(u for u in uops if u.op is UOps.ENDIF)
|
||||
|
@ -265,7 +265,7 @@ class TestGatedStoreRewrite(unittest.TestCase):
|
|||
val = UOp.const(dtypes.float, 42.0)
|
||||
gate = gidx0.lt(UOp.const(dtypes.int, 1))
|
||||
stores = [UOp.store(gmem0, idx, val, gate), UOp.store(gmem1, idx, val)]
|
||||
uops = UOpGraph(stores)
|
||||
uops = linearize_uop(stores)
|
||||
if DEBUG >= 4: print(Device[Device.DEFAULT].renderer.render("test", uops))
|
||||
if_uop = next(u for u in uops if u.op is UOps.IF)
|
||||
endif = next(u for u in uops if u.op is UOps.ENDIF)
|
||||
|
@ -284,7 +284,7 @@ class TestGatedStoreRewrite(unittest.TestCase):
|
|||
val = UOp.const(dtypes.float, 42.0)
|
||||
gate = gidx0.lt(UOp.const(dtypes.int, 1))
|
||||
stores = [UOp.store(gmem0, idx, val, gate), UOp.store(gmem1, idx, val, gate)]
|
||||
uops = UOpGraph(stores)
|
||||
uops = linearize_uop(stores)
|
||||
if DEBUG >= 4: print(Device[Device.DEFAULT].renderer.render("test", uops))
|
||||
ifs = [u for u in uops if u.op is UOps.IF]
|
||||
endifs = [u for u in uops if u.op is UOps.ENDIF]
|
||||
|
@ -326,7 +326,7 @@ class TestAssembly(unittest.TestCase):
|
|||
l1 = UOp(UOps.LOAD, dtypes.int, (g1, c1))
|
||||
a1 = UOp(UOps.ALU, dtypes.int, (l1, c1), BinaryOps.MUL)
|
||||
a2 = UOp(UOps.ALU, dtypes.int, (l1, c2), BinaryOps.MUL)
|
||||
uops = UOpGraph([a1,a2]).linearize(Device[Device.DEFAULT].renderer.extra_matcher)
|
||||
uops = linearize_uop([a1,a2], extra_pm=Device[Device.DEFAULT].renderer.extra_matcher)
|
||||
Device[Device.DEFAULT].renderer.render("test", uops)
|
||||
self.assertEqual(uops[-1].arg, BinaryOps.SHL)
|
||||
self.assertEqual(uops[-2].arg, BinaryOps.MUL)
|
||||
|
@ -338,7 +338,7 @@ class TestAssembly(unittest.TestCase):
|
|||
l1 = UOp(UOps.LOAD, dtypes.int, (g1, c1))
|
||||
a1 = UOp(UOps.ALU, dtypes.int, (l1, c1), BinaryOps.IDIV)
|
||||
a2 = UOp(UOps.ALU, dtypes.int, (l1, c2), BinaryOps.IDIV)
|
||||
uops = UOpGraph([a1,a2]).linearize(Device[Device.DEFAULT].renderer.extra_matcher)
|
||||
uops = linearize_uop([a1,a2], extra_pm=Device[Device.DEFAULT].renderer.extra_matcher)
|
||||
Device[Device.DEFAULT].renderer.render("test", uops)
|
||||
self.assertEqual(uops[-1].arg, BinaryOps.SHR)
|
||||
self.assertEqual(uops[-2].arg, BinaryOps.IDIV)
|
||||
|
@ -377,7 +377,7 @@ class TestIndexingOrdering(unittest.TestCase):
|
|||
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 0)
|
||||
st0 = UOp(UOps.STORE, dtypes.float.vec(4), (buf, UOp.const(dtypes.int, 0), UOp.const(dtypes.float.vec(4), 42)))
|
||||
st1 = UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, 4), UOp.const(dtypes.float, 10)))
|
||||
uops = UOpGraph([st1, st0]).linearize(skip_check=True)
|
||||
uops = linearize_uop([st1, st0], skip_check=True)
|
||||
stores = [st for st in uops if st.op is UOps.STORE]
|
||||
assert stores[0].src[1] < stores[1].src[1], f"stored at idx {stores[1].src[1].arg} AFTER {stores[0].src[1].arg}"
|
||||
|
||||
|
@ -389,7 +389,7 @@ class TestIndexingOrdering(unittest.TestCase):
|
|||
st1_0 = UOp(UOps.STORE, dtypes.float, (buf0, UOp.const(dtypes.int, 4), UOp.const(dtypes.float, 10)))
|
||||
st0_1 = UOp(UOps.STORE, dtypes.float.vec(4), (buf1, UOp.const(dtypes.int, 0), UOp.const(dtypes.float.vec(4), 42)))
|
||||
st1_1 = UOp(UOps.STORE, dtypes.float, (buf1, UOp.const(dtypes.int, 4), UOp.const(dtypes.float, 10)))
|
||||
uops = UOpGraph([st0_0, st1_0, st0_1, st1_1]).linearize(skip_check=True)
|
||||
uops = linearize_uop([st0_0, st1_0, st0_1, st1_1], skip_check=True)
|
||||
stores = [st for st in uops if st.op is UOps.STORE]
|
||||
print("\n".join(map(str, stores)))
|
||||
# buf0 stores come first
|
||||
|
@ -405,7 +405,7 @@ class TestIndexingOrdering(unittest.TestCase):
|
|||
gidx0 = UOp(UOps.SPECIAL, dtypes.int, (), ('gidx0', 4))
|
||||
st0 = UOp(UOps.STORE, dtypes.float.vec(4), (buf, gidx0+UOp.const(dtypes.int, 0), UOp.const(dtypes.float.vec(4), 42)))
|
||||
st1 = UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, 4), UOp.const(dtypes.float, 10)))
|
||||
uops = UOpGraph([st1, st0]).linearize(skip_check=True)
|
||||
uops = linearize_uop([st1, st0], skip_check=True)
|
||||
stores = [st for st in uops if st.op is UOps.STORE]
|
||||
assert stores[0].src[1] < stores[1].src[1], f"stored at idx {stores[1].src[1].arg} AFTER {stores[0].src[1].arg}"
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@ from tinygrad import Tensor
|
|||
from tinygrad.helpers import getenv
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.engine.realize import lower_schedule_item
|
||||
from tinygrad.codegen.uopgraph import UOpGraph
|
||||
from tinygrad.codegen.uopgraph import linearize_uop
|
||||
from tinygrad.ops import BinaryOps, TernaryOps, flops_mem, UOps, UOp
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.codegen.kernel import Kernel, Opt, OptOps, KernelOptError
|
||||
|
@ -105,7 +105,7 @@ class TestUOpsStats(unittest.TestCase):
|
|||
u3 = UOp(UOps.CONST, dtypes.int, tuple(), 3)
|
||||
u4 = UOp(UOps.ALU, dtypes.int, (u1,u2), BinaryOps.MUL)
|
||||
u5 = UOp(UOps.ALU, dtypes.int, (u4,u3), BinaryOps.ADD)
|
||||
uops = UOpGraph([u5])
|
||||
uops = linearize_uop([u5])
|
||||
|
||||
globl = UOp(UOps.DEFINE_GLOBAL, dtypes.int, tuple())
|
||||
o1 = UOp(UOps.CONST, dtypes.int, tuple(), 1)
|
||||
|
@ -114,9 +114,9 @@ class TestUOpsStats(unittest.TestCase):
|
|||
u2 = UOp(UOps.LOAD, dtypes.int, (globl, o2))
|
||||
u3 = UOp(UOps.CONST, dtypes.int, tuple(), 3)
|
||||
u4 = UOp(UOps.ALU, dtypes.int, (u1,u2,u3), TernaryOps.MULACC)
|
||||
uops_fma = UOpGraph([u4])
|
||||
uops_fma = linearize_uop([u4])
|
||||
|
||||
self.assertEqual(flops_mem(uops.linearize()), flops_mem(uops_fma.linearize()))
|
||||
self.assertEqual(flops_mem(uops), flops_mem(uops_fma))
|
||||
|
||||
N = 100
|
||||
@unittest.skipIf(getenv("PTX"), "wrong in PTX") # maybe?
|
||||
|
|
|
@ -9,15 +9,14 @@ from typing import Tuple
|
|||
|
||||
from tinygrad.helpers import DEBUG
|
||||
from tinygrad.dtype import dtypes, PtrDType, ConstType
|
||||
from tinygrad.codegen.uopgraph import UOpGraph
|
||||
from tinygrad.codegen.uopgraph import linearize_uop
|
||||
from tinygrad.ops import BinaryOps, UOp, UOps, print_uops
|
||||
import functools
|
||||
|
||||
def render(self) -> Tuple[str, ConstType, ConstType]:
|
||||
# NOTE: we need STORE so the ALU op has children
|
||||
glbl = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=0)
|
||||
graph = UOpGraph([UOp(UOps.STORE, None, (glbl, UOp.const(dtypes.int, 0), self))])
|
||||
uops = graph.linearize()
|
||||
uops = linearize_uop([UOp(UOps.STORE, None, (glbl, UOp.const(dtypes.int, 0), self))])
|
||||
if DEBUG>=5: print_uops(uops)
|
||||
from tinygrad.renderer.cstyle import CStyleLanguage
|
||||
class TestRenderer(CStyleLanguage):
|
||||
|
|
|
@ -13,7 +13,7 @@ from tinygrad.helpers import all_same, colored, ansilen, dedup, getenv, prod, DE
|
|||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.symbolic import Variable, sint
|
||||
from tinygrad.shape.view import strides_for_shape
|
||||
from tinygrad.codegen.uopgraph import UOpGraph
|
||||
from tinygrad.codegen.uopgraph import linearize_uop
|
||||
from tinygrad.codegen.lowerer import ast_to_uop
|
||||
from enum import Enum, auto
|
||||
|
||||
|
@ -737,8 +737,7 @@ class Kernel:
|
|||
print(self.applied_opts)
|
||||
verify_ast(modified_ast)
|
||||
|
||||
# generate the UOpGraph
|
||||
self.uops:List[UOp] = UOpGraph(ast_to_uop(modified_ast, self.opts), self.opts).linearize(self.opts.extra_matcher)
|
||||
self.uops:List[UOp] = linearize_uop(ast_to_uop(modified_ast, self.opts), self.opts, extra_pm=self.opts.extra_matcher)
|
||||
if DEBUG >= 5: print_uops(self.uops)
|
||||
if getenv("GRAPHUOPS"):
|
||||
from tinygrad.engine.graph import graph_uops
|
||||
|
|
|
@ -517,104 +517,100 @@ def graph_rewrite(sink:UOp, pm:PatternMatcher) -> UOp:
|
|||
return found
|
||||
return __inner_rewrite(sink)
|
||||
|
||||
class UOpGraph:
|
||||
def __init__(self, sink:Union[UOp, List[UOp]], opts:Optional[Renderer]=None):
|
||||
self.sink: UOp = sink if isinstance(sink, UOp) else UOp(UOps.SINK, None, tuple(sink))
|
||||
assert self.sink.op is UOps.SINK, f"sink isn't sink, it's {self.sink.op}"
|
||||
self.opts = opts
|
||||
self.folder = constant_folder + transcendental_folding({} if TRANSCENDENTAL >= 2 or opts is None else opts.code_for_op.keys())
|
||||
linearize_cnt = 0
|
||||
def linearize_uop(sink_in:Union[UOp, List[UOp]], opts:Optional[Renderer]=None, extra_pm:Optional[PatternMatcher]=None, skip_check=False) -> List[UOp]:
|
||||
global linearize_cnt, acc_number
|
||||
sink: UOp = sink_in if isinstance(sink_in, UOp) else UOp(UOps.SINK, None, tuple(sink_in))
|
||||
assert sink.op is UOps.SINK, f"sink isn't sink, it's {sink.op}"
|
||||
folder = constant_folder + transcendental_folding({} if TRANSCENDENTAL >= 2 or opts is None else opts.code_for_op.keys())
|
||||
|
||||
cnt = 0
|
||||
def linearize(self, extra_pm:Optional[PatternMatcher]=None, skip_check=False) -> List[UOp]:
|
||||
global acc_number
|
||||
acc_number = 0
|
||||
# do graph rewrite
|
||||
acc_number = 0
|
||||
sink = graph_rewrite(sink, folder)
|
||||
|
||||
# do graph rewrite
|
||||
sink = graph_rewrite(self.sink, self.folder)
|
||||
# rewrite pyint to int32
|
||||
sink = graph_rewrite(sink, PatternMatcher([(UPat({UOps.CONST, UOps.ALU, UOps.SPECIAL, UOps.RANGE}, dtype=dtypes.pyint, name="x"),
|
||||
lambda x: UOp(x.op, dtypes.int32, x.src, x.arg))]))
|
||||
|
||||
# rewrite pyint to int32
|
||||
sink = graph_rewrite(sink, PatternMatcher([(UPat({UOps.CONST, UOps.ALU, UOps.SPECIAL, UOps.RANGE}, dtype=dtypes.pyint, name="x"),
|
||||
lambda x: UOp(x.op, dtypes.int32, x.src, x.arg))]))
|
||||
# expand
|
||||
linearize_cnt += 1
|
||||
if linearize_cnt != getenv("DEBUG_EXPAND", 0):
|
||||
sink = graph_rewrite(sink, folder+expander+float4_folding if opts is not None and opts.supports_float4 else folder+expander)
|
||||
sink = graph_rewrite(sink, folder+expander+reducer)
|
||||
|
||||
# expand
|
||||
UOpGraph.cnt += 1
|
||||
if UOpGraph.cnt != getenv("DEBUG_EXPAND", 0):
|
||||
sink = graph_rewrite(sink, self.folder+expander+float4_folding if self.opts is not None and self.opts.supports_float4 else self.folder+expander)
|
||||
sink = graph_rewrite(sink, self.folder+expander+reducer)
|
||||
# for PTX only
|
||||
if extra_pm: sink = graph_rewrite(sink, folder+extra_pm)
|
||||
|
||||
# for PTX only
|
||||
if extra_pm: sink = graph_rewrite(sink, self.folder+extra_pm)
|
||||
# filter nodes that don't link to a sink
|
||||
# BFS toposort
|
||||
children: Dict[UOp, List[UOp]] = {}
|
||||
range_srcs: Dict[UOp, Dict[UOp, None]] = {}
|
||||
in_degree: Dict[UOp, int] = {}
|
||||
get_children_dfs(sink, children, range_srcs, in_degree)
|
||||
|
||||
# filter nodes that don't link to a sink
|
||||
# BFS toposort
|
||||
children: Dict[UOp, List[UOp]] = {}
|
||||
range_srcs: Dict[UOp, Dict[UOp, None]] = {}
|
||||
in_degree: Dict[UOp, int] = {}
|
||||
get_children_dfs(sink, children, range_srcs, in_degree)
|
||||
@functools.lru_cache(None)
|
||||
def get_recursive_children(x:UOp, end:UOps, include_self=False) -> Set[UOp]:
|
||||
if x.op is UOps.SINK: return set()
|
||||
return set.union({x} if include_self else set(), *([get_recursive_children(u, end, True) for u in children[x] if x.op is not end]))
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def get_recursive_children(x:UOp, end:UOps, include_self=False) -> Set[UOp]:
|
||||
if x.op is UOps.SINK: return set()
|
||||
return set.union({x} if include_self else set(), *([get_recursive_children(u, end, True) for u in children[x] if x.op is not end]))
|
||||
# scope children impact the toposort and END* insertion
|
||||
scope_children = {p:get_recursive_children(p, END_FOR_UOP[p.op][0]) for p in reversed(in_degree) if p.op in END_FOR_UOP}
|
||||
range_phi = {r:[p for p in scope_children[r] if p.op is UOps.PHI] for r in scope_children if r.op is UOps.RANGE}
|
||||
|
||||
# scope children impact the toposort and END* insertion
|
||||
scope_children = {p:get_recursive_children(p, END_FOR_UOP[p.op][0]) for p in reversed(in_degree) if p.op in END_FOR_UOP}
|
||||
range_phi = {r:[p for p in scope_children[r] if p.op is UOps.PHI] for r in scope_children if r.op is UOps.RANGE}
|
||||
queue:List[Tuple[int, UOp]] = []
|
||||
def push(u:UOp):
|
||||
priority = 0
|
||||
# prefer ranges that depend on the least number of independent ranges
|
||||
if u.op is UOps.RANGE and u.arg[1]:
|
||||
priority += u.arg[0]
|
||||
for p in range_phi[u]:
|
||||
priority += 10000*len([r for r in range_srcs[p] if not any(i in range_phi[u] for i in range_phi[r])])
|
||||
# prefer uops that are loop children
|
||||
else:
|
||||
priority -= sum([(l.arg[0]+1) + 1000*l.arg[1] for l,ss in scope_children.items() if l.op is UOps.RANGE and u in ss])
|
||||
heapq.heappush(queue, (priority, u))
|
||||
|
||||
queue:List[Tuple[int, UOp]] = []
|
||||
def push(u:UOp):
|
||||
priority = 0
|
||||
# prefer ranges that depend on the least number of independent ranges
|
||||
if u.op is UOps.RANGE and u.arg[1]:
|
||||
priority += u.arg[0]
|
||||
for p in range_phi[u]:
|
||||
priority += 10000*len([r for r in range_srcs[p] if not any(i in range_phi[u] for i in range_phi[r])])
|
||||
# prefer uops that are loop children
|
||||
else:
|
||||
priority -= sum([(l.arg[0]+1) + 1000*l.arg[1] for l,ss in scope_children.items() if l.op is UOps.RANGE and u in ss])
|
||||
heapq.heappush(queue, (priority, u))
|
||||
for u in children:
|
||||
if in_degree[u] == 0: push(u)
|
||||
|
||||
for u in children:
|
||||
scope_end: Dict[UOp, UOp] = {}
|
||||
_uops: List[UOp] = []
|
||||
while queue:
|
||||
p,x = heapq.heappop(queue)
|
||||
if DEBUG >= 7: print(f"{p:5d}",x)
|
||||
if x in scope_children: scope_end[x] = x
|
||||
if x.op is UOps.DEFINE_ACC:
|
||||
idx = min([_uops.index(l) for l in x.src if l.op is UOps.RANGE])
|
||||
_uops.insert(idx, x)
|
||||
else: _uops.append(x)
|
||||
for u, ss in scope_children.items():
|
||||
if x in ss:
|
||||
ss.remove(x)
|
||||
if len(ss) == 0: scope_end[u] = x
|
||||
for u in children[x]:
|
||||
in_degree[u] -= 1
|
||||
if in_degree[u] == 0: push(u)
|
||||
|
||||
scope_end: Dict[UOp, UOp] = {}
|
||||
_uops: List[UOp] = []
|
||||
while queue:
|
||||
p,x = heapq.heappop(queue)
|
||||
if DEBUG >= 7: print(f"{p:5d}",x)
|
||||
if x in scope_children: scope_end[x] = x
|
||||
if x.op is UOps.DEFINE_ACC:
|
||||
idx = min([_uops.index(l) for l in x.src if l.op is UOps.RANGE])
|
||||
_uops.insert(idx, x)
|
||||
else: _uops.append(x)
|
||||
for u, ss in scope_children.items():
|
||||
if x in ss:
|
||||
ss.remove(x)
|
||||
if len(ss) == 0: scope_end[u] = x
|
||||
for u in children[x]:
|
||||
in_degree[u] -= 1
|
||||
if in_degree[u] == 0: push(u)
|
||||
# end scopes in toposort order
|
||||
for u, x in scope_end.items(): _uops.insert(_uops.index(x)+1, UOp(END_FOR_UOP[u.op][1], None, (u,)))
|
||||
|
||||
# end scopes in toposort order
|
||||
for u, x in scope_end.items(): _uops.insert(_uops.index(x)+1, UOp(END_FOR_UOP[u.op][1], None, (u,)))
|
||||
# sanity checks (NOTE: these can cause things to be skipped in BEAM)
|
||||
if not skip_check:
|
||||
bad_ops = dedup([x.op for x in _uops if x.op in {UOps.EXPAND, UOps.CONTRACT, UOps.REDUCE}])
|
||||
try:
|
||||
type_verify(_uops)
|
||||
assert _uops[-1].op is UOps.SINK, f"didn't end with SINK, ended with {_uops[-1]}"
|
||||
assert len(bad_ops) == 0, f"bad UOps left in list: {bad_ops}"
|
||||
# TODO: this should be enabled, and the valid clause should be removed
|
||||
# NOTE: multiple identical stores to DEFINE_LOCAL is okay
|
||||
assert len(all_stores := [x.src[0:2]+x.src[3:] for x in _uops if x.op is UOps.STORE and x.src[0].op is not UOps.DEFINE_LOCAL]) \
|
||||
== len(dedup(all_stores)), "repeated stores in uops"
|
||||
except AssertionError as e:
|
||||
print_uops(_uops)
|
||||
if not CI:
|
||||
from tinygrad.engine.graph import graph_uops
|
||||
graph_uops(_uops)
|
||||
raise e
|
||||
|
||||
# sanity checks (NOTE: these can cause things to be skipped in BEAM)
|
||||
if not skip_check:
|
||||
bad_ops = dedup([x.op for x in _uops if x.op in {UOps.EXPAND, UOps.CONTRACT, UOps.REDUCE}])
|
||||
try:
|
||||
type_verify(_uops)
|
||||
assert _uops[-1].op is UOps.SINK, f"didn't end with SINK, ended with {_uops[-1]}"
|
||||
assert len(bad_ops) == 0, f"bad UOps left in list: {bad_ops}"
|
||||
# TODO: this should be enabled, and the valid clause should be removed
|
||||
# NOTE: multiple identical stores to DEFINE_LOCAL is okay
|
||||
assert len(all_stores := [x.src[0:2]+x.src[3:] for x in _uops if x.op is UOps.STORE and x.src[0].op is not UOps.DEFINE_LOCAL]) \
|
||||
== len(dedup(all_stores)), "repeated stores in uops"
|
||||
except AssertionError as e:
|
||||
print_uops(_uops)
|
||||
if not CI:
|
||||
from tinygrad.engine.graph import graph_uops
|
||||
graph_uops(_uops)
|
||||
raise e
|
||||
|
||||
# strip the SINK
|
||||
return _uops[:-1]
|
||||
# strip the SINK
|
||||
return _uops[:-1]
|
||||
|
|
Loading…
Reference in New Issue