split full_graph_rewrite and linearize_uop [run_process_replay] (#6215)

* split full_graph_rewrite and linearize_uop

* fix tests

* graph rewrite in test uops

* add types
This commit is contained in:
George Hotz 2024-08-20 20:12:33 -07:00 committed by GitHub
parent 9faf205601
commit 16f420f7a7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 48 additions and 41 deletions

View File

@ -148,7 +148,7 @@ class TestIndexing(unittest.TestCase):
def test_index_mnist_opt(self): self.test_index_mnist(0) def test_index_mnist_opt(self): self.test_index_mnist(0)
@unittest.skipIf(getenv("PTX"), "broken on ptx for some reason") @unittest.skipIf(getenv("PTX"), "broken on ptx for some reason")
def test_llama_embedding(self, noopt=1, op_limit=100): def test_llama_embedding(self, noopt=1, op_limit=65536):
# llama3 is 128256 # llama3 is 128256
vocab_size, embed_size = (10, 3) if CI else (32000, 4096) vocab_size, embed_size = (10, 3) if CI else (32000, 4096)
emb = nn.Embedding(vocab_size, embed_size) emb = nn.Embedding(vocab_size, embed_size)

View File

@ -1,10 +1,11 @@
from typing import List
import unittest import unittest
from test.helpers import TestUOps from test.helpers import TestUOps
from tinygrad import dtypes, Variable from tinygrad import dtypes, Variable
from tinygrad.dtype import PtrDType from tinygrad.dtype import PtrDType
from tinygrad.helpers import DEBUG from tinygrad.helpers import DEBUG
from tinygrad.ops import BinaryOps, TernaryOps, UnaryOps, ReduceOps, UOps, UOp, NOp, PatternMatcher from tinygrad.ops import BinaryOps, TernaryOps, UnaryOps, ReduceOps, UOps, UOp, NOp, PatternMatcher
from tinygrad.codegen.uopgraph import linearize_uop, graph_rewrite, expander, reducer, constant_folder, float4_folding from tinygrad.codegen.uopgraph import linearize_uop, full_graph_rewrite, graph_rewrite, expander, reducer, constant_folder, float4_folding
simple_pm = PatternMatcher([ simple_pm = PatternMatcher([
(NOp.cvar('x', dtypes.int), lambda x: UOp.const(dtypes.float, 1.0) + UOp.const(dtypes.float, 2.0)), (NOp.cvar('x', dtypes.int), lambda x: UOp.const(dtypes.float, 1.0) + UOp.const(dtypes.float, 2.0)),
@ -13,6 +14,8 @@ simple_pm = PatternMatcher([
((NOp.var('x') + NOp.cvar('c1')) + NOp.cvar('c2'), lambda x,c1,c2: x + x.const(c1.arg+c2.arg)), ((NOp.var('x') + NOp.cvar('c1')) + NOp.cvar('c2'), lambda x,c1,c2: x + x.const(c1.arg+c2.arg)),
]) ])
def to_uops_list(u:List[UOp]) -> List[UOp]: return linearize_uop(full_graph_rewrite(UOp.sink(*u)))
class TestGraphRewrite(unittest.TestCase): class TestGraphRewrite(unittest.TestCase):
def test_dedup(self): def test_dedup(self):
v1 = UOp(UOps.DEFINE_VAR, dtypes.float) v1 = UOp(UOps.DEFINE_VAR, dtypes.float)
@ -94,7 +97,7 @@ class TestUOpGraph(TestUOps):
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0) c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
c2 = UOp(UOps.CONST, dtypes.float, arg=2.0) c2 = UOp(UOps.CONST, dtypes.float, arg=2.0)
out = UOp(UOps.ALU, dtypes.float, (c1, c2), BinaryOps.ADD) out = UOp(UOps.ALU, dtypes.float, (c1, c2), BinaryOps.ADD)
uops = linearize_uop([out]) uops = to_uops_list([out])
self.assertEqual(len(uops), 1) self.assertEqual(len(uops), 1)
out = uops[-1] out = uops[-1]
self.assertEqual(out.op, UOps.CONST) self.assertEqual(out.op, UOps.CONST)
@ -106,7 +109,7 @@ class TestUOpGraph(TestUOps):
vc = UOp(UOps.ALU, dtypes.bool, (v, c0), BinaryOps.CMPNE) vc = UOp(UOps.ALU, dtypes.bool, (v, c0), BinaryOps.CMPNE)
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0) c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
out = UOp(UOps.ALU, dtypes.float, (vc, c1, c1), TernaryOps.WHERE) out = UOp(UOps.ALU, dtypes.float, (vc, c1, c1), TernaryOps.WHERE)
uops = linearize_uop([out]) uops = to_uops_list([out])
self.assertEqual(len(uops), 1) self.assertEqual(len(uops), 1)
out = uops[-1] out = uops[-1]
self.assertEqual(out.op, UOps.CONST) self.assertEqual(out.op, UOps.CONST)
@ -117,7 +120,7 @@ class TestUOpGraph(TestUOps):
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0) c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
c2 = UOp(UOps.CONST, dtypes.float, arg=2.0) c2 = UOp(UOps.CONST, dtypes.float, arg=2.0)
out = UOp(UOps.ALU, dtypes.float, (bf, c1, c2), TernaryOps.WHERE) out = UOp(UOps.ALU, dtypes.float, (bf, c1, c2), TernaryOps.WHERE)
uops = linearize_uop([out]) uops = to_uops_list([out])
self.assertEqual(len(uops), 1) self.assertEqual(len(uops), 1)
out = uops[-1] out = uops[-1]
self.assertEqual(out.op, UOps.CONST) self.assertEqual(out.op, UOps.CONST)
@ -126,7 +129,7 @@ class TestUOpGraph(TestUOps):
def test_const_cast(self): def test_const_cast(self):
bf = UOp(UOps.CONST, dtypes.bool, arg=False) bf = UOp(UOps.CONST, dtypes.bool, arg=False)
out = UOp(UOps.CAST, dtypes.int, (bf,)) out = UOp(UOps.CAST, dtypes.int, (bf,))
uops = linearize_uop([out]) uops = to_uops_list([out])
self.assertEqual(len(uops), 1) self.assertEqual(len(uops), 1)
out = uops[-1] out = uops[-1]
self.assertEqual(out.op, UOps.CONST) self.assertEqual(out.op, UOps.CONST)
@ -140,7 +143,7 @@ class TestUOpGraph(TestUOps):
x = UOp(UOps.GEP, dtypes.float, (vec, ), arg=0) x = UOp(UOps.GEP, dtypes.float, (vec, ), arg=0)
alu = UOp(UOps.ALU, dtypes.float, (x, ), UnaryOps.SQRT) alu = UOp(UOps.ALU, dtypes.float, (x, ), UnaryOps.SQRT)
out = UOp(UOps.STORE, None, (d0, idx, alu)) out = UOp(UOps.STORE, None, (d0, idx, alu))
uops = linearize_uop([out]) uops = to_uops_list([out])
self.assertEqual(len([x for x in uops if x.op is UOps.VECTORIZE]), 0) self.assertEqual(len([x for x in uops if x.op is UOps.VECTORIZE]), 0)
def test_gep_vec_fold(self): def test_gep_vec_fold(self):
@ -151,7 +154,7 @@ class TestUOpGraph(TestUOps):
def _test_vec(geps, count=4): def _test_vec(geps, count=4):
vec = UOp(UOps.VECTORIZE, dtypes.float.vec(count), geps) vec = UOp(UOps.VECTORIZE, dtypes.float.vec(count), geps)
out = UOp(UOps.STORE, None, (d0, idx, vec)) out = UOp(UOps.STORE, None, (d0, idx, vec))
uops = linearize_uop([out]) uops = to_uops_list([out])
if DEBUG >= 4: if DEBUG >= 4:
from tinygrad import Device from tinygrad import Device
print(Device[Device.DEFAULT].renderer.render("test", uops)) print(Device[Device.DEFAULT].renderer.render("test", uops))
@ -186,8 +189,7 @@ class TestUOpGraph(TestUOps):
for vec_size in [2, 4, 8]: for vec_size in [2, 4, 8]:
consts = [UOp.const(dtypes.float, float(i)) for i in range(vec_size)] consts = [UOp.const(dtypes.float, float(i)) for i in range(vec_size)]
vec = UOp(UOps.VECTORIZE, dtypes.float.vec(vec_size), tuple(consts)) vec = UOp(UOps.VECTORIZE, dtypes.float.vec(vec_size), tuple(consts))
geps = [UOp(UOps.GEP, dtypes.float, (vec,), i) for i in range(vec_size)] uops = to_uops_list([UOp(UOps.GEP, dtypes.float, (vec,), i) for i in range(vec_size)])
uops = linearize_uop(geps)
for uop, const in zip(uops, consts): for uop, const in zip(uops, consts):
self.assert_equiv_uops(uop, const) self.assert_equiv_uops(uop, const)
@ -197,7 +199,7 @@ class TestUOpGraph(TestUOps):
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i)) var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i))
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0)) acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc)) wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc))
uops = linearize_uop([wmma]) uops = to_uops_list([wmma])
self.assert_equiv_uops(uops[0], acc) self.assert_equiv_uops(uops[0], acc)
self.assertEqual(len(uops), 1) self.assertEqual(len(uops), 1)
@ -206,7 +208,7 @@ class TestUOpGraph(TestUOps):
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i), tuple(UOp.const(dtypes.half, 0.0) for _ in range(i))) vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i), tuple(UOp.const(dtypes.half, 0.0) for _ in range(i)))
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0)) acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc)) wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc))
uops = linearize_uop([wmma]) uops = to_uops_list([wmma])
self.assert_equiv_uops(uops[0], acc) self.assert_equiv_uops(uops[0], acc)
self.assertEqual(len(uops), 1) self.assertEqual(len(uops), 1)
@ -218,7 +220,7 @@ class TestUOpGraph(TestUOps):
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable(f'tmp{i}', 0.0, 1.0)) var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable(f'tmp{i}', 0.0, 1.0))
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0)) acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc)) wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc))
uops = linearize_uop([wmma]) uops = to_uops_list([wmma])
self.assert_equiv_uops(uops[-1], wmma) self.assert_equiv_uops(uops[-1], wmma)
for i in [4, 8]: for i in [4, 8]:
@ -228,7 +230,7 @@ class TestUOpGraph(TestUOps):
tuple(UOp(UOps.DEFINE_VAR, dtypes.half, arg=Variable(f'tmp{j}', 0.0, 1.0)) for j in range(i//2))) tuple(UOp(UOps.DEFINE_VAR, dtypes.half, arg=Variable(f'tmp{j}', 0.0, 1.0)) for j in range(i//2)))
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0)) acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc)) wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc))
uops = linearize_uop([wmma]) uops = to_uops_list([wmma])
self.assert_equiv_uops(uops[-1], wmma) self.assert_equiv_uops(uops[-1], wmma)
for i in [2, 4, 8]: for i in [2, 4, 8]:
@ -237,7 +239,7 @@ class TestUOpGraph(TestUOps):
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable(f'tmp{i}', 0.0, 1.0)) var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable(f'tmp{i}', 0.0, 1.0))
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0)) acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc)) wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc))
uops = linearize_uop([wmma]) uops = to_uops_list([wmma])
self.assert_equiv_uops(uops[-1], wmma) self.assert_equiv_uops(uops[-1], wmma)
for i in [2, 4, 8]: for i in [2, 4, 8]:
@ -246,7 +248,7 @@ class TestUOpGraph(TestUOps):
tuple(UOp.const(dtypes.half, 1.0 if j == 0 else 0.0) for j in range(i))) tuple(UOp.const(dtypes.half, 1.0 if j == 0 else 0.0) for j in range(i)))
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0)) acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc)) wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc))
uops = linearize_uop([wmma]) uops = to_uops_list([wmma])
self.assert_equiv_uops(uops[-1], wmma) self.assert_equiv_uops(uops[-1], wmma)
def test_cast_alu_fold(self): def test_cast_alu_fold(self):
@ -256,7 +258,7 @@ class TestUOpGraph(TestUOps):
ld = UOp(UOps.LOAD, dtypes.int, (d1, idx)) ld = UOp(UOps.LOAD, dtypes.int, (d1, idx))
alu = ld.lt(1).cast(dtypes.bool) alu = ld.lt(1).cast(dtypes.bool)
out = UOp(UOps.STORE, None, (d0, idx, alu)) out = UOp(UOps.STORE, None, (d0, idx, alu))
uops = linearize_uop([out]) uops = to_uops_list([out])
self.assertEqual(len([x for x in uops if x.op is UOps.CAST]), 0) self.assertEqual(len([x for x in uops if x.op is UOps.CAST]), 0)
def test_double_cast_fold(self): def test_double_cast_fold(self):
@ -266,7 +268,7 @@ class TestUOpGraph(TestUOps):
ld = UOp(UOps.LOAD, dtypes.int, (d1, idx)) ld = UOp(UOps.LOAD, dtypes.int, (d1, idx))
alu = ld.cast(dtypes.float).cast(dtypes.float) alu = ld.cast(dtypes.float).cast(dtypes.float)
out = UOp(UOps.STORE, None, (d0, idx, alu)) out = UOp(UOps.STORE, None, (d0, idx, alu))
uops = linearize_uop([out]) uops = to_uops_list([out])
self.assertEqual(len([x for x in uops if x.op is UOps.CAST]), 1) self.assertEqual(len([x for x in uops if x.op is UOps.CAST]), 1)
def test_depth_2_const_fold(self): def test_depth_2_const_fold(self):
@ -275,7 +277,7 @@ class TestUOpGraph(TestUOps):
c4 = UOp(UOps.CONST, dtypes.int, arg=4) c4 = UOp(UOps.CONST, dtypes.int, arg=4)
vc = UOp(UOps.ALU, dtypes.int, (v, c2), BinaryOps.ADD) vc = UOp(UOps.ALU, dtypes.int, (v, c2), BinaryOps.ADD)
out = UOp(UOps.ALU, dtypes.int, (vc, c4), BinaryOps.ADD) out = UOp(UOps.ALU, dtypes.int, (vc, c4), BinaryOps.ADD)
uops = linearize_uop([out]) uops = to_uops_list([out])
self.assertEqual(len(uops), 5) self.assertEqual(len(uops), 5)
out = uops[-1] out = uops[-1]
self.assertEqual(out.op, UOps.ALU) self.assertEqual(out.op, UOps.ALU)
@ -290,7 +292,7 @@ class TestUOpGraph(TestUOps):
idx = UOp.const(dtypes.int, 0) idx = UOp.const(dtypes.int, 0)
ld0 = UOp(UOps.LOAD, dtypes.int, (glbl1, idx, UOp.const(dtypes.int, 2), UOp.const(dtypes.bool, False))) ld0 = UOp(UOps.LOAD, dtypes.int, (glbl1, idx, UOp.const(dtypes.int, 2), UOp.const(dtypes.bool, False)))
ld1 = UOp(UOps.LOAD, dtypes.int, (glbl2, idx, UOp.const(dtypes.int, 3), UOp.const(dtypes.bool, True))) ld1 = UOp(UOps.LOAD, dtypes.int, (glbl2, idx, UOp.const(dtypes.int, 3), UOp.const(dtypes.bool, True)))
uops = linearize_uop([UOp(UOps.STORE, None, (glbl0, idx, ld1+ld0))]) uops = to_uops_list([UOp(UOps.STORE, None, (glbl0, idx, ld1+ld0))])
ld0, ld1 = uops[-1].src[2].src ld0, ld1 = uops[-1].src[2].src
# ld0 becomes the invalid value # ld0 becomes the invalid value
self.assert_equiv_uops(ld1, UOp.const(dtypes.int, 2)) self.assert_equiv_uops(ld1, UOp.const(dtypes.int, 2))
@ -305,7 +307,7 @@ class TestUOpGraph(TestUOps):
barrier = UOp(UOps.BARRIER, None, (st, )) barrier = UOp(UOps.BARRIER, None, (st, ))
ld0 = UOp(UOps.LOAD, dtypes.int, (smem, lidx+1, UOp.const(dtypes.int, 2), UOp.const(dtypes.bool, False), barrier)) ld0 = UOp(UOps.LOAD, dtypes.int, (smem, lidx+1, UOp.const(dtypes.int, 2), UOp.const(dtypes.bool, False), barrier))
ld1 = UOp(UOps.LOAD, dtypes.int, (smem, lidx+2, UOp.const(dtypes.int, 3), UOp.const(dtypes.bool, True), barrier)) ld1 = UOp(UOps.LOAD, dtypes.int, (smem, lidx+2, UOp.const(dtypes.int, 3), UOp.const(dtypes.bool, True), barrier))
uops = linearize_uop([UOp(UOps.STORE, None, (glbl0, lidx, ld1+ld0))]) uops = to_uops_list([UOp(UOps.STORE, None, (glbl0, lidx, ld1+ld0))])
ld0, ld1 = uops[-1].src[2].src ld0, ld1 = uops[-1].src[2].src
# ld0 becomes the invalid value # ld0 becomes the invalid value
self.assert_equiv_uops(ld1, UOp.const(dtypes.int, 2)) self.assert_equiv_uops(ld1, UOp.const(dtypes.int, 2))
@ -319,7 +321,7 @@ class TestUOpGraph(TestUOps):
val = UOp.const(dtypes.int, 42) val = UOp.const(dtypes.int, 42)
st0 = UOp(UOps.STORE, None, (glbl, idx0, val, UOp.const(dtypes.bool, False))) st0 = UOp(UOps.STORE, None, (glbl, idx0, val, UOp.const(dtypes.bool, False)))
st1 = UOp(UOps.STORE, None, (glbl, idx1, val, UOp.const(dtypes.bool, True))) st1 = UOp(UOps.STORE, None, (glbl, idx1, val, UOp.const(dtypes.bool, True)))
uops = linearize_uop([st0, st1]) uops = to_uops_list([st0, st1])
# only the second store happens # only the second store happens
self.assertEqual(len(uops), 4) self.assertEqual(len(uops), 4)
self.assert_equiv_uops(uops[-1], UOp.store(glbl, idx1, val)) self.assert_equiv_uops(uops[-1], UOp.store(glbl, idx1, val))
@ -328,7 +330,7 @@ class TestUOpGraph(TestUOps):
glbl0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0) glbl0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0)
idx = UOp.const(dtypes.int, 0) idx = UOp.const(dtypes.int, 0)
bad_gate = UOp.const(dtypes.int, 1) bad_gate = UOp.const(dtypes.int, 1)
with self.assertRaises(AssertionError): linearize_uop([UOp(UOps.STORE, None, (glbl0, idx, UOp.const(dtypes.int, 42), bad_gate))]) with self.assertRaises(AssertionError): to_uops_list([UOp(UOps.STORE, None, (glbl0, idx, UOp.const(dtypes.int, 42), bad_gate))])
def test_switched_range_order(self): def test_switched_range_order(self):
glbl = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0) glbl = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0)
@ -339,7 +341,7 @@ class TestUOpGraph(TestUOps):
r2 = UOp(UOps.RANGE, dtypes.int, (c0, c2), (1, 1, False)) r2 = UOp(UOps.RANGE, dtypes.int, (c0, c2), (1, 1, False))
alu = UOp(UOps.ALU, dtypes.int, (r2, r1), BinaryOps.MUL) alu = UOp(UOps.ALU, dtypes.int, (r2, r1), BinaryOps.MUL)
store = UOp(UOps.STORE, None, (glbl, alu, cf)) store = UOp(UOps.STORE, None, (glbl, alu, cf))
uops = linearize_uop([store]) uops = to_uops_list([store])
ranges = [x for x in uops if x.op is UOps.RANGE] ranges = [x for x in uops if x.op is UOps.RANGE]
endranges = [x for x in uops if x.op is UOps.ENDRANGE] endranges = [x for x in uops if x.op is UOps.ENDRANGE]
# ranges are closed in the right order # ranges are closed in the right order

