mirror of https://github.com/commaai/tinygrad.git
split full_graph_rewrite and linearize_uop [run_process_replay] (#6215)
* split full_graph_rewrite and linearize_uop * fix tests * graph rewrite in test uops * add types
This commit is contained in:
parent
9faf205601
commit
16f420f7a7
|
@ -148,7 +148,7 @@ class TestIndexing(unittest.TestCase):
|
||||||
def test_index_mnist_opt(self): self.test_index_mnist(0)
|
def test_index_mnist_opt(self): self.test_index_mnist(0)
|
||||||
|
|
||||||
@unittest.skipIf(getenv("PTX"), "broken on ptx for some reason")
|
@unittest.skipIf(getenv("PTX"), "broken on ptx for some reason")
|
||||||
def test_llama_embedding(self, noopt=1, op_limit=100):
|
def test_llama_embedding(self, noopt=1, op_limit=65536):
|
||||||
# llama3 is 128256
|
# llama3 is 128256
|
||||||
vocab_size, embed_size = (10, 3) if CI else (32000, 4096)
|
vocab_size, embed_size = (10, 3) if CI else (32000, 4096)
|
||||||
emb = nn.Embedding(vocab_size, embed_size)
|
emb = nn.Embedding(vocab_size, embed_size)
|
||||||
|
|
|
@ -1,10 +1,11 @@
|
||||||
|
from typing import List
|
||||||
import unittest
|
import unittest
|
||||||
from test.helpers import TestUOps
|
from test.helpers import TestUOps
|
||||||
from tinygrad import dtypes, Variable
|
from tinygrad import dtypes, Variable
|
||||||
from tinygrad.dtype import PtrDType
|
from tinygrad.dtype import PtrDType
|
||||||
from tinygrad.helpers import DEBUG
|
from tinygrad.helpers import DEBUG
|
||||||
from tinygrad.ops import BinaryOps, TernaryOps, UnaryOps, ReduceOps, UOps, UOp, NOp, PatternMatcher
|
from tinygrad.ops import BinaryOps, TernaryOps, UnaryOps, ReduceOps, UOps, UOp, NOp, PatternMatcher
|
||||||
from tinygrad.codegen.uopgraph import linearize_uop, graph_rewrite, expander, reducer, constant_folder, float4_folding
|
from tinygrad.codegen.uopgraph import linearize_uop, full_graph_rewrite, graph_rewrite, expander, reducer, constant_folder, float4_folding
|
||||||
|
|
||||||
simple_pm = PatternMatcher([
|
simple_pm = PatternMatcher([
|
||||||
(NOp.cvar('x', dtypes.int), lambda x: UOp.const(dtypes.float, 1.0) + UOp.const(dtypes.float, 2.0)),
|
(NOp.cvar('x', dtypes.int), lambda x: UOp.const(dtypes.float, 1.0) + UOp.const(dtypes.float, 2.0)),
|
||||||
|
@ -13,6 +14,8 @@ simple_pm = PatternMatcher([
|
||||||
((NOp.var('x') + NOp.cvar('c1')) + NOp.cvar('c2'), lambda x,c1,c2: x + x.const(c1.arg+c2.arg)),
|
((NOp.var('x') + NOp.cvar('c1')) + NOp.cvar('c2'), lambda x,c1,c2: x + x.const(c1.arg+c2.arg)),
|
||||||
])
|
])
|
||||||
|
|
||||||
|
def to_uops_list(u:List[UOp]) -> List[UOp]: return linearize_uop(full_graph_rewrite(UOp.sink(*u)))
|
||||||
|
|
||||||
class TestGraphRewrite(unittest.TestCase):
|
class TestGraphRewrite(unittest.TestCase):
|
||||||
def test_dedup(self):
|
def test_dedup(self):
|
||||||
v1 = UOp(UOps.DEFINE_VAR, dtypes.float)
|
v1 = UOp(UOps.DEFINE_VAR, dtypes.float)
|
||||||
|
@ -94,7 +97,7 @@ class TestUOpGraph(TestUOps):
|
||||||
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
|
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
|
||||||
c2 = UOp(UOps.CONST, dtypes.float, arg=2.0)
|
c2 = UOp(UOps.CONST, dtypes.float, arg=2.0)
|
||||||
out = UOp(UOps.ALU, dtypes.float, (c1, c2), BinaryOps.ADD)
|
out = UOp(UOps.ALU, dtypes.float, (c1, c2), BinaryOps.ADD)
|
||||||
uops = linearize_uop([out])
|
uops = to_uops_list([out])
|
||||||
self.assertEqual(len(uops), 1)
|
self.assertEqual(len(uops), 1)
|
||||||
out = uops[-1]
|
out = uops[-1]
|
||||||
self.assertEqual(out.op, UOps.CONST)
|
self.assertEqual(out.op, UOps.CONST)
|
||||||
|
@ -106,7 +109,7 @@ class TestUOpGraph(TestUOps):
|
||||||
vc = UOp(UOps.ALU, dtypes.bool, (v, c0), BinaryOps.CMPNE)
|
vc = UOp(UOps.ALU, dtypes.bool, (v, c0), BinaryOps.CMPNE)
|
||||||
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
|
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
|
||||||
out = UOp(UOps.ALU, dtypes.float, (vc, c1, c1), TernaryOps.WHERE)
|
out = UOp(UOps.ALU, dtypes.float, (vc, c1, c1), TernaryOps.WHERE)
|
||||||
uops = linearize_uop([out])
|
uops = to_uops_list([out])
|
||||||
self.assertEqual(len(uops), 1)
|
self.assertEqual(len(uops), 1)
|
||||||
out = uops[-1]
|
out = uops[-1]
|
||||||
self.assertEqual(out.op, UOps.CONST)
|
self.assertEqual(out.op, UOps.CONST)
|
||||||
|
@ -117,7 +120,7 @@ class TestUOpGraph(TestUOps):
|
||||||
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
|
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
|
||||||
c2 = UOp(UOps.CONST, dtypes.float, arg=2.0)
|
c2 = UOp(UOps.CONST, dtypes.float, arg=2.0)
|
||||||
out = UOp(UOps.ALU, dtypes.float, (bf, c1, c2), TernaryOps.WHERE)
|
out = UOp(UOps.ALU, dtypes.float, (bf, c1, c2), TernaryOps.WHERE)
|
||||||
uops = linearize_uop([out])
|
uops = to_uops_list([out])
|
||||||
self.assertEqual(len(uops), 1)
|
self.assertEqual(len(uops), 1)
|
||||||
out = uops[-1]
|
out = uops[-1]
|
||||||
self.assertEqual(out.op, UOps.CONST)
|
self.assertEqual(out.op, UOps.CONST)
|
||||||
|
@ -126,7 +129,7 @@ class TestUOpGraph(TestUOps):
|
||||||
def test_const_cast(self):
|
def test_const_cast(self):
|
||||||
bf = UOp(UOps.CONST, dtypes.bool, arg=False)
|
bf = UOp(UOps.CONST, dtypes.bool, arg=False)
|
||||||
out = UOp(UOps.CAST, dtypes.int, (bf,))
|
out = UOp(UOps.CAST, dtypes.int, (bf,))
|
||||||
uops = linearize_uop([out])
|
uops = to_uops_list([out])
|
||||||
self.assertEqual(len(uops), 1)
|
self.assertEqual(len(uops), 1)
|
||||||
out = uops[-1]
|
out = uops[-1]
|
||||||
self.assertEqual(out.op, UOps.CONST)
|
self.assertEqual(out.op, UOps.CONST)
|
||||||
|
@ -140,7 +143,7 @@ class TestUOpGraph(TestUOps):
|
||||||
x = UOp(UOps.GEP, dtypes.float, (vec, ), arg=0)
|
x = UOp(UOps.GEP, dtypes.float, (vec, ), arg=0)
|
||||||
alu = UOp(UOps.ALU, dtypes.float, (x, ), UnaryOps.SQRT)
|
alu = UOp(UOps.ALU, dtypes.float, (x, ), UnaryOps.SQRT)
|
||||||
out = UOp(UOps.STORE, None, (d0, idx, alu))
|
out = UOp(UOps.STORE, None, (d0, idx, alu))
|
||||||
uops = linearize_uop([out])
|
uops = to_uops_list([out])
|
||||||
self.assertEqual(len([x for x in uops if x.op is UOps.VECTORIZE]), 0)
|
self.assertEqual(len([x for x in uops if x.op is UOps.VECTORIZE]), 0)
|
||||||
|
|
||||||
def test_gep_vec_fold(self):
|
def test_gep_vec_fold(self):
|
||||||
|
@ -151,7 +154,7 @@ class TestUOpGraph(TestUOps):
|
||||||
def _test_vec(geps, count=4):
|
def _test_vec(geps, count=4):
|
||||||
vec = UOp(UOps.VECTORIZE, dtypes.float.vec(count), geps)
|
vec = UOp(UOps.VECTORIZE, dtypes.float.vec(count), geps)
|
||||||
out = UOp(UOps.STORE, None, (d0, idx, vec))
|
out = UOp(UOps.STORE, None, (d0, idx, vec))
|
||||||
uops = linearize_uop([out])
|
uops = to_uops_list([out])
|
||||||
if DEBUG >= 4:
|
if DEBUG >= 4:
|
||||||
from tinygrad import Device
|
from tinygrad import Device
|
||||||
print(Device[Device.DEFAULT].renderer.render("test", uops))
|
print(Device[Device.DEFAULT].renderer.render("test", uops))
|
||||||
|
@ -186,8 +189,7 @@ class TestUOpGraph(TestUOps):
|
||||||
for vec_size in [2, 4, 8]:
|
for vec_size in [2, 4, 8]:
|
||||||
consts = [UOp.const(dtypes.float, float(i)) for i in range(vec_size)]
|
consts = [UOp.const(dtypes.float, float(i)) for i in range(vec_size)]
|
||||||
vec = UOp(UOps.VECTORIZE, dtypes.float.vec(vec_size), tuple(consts))
|
vec = UOp(UOps.VECTORIZE, dtypes.float.vec(vec_size), tuple(consts))
|
||||||
geps = [UOp(UOps.GEP, dtypes.float, (vec,), i) for i in range(vec_size)]
|
uops = to_uops_list([UOp(UOps.GEP, dtypes.float, (vec,), i) for i in range(vec_size)])
|
||||||
uops = linearize_uop(geps)
|
|
||||||
for uop, const in zip(uops, consts):
|
for uop, const in zip(uops, consts):
|
||||||
self.assert_equiv_uops(uop, const)
|
self.assert_equiv_uops(uop, const)
|
||||||
|
|
||||||
|
@ -197,7 +199,7 @@ class TestUOpGraph(TestUOps):
|
||||||
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i))
|
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i))
|
||||||
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
|
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
|
||||||
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc))
|
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc))
|
||||||
uops = linearize_uop([wmma])
|
uops = to_uops_list([wmma])
|
||||||
self.assert_equiv_uops(uops[0], acc)
|
self.assert_equiv_uops(uops[0], acc)
|
||||||
self.assertEqual(len(uops), 1)
|
self.assertEqual(len(uops), 1)
|
||||||
|
|
||||||
|
@ -206,7 +208,7 @@ class TestUOpGraph(TestUOps):
|
||||||
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i), tuple(UOp.const(dtypes.half, 0.0) for _ in range(i)))
|
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i), tuple(UOp.const(dtypes.half, 0.0) for _ in range(i)))
|
||||||
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
|
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
|
||||||
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc))
|
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc))
|
||||||
uops = linearize_uop([wmma])
|
uops = to_uops_list([wmma])
|
||||||
self.assert_equiv_uops(uops[0], acc)
|
self.assert_equiv_uops(uops[0], acc)
|
||||||
self.assertEqual(len(uops), 1)
|
self.assertEqual(len(uops), 1)
|
||||||
|
|
||||||
|
@ -218,7 +220,7 @@ class TestUOpGraph(TestUOps):
|
||||||
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable(f'tmp{i}', 0.0, 1.0))
|
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable(f'tmp{i}', 0.0, 1.0))
|
||||||
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
|
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
|
||||||
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc))
|
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc))
|
||||||
uops = linearize_uop([wmma])
|
uops = to_uops_list([wmma])
|
||||||
self.assert_equiv_uops(uops[-1], wmma)
|
self.assert_equiv_uops(uops[-1], wmma)
|
||||||
|
|
||||||
for i in [4, 8]:
|
for i in [4, 8]:
|
||||||
|
@ -228,7 +230,7 @@ class TestUOpGraph(TestUOps):
|
||||||
tuple(UOp(UOps.DEFINE_VAR, dtypes.half, arg=Variable(f'tmp{j}', 0.0, 1.0)) for j in range(i//2)))
|
tuple(UOp(UOps.DEFINE_VAR, dtypes.half, arg=Variable(f'tmp{j}', 0.0, 1.0)) for j in range(i//2)))
|
||||||
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
|
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
|
||||||
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc))
|
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc))
|
||||||
uops = linearize_uop([wmma])
|
uops = to_uops_list([wmma])
|
||||||
self.assert_equiv_uops(uops[-1], wmma)
|
self.assert_equiv_uops(uops[-1], wmma)
|
||||||
|
|
||||||
for i in [2, 4, 8]:
|
for i in [2, 4, 8]:
|
||||||
|
@ -237,7 +239,7 @@ class TestUOpGraph(TestUOps):
|
||||||
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable(f'tmp{i}', 0.0, 1.0))
|
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable(f'tmp{i}', 0.0, 1.0))
|
||||||
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
|
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
|
||||||
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc))
|
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc))
|
||||||
uops = linearize_uop([wmma])
|
uops = to_uops_list([wmma])
|
||||||
self.assert_equiv_uops(uops[-1], wmma)
|
self.assert_equiv_uops(uops[-1], wmma)
|
||||||
|
|
||||||
for i in [2, 4, 8]:
|
for i in [2, 4, 8]:
|
||||||
|
@ -246,7 +248,7 @@ class TestUOpGraph(TestUOps):
|
||||||
tuple(UOp.const(dtypes.half, 1.0 if j == 0 else 0.0) for j in range(i)))
|
tuple(UOp.const(dtypes.half, 1.0 if j == 0 else 0.0) for j in range(i)))
|
||||||
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
|
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
|
||||||
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc))
|
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc))
|
||||||
uops = linearize_uop([wmma])
|
uops = to_uops_list([wmma])
|
||||||
self.assert_equiv_uops(uops[-1], wmma)
|
self.assert_equiv_uops(uops[-1], wmma)
|
||||||
|
|
||||||
def test_cast_alu_fold(self):
|
def test_cast_alu_fold(self):
|
||||||
|
@ -256,7 +258,7 @@ class TestUOpGraph(TestUOps):
|
||||||
ld = UOp(UOps.LOAD, dtypes.int, (d1, idx))
|
ld = UOp(UOps.LOAD, dtypes.int, (d1, idx))
|
||||||
alu = ld.lt(1).cast(dtypes.bool)
|
alu = ld.lt(1).cast(dtypes.bool)
|
||||||
out = UOp(UOps.STORE, None, (d0, idx, alu))
|
out = UOp(UOps.STORE, None, (d0, idx, alu))
|
||||||
uops = linearize_uop([out])
|
uops = to_uops_list([out])
|
||||||
self.assertEqual(len([x for x in uops if x.op is UOps.CAST]), 0)
|
self.assertEqual(len([x for x in uops if x.op is UOps.CAST]), 0)
|
||||||
|
|
||||||
def test_double_cast_fold(self):
|
def test_double_cast_fold(self):
|
||||||
|
@ -266,7 +268,7 @@ class TestUOpGraph(TestUOps):
|
||||||
ld = UOp(UOps.LOAD, dtypes.int, (d1, idx))
|
ld = UOp(UOps.LOAD, dtypes.int, (d1, idx))
|
||||||
alu = ld.cast(dtypes.float).cast(dtypes.float)
|
alu = ld.cast(dtypes.float).cast(dtypes.float)
|
||||||
out = UOp(UOps.STORE, None, (d0, idx, alu))
|
out = UOp(UOps.STORE, None, (d0, idx, alu))
|
||||||
uops = linearize_uop([out])
|
uops = to_uops_list([out])
|
||||||
self.assertEqual(len([x for x in uops if x.op is UOps.CAST]), 1)
|
self.assertEqual(len([x for x in uops if x.op is UOps.CAST]), 1)
|
||||||
|
|
||||||
def test_depth_2_const_fold(self):
|
def test_depth_2_const_fold(self):
|
||||||
|
@ -275,7 +277,7 @@ class TestUOpGraph(TestUOps):
|
||||||
c4 = UOp(UOps.CONST, dtypes.int, arg=4)
|
c4 = UOp(UOps.CONST, dtypes.int, arg=4)
|
||||||
vc = UOp(UOps.ALU, dtypes.int, (v, c2), BinaryOps.ADD)
|
vc = UOp(UOps.ALU, dtypes.int, (v, c2), BinaryOps.ADD)
|
||||||
out = UOp(UOps.ALU, dtypes.int, (vc, c4), BinaryOps.ADD)
|
out = UOp(UOps.ALU, dtypes.int, (vc, c4), BinaryOps.ADD)
|
||||||
uops = linearize_uop([out])
|
uops = to_uops_list([out])
|
||||||
self.assertEqual(len(uops), 5)
|
self.assertEqual(len(uops), 5)
|
||||||
out = uops[-1]
|
out = uops[-1]
|
||||||
self.assertEqual(out.op, UOps.ALU)
|
self.assertEqual(out.op, UOps.ALU)
|
||||||
|
@ -290,7 +292,7 @@ class TestUOpGraph(TestUOps):
|
||||||
idx = UOp.const(dtypes.int, 0)
|
idx = UOp.const(dtypes.int, 0)
|
||||||
ld0 = UOp(UOps.LOAD, dtypes.int, (glbl1, idx, UOp.const(dtypes.int, 2), UOp.const(dtypes.bool, False)))
|
ld0 = UOp(UOps.LOAD, dtypes.int, (glbl1, idx, UOp.const(dtypes.int, 2), UOp.const(dtypes.bool, False)))
|
||||||
ld1 = UOp(UOps.LOAD, dtypes.int, (glbl2, idx, UOp.const(dtypes.int, 3), UOp.const(dtypes.bool, True)))
|
ld1 = UOp(UOps.LOAD, dtypes.int, (glbl2, idx, UOp.const(dtypes.int, 3), UOp.const(dtypes.bool, True)))
|
||||||
uops = linearize_uop([UOp(UOps.STORE, None, (glbl0, idx, ld1+ld0))])
|
uops = to_uops_list([UOp(UOps.STORE, None, (glbl0, idx, ld1+ld0))])
|
||||||
ld0, ld1 = uops[-1].src[2].src
|
ld0, ld1 = uops[-1].src[2].src
|
||||||
# ld0 becomes the invalid value
|
# ld0 becomes the invalid value
|
||||||
self.assert_equiv_uops(ld1, UOp.const(dtypes.int, 2))
|
self.assert_equiv_uops(ld1, UOp.const(dtypes.int, 2))
|
||||||
|
@ -305,7 +307,7 @@ class TestUOpGraph(TestUOps):
|
||||||
barrier = UOp(UOps.BARRIER, None, (st, ))
|
barrier = UOp(UOps.BARRIER, None, (st, ))
|
||||||
ld0 = UOp(UOps.LOAD, dtypes.int, (smem, lidx+1, UOp.const(dtypes.int, 2), UOp.const(dtypes.bool, False), barrier))
|
ld0 = UOp(UOps.LOAD, dtypes.int, (smem, lidx+1, UOp.const(dtypes.int, 2), UOp.const(dtypes.bool, False), barrier))
|
||||||
ld1 = UOp(UOps.LOAD, dtypes.int, (smem, lidx+2, UOp.const(dtypes.int, 3), UOp.const(dtypes.bool, True), barrier))
|
ld1 = UOp(UOps.LOAD, dtypes.int, (smem, lidx+2, UOp.const(dtypes.int, 3), UOp.const(dtypes.bool, True), barrier))
|
||||||
uops = linearize_uop([UOp(UOps.STORE, None, (glbl0, lidx, ld1+ld0))])
|
uops = to_uops_list([UOp(UOps.STORE, None, (glbl0, lidx, ld1+ld0))])
|
||||||
ld0, ld1 = uops[-1].src[2].src
|
ld0, ld1 = uops[-1].src[2].src
|
||||||
# ld0 becomes the invalid value
|
# ld0 becomes the invalid value
|
||||||
self.assert_equiv_uops(ld1, UOp.const(dtypes.int, 2))
|
self.assert_equiv_uops(ld1, UOp.const(dtypes.int, 2))
|
||||||
|
@ -319,7 +321,7 @@ class TestUOpGraph(TestUOps):
|
||||||
val = UOp.const(dtypes.int, 42)
|
val = UOp.const(dtypes.int, 42)
|
||||||
st0 = UOp(UOps.STORE, None, (glbl, idx0, val, UOp.const(dtypes.bool, False)))
|
st0 = UOp(UOps.STORE, None, (glbl, idx0, val, UOp.const(dtypes.bool, False)))
|
||||||
st1 = UOp(UOps.STORE, None, (glbl, idx1, val, UOp.const(dtypes.bool, True)))
|
st1 = UOp(UOps.STORE, None, (glbl, idx1, val, UOp.const(dtypes.bool, True)))
|
||||||
uops = linearize_uop([st0, st1])
|
uops = to_uops_list([st0, st1])
|
||||||
# only the second store happens
|
# only the second store happens
|
||||||
self.assertEqual(len(uops), 4)
|
self.assertEqual(len(uops), 4)
|
||||||
self.assert_equiv_uops(uops[-1], UOp.store(glbl, idx1, val))
|
self.assert_equiv_uops(uops[-1], UOp.store(glbl, idx1, val))
|
||||||
|
@ -328,7 +330,7 @@ class TestUOpGraph(TestUOps):
|
||||||
glbl0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0)
|
glbl0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0)
|
||||||
idx = UOp.const(dtypes.int, 0)
|
idx = UOp.const(dtypes.int, 0)
|
||||||
bad_gate = UOp.const(dtypes.int, 1)
|
bad_gate = UOp.const(dtypes.int, 1)
|
||||||
with self.assertRaises(AssertionError): linearize_uop([UOp(UOps.STORE, None, (glbl0, idx, UOp.const(dtypes.int, 42), bad_gate))])
|
with self.assertRaises(AssertionError): to_uops_list([UOp(UOps.STORE, None, (glbl0, idx, UOp.const(dtypes.int, 42), bad_gate))])
|
||||||
|
|
||||||
def test_switched_range_order(self):
|
def test_switched_range_order(self):
|
||||||
glbl = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0)
|
glbl = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0)
|
||||||
|
@ -339,7 +341,7 @@ class TestUOpGraph(TestUOps):
|
||||||
r2 = UOp(UOps.RANGE, dtypes.int, (c0, c2), (1, 1, False))
|
r2 = UOp(UOps.RANGE, dtypes.int, (c0, c2), (1, 1, False))
|
||||||
alu = UOp(UOps.ALU, dtypes.int, (r2, r1), BinaryOps.MUL)
|
alu = UOp(UOps.ALU, dtypes.int, (r2, r1), BinaryOps.MUL)
|
||||||
store = UOp(UOps.STORE, None, (glbl, alu, cf))
|
store = UOp(UOps.STORE, None, (glbl, alu, cf))
|
||||||
uops = linearize_uop([store])
|
uops = to_uops_list([store])
|
||||||
ranges = [x for x in uops if x.op is UOps.RANGE]
|
ranges = [x for x in uops if x.op is UOps.RANGE]
|
||||||
endranges = [x for x in uops if x.op is UOps.ENDRANGE]
|
endranges = [x for x in uops if x.op is UOps.ENDRANGE]
|
||||||
# ranges are closed in the right order
|
# ranges are closed in the right order
|
||||||
|
|
|
@ -9,11 +9,13 @@ from tinygrad.ops import UOps, NOp, UOp, UnaryOps, BinaryOps, TernaryOps, Reduce
|
||||||
from tinygrad.renderer import Program
|
from tinygrad.renderer import Program
|
||||||
from tinygrad.engine.schedule import create_schedule
|
from tinygrad.engine.schedule import create_schedule
|
||||||
from tinygrad.engine.realize import CompiledRunner, lower_schedule_item, get_kernel
|
from tinygrad.engine.realize import CompiledRunner, lower_schedule_item, get_kernel
|
||||||
from tinygrad.codegen.uopgraph import linearize_uop
|
from tinygrad.codegen.uopgraph import linearize_uop, full_graph_rewrite
|
||||||
from test.helpers import is_dtype_supported, TestUOps as TestEqUOps
|
from test.helpers import is_dtype_supported, TestUOps as TestEqUOps
|
||||||
|
|
||||||
|
def to_uops_list(u:List[UOp], opts=None, skip_check=False) -> List[UOp]: return linearize_uop(full_graph_rewrite(UOp.sink(*u), opts), skip_check)
|
||||||
|
|
||||||
def _uops_to_prg(uops_list):
|
def _uops_to_prg(uops_list):
|
||||||
uops = linearize_uop(uops_list, opts=Device[Device.DEFAULT].renderer)
|
uops = linearize_uop(full_graph_rewrite(UOp.sink(*uops_list), opts=Device[Device.DEFAULT].renderer))
|
||||||
src = Device[Device.DEFAULT].renderer.render("test", uops)
|
src = Device[Device.DEFAULT].renderer.render("test", uops)
|
||||||
has_local = Device[Device.DEFAULT].renderer.has_local
|
has_local = Device[Device.DEFAULT].renderer.has_local
|
||||||
return CompiledRunner(Program("test", src, Device.DEFAULT, uops=uops,
|
return CompiledRunner(Program("test", src, Device.DEFAULT, uops=uops,
|
||||||
|
@ -255,7 +257,7 @@ class TestGatedStoreRewrite(unittest.TestCase):
|
||||||
val = UOp.const(dtypes.float, 42.0)
|
val = UOp.const(dtypes.float, 42.0)
|
||||||
gate = gidx0.lt(UOp.const(dtypes.int, 1))
|
gate = gidx0.lt(UOp.const(dtypes.int, 1))
|
||||||
store = UOp(UOps.STORE, None, (gmem, idx, val, gate))
|
store = UOp(UOps.STORE, None, (gmem, idx, val, gate))
|
||||||
uops = linearize_uop([store])
|
uops = to_uops_list([store])
|
||||||
if DEBUG >= 4: print(Device[Device.DEFAULT].renderer.render("test", uops))
|
if DEBUG >= 4: print(Device[Device.DEFAULT].renderer.render("test", uops))
|
||||||
if_uop = next(u for u in uops if u.op is UOps.IF)
|
if_uop = next(u for u in uops if u.op is UOps.IF)
|
||||||
endif = next(u for u in uops if u.op is UOps.ENDIF)
|
endif = next(u for u in uops if u.op is UOps.ENDIF)
|
||||||
|
@ -334,7 +336,7 @@ class TestAssembly(unittest.TestCase):
|
||||||
l1 = UOp(UOps.LOAD, dtypes.int, (g1, c1))
|
l1 = UOp(UOps.LOAD, dtypes.int, (g1, c1))
|
||||||
a1 = UOp(UOps.ALU, dtypes.int, (l1, c1), BinaryOps.MUL)
|
a1 = UOp(UOps.ALU, dtypes.int, (l1, c1), BinaryOps.MUL)
|
||||||
a2 = UOp(UOps.ALU, dtypes.int, (l1, c2), BinaryOps.MUL)
|
a2 = UOp(UOps.ALU, dtypes.int, (l1, c2), BinaryOps.MUL)
|
||||||
uops = linearize_uop([a1,a2], opts=Device[Device.DEFAULT].renderer)
|
uops = to_uops_list([a1,a2], opts=Device[Device.DEFAULT].renderer)
|
||||||
Device[Device.DEFAULT].renderer.render("test", uops)
|
Device[Device.DEFAULT].renderer.render("test", uops)
|
||||||
self.assertEqual(uops[-1].arg, BinaryOps.SHL)
|
self.assertEqual(uops[-1].arg, BinaryOps.SHL)
|
||||||
self.assertEqual(uops[-2].arg, BinaryOps.MUL)
|
self.assertEqual(uops[-2].arg, BinaryOps.MUL)
|
||||||
|
@ -346,7 +348,7 @@ class TestAssembly(unittest.TestCase):
|
||||||
l1 = UOp(UOps.LOAD, dtypes.int, (g1, c1))
|
l1 = UOp(UOps.LOAD, dtypes.int, (g1, c1))
|
||||||
a1 = UOp(UOps.ALU, dtypes.int, (l1, c1), BinaryOps.IDIV)
|
a1 = UOp(UOps.ALU, dtypes.int, (l1, c1), BinaryOps.IDIV)
|
||||||
a2 = UOp(UOps.ALU, dtypes.int, (l1, c2), BinaryOps.IDIV)
|
a2 = UOp(UOps.ALU, dtypes.int, (l1, c2), BinaryOps.IDIV)
|
||||||
uops = linearize_uop([a1,a2], opts=Device[Device.DEFAULT].renderer)
|
uops = to_uops_list([a1,a2], opts=Device[Device.DEFAULT].renderer)
|
||||||
Device[Device.DEFAULT].renderer.render("test", uops)
|
Device[Device.DEFAULT].renderer.render("test", uops)
|
||||||
self.assertEqual(uops[-1].arg, BinaryOps.SHR)
|
self.assertEqual(uops[-1].arg, BinaryOps.SHR)
|
||||||
self.assertEqual(uops[-2].arg, BinaryOps.IDIV)
|
self.assertEqual(uops[-2].arg, BinaryOps.IDIV)
|
||||||
|
@ -385,7 +387,7 @@ class TestIndexingOrdering(unittest.TestCase):
|
||||||
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 0)
|
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 0)
|
||||||
st0 = UOp(UOps.STORE, dtypes.float.vec(4), (buf, UOp.const(dtypes.int, 0), UOp.const(dtypes.float.vec(4), 42)))
|
st0 = UOp(UOps.STORE, dtypes.float.vec(4), (buf, UOp.const(dtypes.int, 0), UOp.const(dtypes.float.vec(4), 42)))
|
||||||
st1 = UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, 4), UOp.const(dtypes.float, 10)))
|
st1 = UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, 4), UOp.const(dtypes.float, 10)))
|
||||||
uops = linearize_uop([st1, st0], skip_check=True)
|
uops = to_uops_list([st1, st0], skip_check=True)
|
||||||
stores = [st for st in uops if st.op is UOps.STORE]
|
stores = [st for st in uops if st.op is UOps.STORE]
|
||||||
assert stores[0].src[1] < stores[1].src[1], f"stored at idx {stores[1].src[1].arg} AFTER {stores[0].src[1].arg}"
|
assert stores[0].src[1] < stores[1].src[1], f"stored at idx {stores[1].src[1].arg} AFTER {stores[0].src[1].arg}"
|
||||||
|
|
||||||
|
@ -397,7 +399,7 @@ class TestIndexingOrdering(unittest.TestCase):
|
||||||
st1_0 = UOp(UOps.STORE, dtypes.float, (buf0, UOp.const(dtypes.int, 4), UOp.const(dtypes.float, 10)))
|
st1_0 = UOp(UOps.STORE, dtypes.float, (buf0, UOp.const(dtypes.int, 4), UOp.const(dtypes.float, 10)))
|
||||||
st0_1 = UOp(UOps.STORE, dtypes.float.vec(4), (buf1, UOp.const(dtypes.int, 0), UOp.const(dtypes.float.vec(4), 42)))
|
st0_1 = UOp(UOps.STORE, dtypes.float.vec(4), (buf1, UOp.const(dtypes.int, 0), UOp.const(dtypes.float.vec(4), 42)))
|
||||||
st1_1 = UOp(UOps.STORE, dtypes.float, (buf1, UOp.const(dtypes.int, 4), UOp.const(dtypes.float, 10)))
|
st1_1 = UOp(UOps.STORE, dtypes.float, (buf1, UOp.const(dtypes.int, 4), UOp.const(dtypes.float, 10)))
|
||||||
uops = linearize_uop([st0_0, st1_0, st0_1, st1_1], skip_check=True)
|
uops = to_uops_list([st0_0, st1_0, st0_1, st1_1], skip_check=True)
|
||||||
stores = [st for st in uops if st.op is UOps.STORE]
|
stores = [st for st in uops if st.op is UOps.STORE]
|
||||||
print("\n".join(map(str, stores)))
|
print("\n".join(map(str, stores)))
|
||||||
# buf0 stores come first
|
# buf0 stores come first
|
||||||
|
@ -413,7 +415,7 @@ class TestIndexingOrdering(unittest.TestCase):
|
||||||
gidx0 = UOp(UOps.SPECIAL, dtypes.int, (), ('gidx0', 4))
|
gidx0 = UOp(UOps.SPECIAL, dtypes.int, (), ('gidx0', 4))
|
||||||
st0 = UOp(UOps.STORE, dtypes.float.vec(4), (buf, gidx0+UOp.const(dtypes.int, 0), UOp.const(dtypes.float.vec(4), 42)))
|
st0 = UOp(UOps.STORE, dtypes.float.vec(4), (buf, gidx0+UOp.const(dtypes.int, 0), UOp.const(dtypes.float.vec(4), 42)))
|
||||||
st1 = UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, 4), UOp.const(dtypes.float, 10)))
|
st1 = UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, 4), UOp.const(dtypes.float, 10)))
|
||||||
uops = linearize_uop([st1, st0], skip_check=True)
|
uops = linearize_uop(UOp.sink(st1, st0), skip_check=True)
|
||||||
stores = [st for st in uops if st.op is UOps.STORE]
|
stores = [st for st in uops if st.op is UOps.STORE]
|
||||||
assert stores[0].src[1] < stores[1].src[1], f"stored at idx {stores[1].src[1].arg} AFTER {stores[0].src[1].arg}"
|
assert stores[0].src[1] < stores[1].src[1], f"stored at idx {stores[1].src[1].arg} AFTER {stores[0].src[1].arg}"
|
||||||
|
|
||||||
|
|
|
@ -105,7 +105,7 @@ class TestUOpsStats(unittest.TestCase):
|
||||||
u3 = UOp(UOps.CONST, dtypes.int, tuple(), 3)
|
u3 = UOp(UOps.CONST, dtypes.int, tuple(), 3)
|
||||||
u4 = UOp(UOps.ALU, dtypes.int, (u1,u2), BinaryOps.MUL)
|
u4 = UOp(UOps.ALU, dtypes.int, (u1,u2), BinaryOps.MUL)
|
||||||
u5 = UOp(UOps.ALU, dtypes.int, (u4,u3), BinaryOps.ADD)
|
u5 = UOp(UOps.ALU, dtypes.int, (u4,u3), BinaryOps.ADD)
|
||||||
uops = linearize_uop([u5])
|
uops = linearize_uop(u5.sink())
|
||||||
|
|
||||||
globl = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), tuple())
|
globl = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), tuple())
|
||||||
o1 = UOp(UOps.CONST, dtypes.int, tuple(), 1)
|
o1 = UOp(UOps.CONST, dtypes.int, tuple(), 1)
|
||||||
|
@ -114,7 +114,7 @@ class TestUOpsStats(unittest.TestCase):
|
||||||
u2 = UOp(UOps.LOAD, dtypes.int, (globl, o2))
|
u2 = UOp(UOps.LOAD, dtypes.int, (globl, o2))
|
||||||
u3 = UOp(UOps.CONST, dtypes.int, tuple(), 3)
|
u3 = UOp(UOps.CONST, dtypes.int, tuple(), 3)
|
||||||
u4 = UOp(UOps.ALU, dtypes.int, (u1,u2,u3), TernaryOps.MULACC)
|
u4 = UOp(UOps.ALU, dtypes.int, (u1,u2,u3), TernaryOps.MULACC)
|
||||||
uops_fma = linearize_uop([u4])
|
uops_fma = linearize_uop(u4.sink())
|
||||||
|
|
||||||
self.assertEqual(flops_mem(uops), flops_mem(uops_fma))
|
self.assertEqual(flops_mem(uops), flops_mem(uops_fma))
|
||||||
|
|
||||||
|
|
|
@ -9,14 +9,14 @@ from typing import Tuple
|
||||||
|
|
||||||
from tinygrad.helpers import DEBUG
|
from tinygrad.helpers import DEBUG
|
||||||
from tinygrad.dtype import dtypes, PtrDType, ConstType
|
from tinygrad.dtype import dtypes, PtrDType, ConstType
|
||||||
from tinygrad.codegen.uopgraph import linearize_uop
|
from tinygrad.codegen.uopgraph import linearize_uop, full_graph_rewrite
|
||||||
from tinygrad.ops import BinaryOps, UOp, UOps, print_uops
|
from tinygrad.ops import BinaryOps, UOp, UOps, print_uops
|
||||||
import functools
|
import functools
|
||||||
|
|
||||||
def render(self) -> Tuple[str, ConstType, ConstType]:
|
def render(self) -> Tuple[str, ConstType, ConstType]:
|
||||||
# NOTE: we need STORE so the ALU op has children
|
# NOTE: we need STORE so the ALU op has children
|
||||||
glbl = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=0)
|
glbl = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=0)
|
||||||
uops = linearize_uop([UOp(UOps.STORE, None, (glbl, UOp.const(dtypes.int, 0), self))])
|
uops = linearize_uop(full_graph_rewrite(UOp(UOps.STORE, None, (glbl, UOp.const(dtypes.int, 0), self)).sink()))
|
||||||
if DEBUG>=5: print_uops(uops)
|
if DEBUG>=5: print_uops(uops)
|
||||||
from tinygrad.renderer.cstyle import CStyleLanguage
|
from tinygrad.renderer.cstyle import CStyleLanguage
|
||||||
class TestRenderer(CStyleLanguage):
|
class TestRenderer(CStyleLanguage):
|
||||||
|
|
|
@ -13,7 +13,7 @@ from tinygrad.helpers import all_same, colored, ansilen, dedup, getenv, prod, DE
|
||||||
from tinygrad.shape.shapetracker import ShapeTracker
|
from tinygrad.shape.shapetracker import ShapeTracker
|
||||||
from tinygrad.shape.symbolic import Variable, sint
|
from tinygrad.shape.symbolic import Variable, sint
|
||||||
from tinygrad.shape.view import strides_for_shape
|
from tinygrad.shape.view import strides_for_shape
|
||||||
from tinygrad.codegen.uopgraph import linearize_uop
|
from tinygrad.codegen.uopgraph import linearize_uop, full_graph_rewrite
|
||||||
from tinygrad.codegen.lowerer import ast_to_uop
|
from tinygrad.codegen.lowerer import ast_to_uop
|
||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
|
|
||||||
|
@ -745,7 +745,7 @@ class Kernel:
|
||||||
print(self.applied_opts)
|
print(self.applied_opts)
|
||||||
verify_ast(modified_ast)
|
verify_ast(modified_ast)
|
||||||
|
|
||||||
self.uops:List[UOp] = linearize_uop(ast_to_uop(modified_ast, self.opts), self.opts)
|
self.uops:List[UOp] = linearize_uop(full_graph_rewrite(ast_to_uop(modified_ast, self.opts), self.opts))
|
||||||
if DEBUG >= 5: print_uops(self.uops)
|
if DEBUG >= 5: print_uops(self.uops)
|
||||||
if getenv("GRAPHUOPS"):
|
if getenv("GRAPHUOPS"):
|
||||||
from tinygrad.engine.graph import graph_uops
|
from tinygrad.engine.graph import graph_uops
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from typing import Optional, Tuple, Dict, List, Set, Union, cast, TYPE_CHECKING, Any, DefaultDict, Callable
|
from typing import Optional, Tuple, Dict, List, Set, cast, TYPE_CHECKING, Any, DefaultDict, Callable
|
||||||
import functools, itertools, heapq, math, operator
|
import functools, itertools, heapq, math, operator
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from tinygrad.dtype import dtypes, PtrDType, ImageDType, DType
|
from tinygrad.dtype import dtypes, PtrDType, ImageDType, DType
|
||||||
|
@ -512,9 +512,8 @@ def get_children_dfs(u:UOp, children:Dict[UOp, List[UOp]], srcs:Dict[UOp, Dict[U
|
||||||
return srcs[u]
|
return srcs[u]
|
||||||
|
|
||||||
linearize_cnt = 0
|
linearize_cnt = 0
|
||||||
def linearize_uop(sink_in:Union[UOp, List[UOp]], opts:Optional[Renderer]=None, skip_check=False) -> List[UOp]:
|
def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp:
|
||||||
global linearize_cnt, acc_number
|
global linearize_cnt, acc_number
|
||||||
sink: UOp = sink_in if isinstance(sink_in, UOp) else UOp(UOps.SINK, None, tuple(sink_in))
|
|
||||||
assert sink.op is UOps.SINK, f"sink isn't sink, it's {sink.op}"
|
assert sink.op is UOps.SINK, f"sink isn't sink, it's {sink.op}"
|
||||||
folder = constant_folder + transcendental_folding(tuple() if TRANSCENDENTAL >= 2 or opts is None else tuple(opts.code_for_op.keys()))
|
folder = constant_folder + transcendental_folding(tuple() if TRANSCENDENTAL >= 2 or opts is None else tuple(opts.code_for_op.keys()))
|
||||||
|
|
||||||
|
@ -533,7 +532,10 @@ def linearize_uop(sink_in:Union[UOp, List[UOp]], opts:Optional[Renderer]=None, s
|
||||||
|
|
||||||
# for PTX only
|
# for PTX only
|
||||||
if opts is not None and opts.extra_matcher is not None: sink = graph_rewrite(sink, folder+opts.extra_matcher)
|
if opts is not None and opts.extra_matcher is not None: sink = graph_rewrite(sink, folder+opts.extra_matcher)
|
||||||
|
return sink
|
||||||
|
|
||||||
|
def linearize_uop(sink:UOp, skip_check:bool=False) -> List[UOp]:
|
||||||
|
assert sink.op is UOps.SINK, f"sink isn't sink, it's {sink.op}"
|
||||||
# filter nodes that don't link to a sink
|
# filter nodes that don't link to a sink
|
||||||
# BFS toposort
|
# BFS toposort
|
||||||
children: Dict[UOp, List[UOp]] = {}
|
children: Dict[UOp, List[UOp]] = {}
|
||||||
|
|
|
@ -123,6 +123,7 @@ class UOp:
|
||||||
ret = self.src[0 if self.op is UOps.CONST else 1]
|
ret = self.src[0 if self.op is UOps.CONST else 1]
|
||||||
assert ret.op is UOps.SHAPETRACKER, f"st_arg trying to return {ret}"
|
assert ret.op is UOps.SHAPETRACKER, f"st_arg trying to return {ret}"
|
||||||
return ret.arg
|
return ret.arg
|
||||||
|
def sink(self, *srcs): return UOp(UOps.SINK, None, (self,)+srcs)
|
||||||
def ufix(self, x): return self.const(x) if not isinstance(x, UOp) else x
|
def ufix(self, x): return self.const(x) if not isinstance(x, UOp) else x
|
||||||
def cast(self, dtype=None): return type(self)(UOps.CAST, dtype, (self,))
|
def cast(self, dtype=None): return type(self)(UOps.CAST, dtype, (self,))
|
||||||
def bitcast(self, dtype=None): return type(self)(UOps.BITCAST, dtype, (self,))
|
def bitcast(self, dtype=None): return type(self)(UOps.BITCAST, dtype, (self,))
|
||||||
|
|
Loading…
Reference in New Issue