View File

@ -9,11 +9,13 @@ from tinygrad.ops import UOps, NOp, UOp, UnaryOps, BinaryOps, TernaryOps, Reduce
from tinygrad.renderer import Program from tinygrad.renderer import Program
from tinygrad.engine.schedule import create_schedule from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import CompiledRunner, lower_schedule_item, get_kernel from tinygrad.engine.realize import CompiledRunner, lower_schedule_item, get_kernel
from tinygrad.codegen.uopgraph import linearize_uop from tinygrad.codegen.uopgraph import linearize_uop, full_graph_rewrite
from test.helpers import is_dtype_supported, TestUOps as TestEqUOps from test.helpers import is_dtype_supported, TestUOps as TestEqUOps
def to_uops_list(u:List[UOp], opts=None, skip_check=False) -> List[UOp]: return linearize_uop(full_graph_rewrite(UOp.sink(*u), opts), skip_check)
def _uops_to_prg(uops_list): def _uops_to_prg(uops_list):
uops = linearize_uop(uops_list, opts=Device[Device.DEFAULT].renderer) uops = linearize_uop(full_graph_rewrite(UOp.sink(*uops_list), opts=Device[Device.DEFAULT].renderer))
src = Device[Device.DEFAULT].renderer.render("test", uops) src = Device[Device.DEFAULT].renderer.render("test", uops)
has_local = Device[Device.DEFAULT].renderer.has_local has_local = Device[Device.DEFAULT].renderer.has_local
return CompiledRunner(Program("test", src, Device.DEFAULT, uops=uops, return CompiledRunner(Program("test", src, Device.DEFAULT, uops=uops,
@ -255,7 +257,7 @@ class TestGatedStoreRewrite(unittest.TestCase):
val = UOp.const(dtypes.float, 42.0) val = UOp.const(dtypes.float, 42.0)
gate = gidx0.lt(UOp.const(dtypes.int, 1)) gate = gidx0.lt(UOp.const(dtypes.int, 1))
store = UOp(UOps.STORE, None, (gmem, idx, val, gate)) store = UOp(UOps.STORE, None, (gmem, idx, val, gate))
uops = linearize_uop([store]) uops = to_uops_list([store])
if DEBUG >= 4: print(Device[Device.DEFAULT].renderer.render("test", uops)) if DEBUG >= 4: print(Device[Device.DEFAULT].renderer.render("test", uops))
if_uop = next(u for u in uops if u.op is UOps.IF) if_uop = next(u for u in uops if u.op is UOps.IF)
endif = next(u for u in uops if u.op is UOps.ENDIF) endif = next(u for u in uops if u.op is UOps.ENDIF)
@ -334,7 +336,7 @@ class TestAssembly(unittest.TestCase):
l1 = UOp(UOps.LOAD, dtypes.int, (g1, c1)) l1 = UOp(UOps.LOAD, dtypes.int, (g1, c1))
a1 = UOp(UOps.ALU, dtypes.int, (l1, c1), BinaryOps.MUL) a1 = UOp(UOps.ALU, dtypes.int, (l1, c1), BinaryOps.MUL)
a2 = UOp(UOps.ALU, dtypes.int, (l1, c2), BinaryOps.MUL) a2 = UOp(UOps.ALU, dtypes.int, (l1, c2), BinaryOps.MUL)
uops = linearize_uop([a1,a2], opts=Device[Device.DEFAULT].renderer) uops = to_uops_list([a1,a2], opts=Device[Device.DEFAULT].renderer)
Device[Device.DEFAULT].renderer.render("test", uops) Device[Device.DEFAULT].renderer.render("test", uops)
self.assertEqual(uops[-1].arg, BinaryOps.SHL) self.assertEqual(uops[-1].arg, BinaryOps.SHL)
self.assertEqual(uops[-2].arg, BinaryOps.MUL) self.assertEqual(uops[-2].arg, BinaryOps.MUL)
@ -346,7 +348,7 @@ class TestAssembly(unittest.TestCase):
l1 = UOp(UOps.LOAD, dtypes.int, (g1, c1)) l1 = UOp(UOps.LOAD, dtypes.int, (g1, c1))
a1 = UOp(UOps.ALU, dtypes.int, (l1, c1), BinaryOps.IDIV) a1 = UOp(UOps.ALU, dtypes.int, (l1, c1), BinaryOps.IDIV)
a2 = UOp(UOps.ALU, dtypes.int, (l1, c2), BinaryOps.IDIV) a2 = UOp(UOps.ALU, dtypes.int, (l1, c2), BinaryOps.IDIV)
uops = linearize_uop([a1,a2], opts=Device[Device.DEFAULT].renderer) uops = to_uops_list([a1,a2], opts=Device[Device.DEFAULT].renderer)
Device[Device.DEFAULT].renderer.render("test", uops) Device[Device.DEFAULT].renderer.render("test", uops)
self.assertEqual(uops[-1].arg, BinaryOps.SHR) self.assertEqual(uops[-1].arg, BinaryOps.SHR)
self.assertEqual(uops[-2].arg, BinaryOps.IDIV) self.assertEqual(uops[-2].arg, BinaryOps.IDIV)
@ -385,7 +387,7 @@ class TestIndexingOrdering(unittest.TestCase):
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 0) buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 0)
st0 = UOp(UOps.STORE, dtypes.float.vec(4), (buf, UOp.const(dtypes.int, 0), UOp.const(dtypes.float.vec(4), 42))) st0 = UOp(UOps.STORE, dtypes.float.vec(4), (buf, UOp.const(dtypes.int, 0), UOp.const(dtypes.float.vec(4), 42)))
st1 = UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, 4), UOp.const(dtypes.float, 10))) st1 = UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, 4), UOp.const(dtypes.float, 10)))
uops = linearize_uop([st1, st0], skip_check=True) uops = to_uops_list([st1, st0], skip_check=True)
stores = [st for st in uops if st.op is UOps.STORE] stores = [st for st in uops if st.op is UOps.STORE]
assert stores[0].src[1] < stores[1].src[1], f"stored at idx {stores[1].src[1].arg} AFTER {stores[0].src[1].arg}" assert stores[0].src[1] < stores[1].src[1], f"stored at idx {stores[1].src[1].arg} AFTER {stores[0].src[1].arg}"
@ -397,7 +399,7 @@ class TestIndexingOrdering(unittest.TestCase):
st1_0 = UOp(UOps.STORE, dtypes.float, (buf0, UOp.const(dtypes.int, 4), UOp.const(dtypes.float, 10))) st1_0 = UOp(UOps.STORE, dtypes.float, (buf0, UOp.const(dtypes.int, 4), UOp.const(dtypes.float, 10)))
st0_1 = UOp(UOps.STORE, dtypes.float.vec(4), (buf1, UOp.const(dtypes.int, 0), UOp.const(dtypes.float.vec(4), 42))) st0_1 = UOp(UOps.STORE, dtypes.float.vec(4), (buf1, UOp.const(dtypes.int, 0), UOp.const(dtypes.float.vec(4), 42)))
st1_1 = UOp(UOps.STORE, dtypes.float, (buf1, UOp.const(dtypes.int, 4), UOp.const(dtypes.float, 10))) st1_1 = UOp(UOps.STORE, dtypes.float, (buf1, UOp.const(dtypes.int, 4), UOp.const(dtypes.float, 10)))
uops = linearize_uop([st0_0, st1_0, st0_1, st1_1], skip_check=True) uops = to_uops_list([st0_0, st1_0, st0_1, st1_1], skip_check=True)
stores = [st for st in uops if st.op is UOps.STORE] stores = [st for st in uops if st.op is UOps.STORE]
print("\n".join(map(str, stores))) print("\n".join(map(str, stores)))
# buf0 stores come first # buf0 stores come first
@ -413,7 +415,7 @@ class TestIndexingOrdering(unittest.TestCase):
gidx0 = UOp(UOps.SPECIAL, dtypes.int, (), ('gidx0', 4)) gidx0 = UOp(UOps.SPECIAL, dtypes.int, (), ('gidx0', 4))
st0 = UOp(UOps.STORE, dtypes.float.vec(4), (buf, gidx0+UOp.const(dtypes.int, 0), UOp.const(dtypes.float.vec(4), 42))) st0 = UOp(UOps.STORE, dtypes.float.vec(4), (buf, gidx0+UOp.const(dtypes.int, 0), UOp.const(dtypes.float.vec(4), 42)))
st1 = UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, 4), UOp.const(dtypes.float, 10))) st1 = UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, 4), UOp.const(dtypes.float, 10)))
uops = linearize_uop([st1, st0], skip_check=True) uops = linearize_uop(UOp.sink(st1, st0), skip_check=True)
stores = [st for st in uops if st.op is UOps.STORE] stores = [st for st in uops if st.op is UOps.STORE]
assert stores[0].src[1] < stores[1].src[1], f"stored at idx {stores[1].src[1].arg} AFTER {stores[0].src[1].arg}" assert stores[0].src[1] < stores[1].src[1], f"stored at idx {stores[1].src[1].arg} AFTER {stores[0].src[1].arg}"

View File

@ -105,7 +105,7 @@ class TestUOpsStats(unittest.TestCase):
u3 = UOp(UOps.CONST, dtypes.int, tuple(), 3) u3 = UOp(UOps.CONST, dtypes.int, tuple(), 3)
u4 = UOp(UOps.ALU, dtypes.int, (u1,u2), BinaryOps.MUL) u4 = UOp(UOps.ALU, dtypes.int, (u1,u2), BinaryOps.MUL)
u5 = UOp(UOps.ALU, dtypes.int, (u4,u3), BinaryOps.ADD) u5 = UOp(UOps.ALU, dtypes.int, (u4,u3), BinaryOps.ADD)
uops = linearize_uop([u5]) uops = linearize_uop(u5.sink())
globl = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), tuple()) globl = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), tuple())
o1 = UOp(UOps.CONST, dtypes.int, tuple(), 1) o1 = UOp(UOps.CONST, dtypes.int, tuple(), 1)
@ -114,7 +114,7 @@ class TestUOpsStats(unittest.TestCase):
u2 = UOp(UOps.LOAD, dtypes.int, (globl, o2)) u2 = UOp(UOps.LOAD, dtypes.int, (globl, o2))
u3 = UOp(UOps.CONST, dtypes.int, tuple(), 3) u3 = UOp(UOps.CONST, dtypes.int, tuple(), 3)
u4 = UOp(UOps.ALU, dtypes.int, (u1,u2,u3), TernaryOps.MULACC) u4 = UOp(UOps.ALU, dtypes.int, (u1,u2,u3), TernaryOps.MULACC)
uops_fma = linearize_uop([u4]) uops_fma = linearize_uop(u4.sink())
self.assertEqual(flops_mem(uops), flops_mem(uops_fma)) self.assertEqual(flops_mem(uops), flops_mem(uops_fma))

View File

@ -9,14 +9,14 @@ from typing import Tuple
from tinygrad.helpers import DEBUG from tinygrad.helpers import DEBUG
from tinygrad.dtype import dtypes, PtrDType, ConstType from tinygrad.dtype import dtypes, PtrDType, ConstType
from tinygrad.codegen.uopgraph import linearize_uop from tinygrad.codegen.uopgraph import linearize_uop, full_graph_rewrite
from tinygrad.ops import BinaryOps, UOp, UOps, print_uops from tinygrad.ops import BinaryOps, UOp, UOps, print_uops
import functools import functools
def render(self) -> Tuple[str, ConstType, ConstType]: def render(self) -> Tuple[str, ConstType, ConstType]:
# NOTE: we need STORE so the ALU op has children # NOTE: we need STORE so the ALU op has children
glbl = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=0) glbl = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=0)
uops = linearize_uop([UOp(UOps.STORE, None, (glbl, UOp.const(dtypes.int, 0), self))]) uops = linearize_uop(full_graph_rewrite(UOp(UOps.STORE, None, (glbl, UOp.const(dtypes.int, 0), self)).sink()))
if DEBUG>=5: print_uops(uops) if DEBUG>=5: print_uops(uops)
from tinygrad.renderer.cstyle import CStyleLanguage from tinygrad.renderer.cstyle import CStyleLanguage
class TestRenderer(CStyleLanguage): class TestRenderer(CStyleLanguage):

View File

@ -13,7 +13,7 @@ from tinygrad.helpers import all_same, colored, ansilen, dedup, getenv, prod, DE
from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.symbolic import Variable, sint from tinygrad.shape.symbolic import Variable, sint
from tinygrad.shape.view import strides_for_shape from tinygrad.shape.view import strides_for_shape
from tinygrad.codegen.uopgraph import linearize_uop from tinygrad.codegen.uopgraph import linearize_uop, full_graph_rewrite
from tinygrad.codegen.lowerer import ast_to_uop from tinygrad.codegen.lowerer import ast_to_uop
from enum import Enum, auto from enum import Enum, auto
@ -745,7 +745,7 @@ class Kernel:
print(self.applied_opts) print(self.applied_opts)
verify_ast(modified_ast) verify_ast(modified_ast)
self.uops:List[UOp] = linearize_uop(ast_to_uop(modified_ast, self.opts), self.opts) self.uops:List[UOp] = linearize_uop(full_graph_rewrite(ast_to_uop(modified_ast, self.opts), self.opts))
if DEBUG >= 5: print_uops(self.uops) if DEBUG >= 5: print_uops(self.uops)
if getenv("GRAPHUOPS"): if getenv("GRAPHUOPS"):
from tinygrad.engine.graph import graph_uops from tinygrad.engine.graph import graph_uops

View File

@ -1,5 +1,5 @@
from __future__ import annotations from __future__ import annotations
from typing import Optional, Tuple, Dict, List, Set, Union, cast, TYPE_CHECKING, Any, DefaultDict, Callable from typing import Optional, Tuple, Dict, List, Set, cast, TYPE_CHECKING, Any, DefaultDict, Callable
import functools, itertools, heapq, math, operator import functools, itertools, heapq, math, operator
from collections import defaultdict from collections import defaultdict
from tinygrad.dtype import dtypes, PtrDType, ImageDType, DType from tinygrad.dtype import dtypes, PtrDType, ImageDType, DType
@ -512,9 +512,8 @@ def get_children_dfs(u:UOp, children:Dict[UOp, List[UOp]], srcs:Dict[UOp, Dict[U
return srcs[u] return srcs[u]
linearize_cnt = 0 linearize_cnt = 0
def linearize_uop(sink_in:Union[UOp, List[UOp]], opts:Optional[Renderer]=None, skip_check=False) -> List[UOp]: def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp:
global linearize_cnt, acc_number global linearize_cnt, acc_number
sink: UOp = sink_in if isinstance(sink_in, UOp) else UOp(UOps.SINK, None, tuple(sink_in))
assert sink.op is UOps.SINK, f"sink isn't sink, it's {sink.op}" assert sink.op is UOps.SINK, f"sink isn't sink, it's {sink.op}"
folder = constant_folder + transcendental_folding(tuple() if TRANSCENDENTAL >= 2 or opts is None else tuple(opts.code_for_op.keys())) folder = constant_folder + transcendental_folding(tuple() if TRANSCENDENTAL >= 2 or opts is None else tuple(opts.code_for_op.keys()))
@ -533,7 +532,10 @@ def linearize_uop(sink_in:Union[UOp, List[UOp]], opts:Optional[Renderer]=None, s
# for PTX only # for PTX only
if opts is not None and opts.extra_matcher is not None: sink = graph_rewrite(sink, folder+opts.extra_matcher) if opts is not None and opts.extra_matcher is not None: sink = graph_rewrite(sink, folder+opts.extra_matcher)
return sink
def linearize_uop(sink:UOp, skip_check:bool=False) -> List[UOp]:
assert sink.op is UOps.SINK, f"sink isn't sink, it's {sink.op}"
# filter nodes that don't link to a sink # filter nodes that don't link to a sink
# BFS toposort # BFS toposort
children: Dict[UOp, List[UOp]] = {} children: Dict[UOp, List[UOp]] = {}

View File

@ -123,6 +123,7 @@ class UOp:
ret = self.src[0 if self.op is UOps.CONST else 1] ret = self.src[0 if self.op is UOps.CONST else 1]
assert ret.op is UOps.SHAPETRACKER, f"st_arg trying to return {ret}" assert ret.op is UOps.SHAPETRACKER, f"st_arg trying to return {ret}"
return ret.arg return ret.arg
def sink(self, *srcs): return UOp(UOps.SINK, None, (self,)+srcs)
def ufix(self, x): return self.const(x) if not isinstance(x, UOp) else x def ufix(self, x): return self.const(x) if not isinstance(x, UOp) else x
def cast(self, dtype=None): return type(self)(UOps.CAST, dtype, (self,)) def cast(self, dtype=None): return type(self)(UOps.CAST, dtype, (self,))
def bitcast(self, dtype=None): return type(self)(UOps.BITCAST, dtype, (self,)) def bitcast(self, dtype=None): return type(self)(UOps.BITCAST, dtype, (self,))