split full_graph_rewrite and linearize_uop [run_process_replay] (#6215)

* split full_graph_rewrite and linearize_uop

* fix tests

* graph rewrite in test uops

* add types
This commit is contained in:
George Hotz 2024-08-20 20:12:33 -07:00 committed by GitHub
parent 9faf205601
commit 16f420f7a7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 48 additions and 41 deletions

View File

@ -148,7 +148,7 @@ class TestIndexing(unittest.TestCase):
def test_index_mnist_opt(self): self.test_index_mnist(0)
@unittest.skipIf(getenv("PTX"), "broken on ptx for some reason")
def test_llama_embedding(self, noopt=1, op_limit=100):
def test_llama_embedding(self, noopt=1, op_limit=65536):
# llama3 is 128256
vocab_size, embed_size = (10, 3) if CI else (32000, 4096)
emb = nn.Embedding(vocab_size, embed_size)

View File

@ -1,10 +1,11 @@
from typing import List
import unittest
from test.helpers import TestUOps
from tinygrad import dtypes, Variable
from tinygrad.dtype import PtrDType
from tinygrad.helpers import DEBUG
from tinygrad.ops import BinaryOps, TernaryOps, UnaryOps, ReduceOps, UOps, UOp, NOp, PatternMatcher
from tinygrad.codegen.uopgraph import linearize_uop, graph_rewrite, expander, reducer, constant_folder, float4_folding
from tinygrad.codegen.uopgraph import linearize_uop, full_graph_rewrite, graph_rewrite, expander, reducer, constant_folder, float4_folding
simple_pm = PatternMatcher([
(NOp.cvar('x', dtypes.int), lambda x: UOp.const(dtypes.float, 1.0) + UOp.const(dtypes.float, 2.0)),
@ -13,6 +14,8 @@ simple_pm = PatternMatcher([
((NOp.var('x') + NOp.cvar('c1')) + NOp.cvar('c2'), lambda x,c1,c2: x + x.const(c1.arg+c2.arg)),
])
def to_uops_list(u:List[UOp]) -> List[UOp]: return linearize_uop(full_graph_rewrite(UOp.sink(*u)))
class TestGraphRewrite(unittest.TestCase):
def test_dedup(self):
v1 = UOp(UOps.DEFINE_VAR, dtypes.float)
@ -94,7 +97,7 @@ class TestUOpGraph(TestUOps):
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
c2 = UOp(UOps.CONST, dtypes.float, arg=2.0)
out = UOp(UOps.ALU, dtypes.float, (c1, c2), BinaryOps.ADD)
uops = linearize_uop([out])
uops = to_uops_list([out])
self.assertEqual(len(uops), 1)
out = uops[-1]
self.assertEqual(out.op, UOps.CONST)
@ -106,7 +109,7 @@ class TestUOpGraph(TestUOps):
vc = UOp(UOps.ALU, dtypes.bool, (v, c0), BinaryOps.CMPNE)
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
out = UOp(UOps.ALU, dtypes.float, (vc, c1, c1), TernaryOps.WHERE)
uops = linearize_uop([out])
uops = to_uops_list([out])
self.assertEqual(len(uops), 1)
out = uops[-1]
self.assertEqual(out.op, UOps.CONST)
@ -117,7 +120,7 @@ class TestUOpGraph(TestUOps):
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
c2 = UOp(UOps.CONST, dtypes.float, arg=2.0)
out = UOp(UOps.ALU, dtypes.float, (bf, c1, c2), TernaryOps.WHERE)
uops = linearize_uop([out])
uops = to_uops_list([out])
self.assertEqual(len(uops), 1)
out = uops[-1]
self.assertEqual(out.op, UOps.CONST)
@ -126,7 +129,7 @@ class TestUOpGraph(TestUOps):
def test_const_cast(self):
bf = UOp(UOps.CONST, dtypes.bool, arg=False)
out = UOp(UOps.CAST, dtypes.int, (bf,))
uops = linearize_uop([out])
uops = to_uops_list([out])
self.assertEqual(len(uops), 1)
out = uops[-1]
self.assertEqual(out.op, UOps.CONST)
@ -140,7 +143,7 @@ class TestUOpGraph(TestUOps):
x = UOp(UOps.GEP, dtypes.float, (vec, ), arg=0)
alu = UOp(UOps.ALU, dtypes.float, (x, ), UnaryOps.SQRT)
out = UOp(UOps.STORE, None, (d0, idx, alu))
uops = linearize_uop([out])
uops = to_uops_list([out])
self.assertEqual(len([x for x in uops if x.op is UOps.VECTORIZE]), 0)
def test_gep_vec_fold(self):
@ -151,7 +154,7 @@ class TestUOpGraph(TestUOps):
def _test_vec(geps, count=4):
vec = UOp(UOps.VECTORIZE, dtypes.float.vec(count), geps)
out = UOp(UOps.STORE, None, (d0, idx, vec))
uops = linearize_uop([out])
uops = to_uops_list([out])
if DEBUG >= 4:
from tinygrad import Device
print(Device[Device.DEFAULT].renderer.render("test", uops))
@ -186,8 +189,7 @@ class TestUOpGraph(TestUOps):
for vec_size in [2, 4, 8]:
consts = [UOp.const(dtypes.float, float(i)) for i in range(vec_size)]
vec = UOp(UOps.VECTORIZE, dtypes.float.vec(vec_size), tuple(consts))
geps = [UOp(UOps.GEP, dtypes.float, (vec,), i) for i in range(vec_size)]
uops = linearize_uop(geps)
uops = to_uops_list([UOp(UOps.GEP, dtypes.float, (vec,), i) for i in range(vec_size)])
for uop, const in zip(uops, consts):
self.assert_equiv_uops(uop, const)
@ -197,7 +199,7 @@ class TestUOpGraph(TestUOps):
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i))
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc))
uops = linearize_uop([wmma])
uops = to_uops_list([wmma])
self.assert_equiv_uops(uops[0], acc)
self.assertEqual(len(uops), 1)
@ -206,7 +208,7 @@ class TestUOpGraph(TestUOps):
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i), tuple(UOp.const(dtypes.half, 0.0) for _ in range(i)))
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc))
uops = linearize_uop([wmma])
uops = to_uops_list([wmma])
self.assert_equiv_uops(uops[0], acc)
self.assertEqual(len(uops), 1)
@ -218,7 +220,7 @@ class TestUOpGraph(TestUOps):
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable(f'tmp{i}', 0.0, 1.0))
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc))
uops = linearize_uop([wmma])
uops = to_uops_list([wmma])
self.assert_equiv_uops(uops[-1], wmma)
for i in [4, 8]:
@ -228,7 +230,7 @@ class TestUOpGraph(TestUOps):
tuple(UOp(UOps.DEFINE_VAR, dtypes.half, arg=Variable(f'tmp{j}', 0.0, 1.0)) for j in range(i//2)))
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc))
uops = linearize_uop([wmma])
uops = to_uops_list([wmma])
self.assert_equiv_uops(uops[-1], wmma)
for i in [2, 4, 8]:
@ -237,7 +239,7 @@ class TestUOpGraph(TestUOps):
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable(f'tmp{i}', 0.0, 1.0))
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc))
uops = linearize_uop([wmma])
uops = to_uops_list([wmma])
self.assert_equiv_uops(uops[-1], wmma)
for i in [2, 4, 8]:
@ -246,7 +248,7 @@ class TestUOpGraph(TestUOps):
tuple(UOp.const(dtypes.half, 1.0 if j == 0 else 0.0) for j in range(i)))
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc))
uops = linearize_uop([wmma])
uops = to_uops_list([wmma])
self.assert_equiv_uops(uops[-1], wmma)
def test_cast_alu_fold(self):
@ -256,7 +258,7 @@ class TestUOpGraph(TestUOps):
ld = UOp(UOps.LOAD, dtypes.int, (d1, idx))
alu = ld.lt(1).cast(dtypes.bool)
out = UOp(UOps.STORE, None, (d0, idx, alu))
uops = linearize_uop([out])
uops = to_uops_list([out])
self.assertEqual(len([x for x in uops if x.op is UOps.CAST]), 0)
def test_double_cast_fold(self):
@ -266,7 +268,7 @@ class TestUOpGraph(TestUOps):
ld = UOp(UOps.LOAD, dtypes.int, (d1, idx))
alu = ld.cast(dtypes.float).cast(dtypes.float)
out = UOp(UOps.STORE, None, (d0, idx, alu))
uops = linearize_uop([out])
uops = to_uops_list([out])
self.assertEqual(len([x for x in uops if x.op is UOps.CAST]), 1)
def test_depth_2_const_fold(self):
@ -275,7 +277,7 @@ class TestUOpGraph(TestUOps):
c4 = UOp(UOps.CONST, dtypes.int, arg=4)
vc = UOp(UOps.ALU, dtypes.int, (v, c2), BinaryOps.ADD)
out = UOp(UOps.ALU, dtypes.int, (vc, c4), BinaryOps.ADD)
uops = linearize_uop([out])
uops = to_uops_list([out])
self.assertEqual(len(uops), 5)
out = uops[-1]
self.assertEqual(out.op, UOps.ALU)
@ -290,7 +292,7 @@ class TestUOpGraph(TestUOps):
idx = UOp.const(dtypes.int, 0)
ld0 = UOp(UOps.LOAD, dtypes.int, (glbl1, idx, UOp.const(dtypes.int, 2), UOp.const(dtypes.bool, False)))
ld1 = UOp(UOps.LOAD, dtypes.int, (glbl2, idx, UOp.const(dtypes.int, 3), UOp.const(dtypes.bool, True)))
uops = linearize_uop([UOp(UOps.STORE, None, (glbl0, idx, ld1+ld0))])
uops = to_uops_list([UOp(UOps.STORE, None, (glbl0, idx, ld1+ld0))])
ld0, ld1 = uops[-1].src[2].src
# ld0 becomes the invalid value
self.assert_equiv_uops(ld1, UOp.const(dtypes.int, 2))
@ -305,7 +307,7 @@ class TestUOpGraph(TestUOps):
barrier = UOp(UOps.BARRIER, None, (st, ))
ld0 = UOp(UOps.LOAD, dtypes.int, (smem, lidx+1, UOp.const(dtypes.int, 2), UOp.const(dtypes.bool, False), barrier))
ld1 = UOp(UOps.LOAD, dtypes.int, (smem, lidx+2, UOp.const(dtypes.int, 3), UOp.const(dtypes.bool, True), barrier))
uops = linearize_uop([UOp(UOps.STORE, None, (glbl0, lidx, ld1+ld0))])
uops = to_uops_list([UOp(UOps.STORE, None, (glbl0, lidx, ld1+ld0))])
ld0, ld1 = uops[-1].src[2].src
# ld0 becomes the invalid value
self.assert_equiv_uops(ld1, UOp.const(dtypes.int, 2))
@ -319,7 +321,7 @@ class TestUOpGraph(TestUOps):
val = UOp.const(dtypes.int, 42)
st0 = UOp(UOps.STORE, None, (glbl, idx0, val, UOp.const(dtypes.bool, False)))
st1 = UOp(UOps.STORE, None, (glbl, idx1, val, UOp.const(dtypes.bool, True)))
uops = linearize_uop([st0, st1])
uops = to_uops_list([st0, st1])
# only the second store happens
self.assertEqual(len(uops), 4)
self.assert_equiv_uops(uops[-1], UOp.store(glbl, idx1, val))
@ -328,7 +330,7 @@ class TestUOpGraph(TestUOps):
glbl0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0)
idx = UOp.const(dtypes.int, 0)
bad_gate = UOp.const(dtypes.int, 1)
with self.assertRaises(AssertionError): linearize_uop([UOp(UOps.STORE, None, (glbl0, idx, UOp.const(dtypes.int, 42), bad_gate))])
with self.assertRaises(AssertionError): to_uops_list([UOp(UOps.STORE, None, (glbl0, idx, UOp.const(dtypes.int, 42), bad_gate))])
def test_switched_range_order(self):
glbl = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0)
@ -339,7 +341,7 @@ class TestUOpGraph(TestUOps):
r2 = UOp(UOps.RANGE, dtypes.int, (c0, c2), (1, 1, False))
alu = UOp(UOps.ALU, dtypes.int, (r2, r1), BinaryOps.MUL)
store = UOp(UOps.STORE, None, (glbl, alu, cf))
uops = linearize_uop([store])
uops = to_uops_list([store])
ranges = [x for x in uops if x.op is UOps.RANGE]
endranges = [x for x in uops if x.op is UOps.ENDRANGE]
# ranges are closed in the right order

View File

@ -9,11 +9,13 @@ from tinygrad.ops import UOps, NOp, UOp, UnaryOps, BinaryOps, TernaryOps, Reduce
from tinygrad.renderer import Program
from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import CompiledRunner, lower_schedule_item, get_kernel
from tinygrad.codegen.uopgraph import linearize_uop
from tinygrad.codegen.uopgraph import linearize_uop, full_graph_rewrite
from test.helpers import is_dtype_supported, TestUOps as TestEqUOps
def to_uops_list(u:List[UOp], opts=None, skip_check=False) -> List[UOp]: return linearize_uop(full_graph_rewrite(UOp.sink(*u), opts), skip_check)
def _uops_to_prg(uops_list):
uops = linearize_uop(uops_list, opts=Device[Device.DEFAULT].renderer)
uops = linearize_uop(full_graph_rewrite(UOp.sink(*uops_list), opts=Device[Device.DEFAULT].renderer))
src = Device[Device.DEFAULT].renderer.render("test", uops)
has_local = Device[Device.DEFAULT].renderer.has_local
return CompiledRunner(Program("test", src, Device.DEFAULT, uops=uops,
@ -255,7 +257,7 @@ class TestGatedStoreRewrite(unittest.TestCase):
val = UOp.const(dtypes.float, 42.0)
gate = gidx0.lt(UOp.const(dtypes.int, 1))
store = UOp(UOps.STORE, None, (gmem, idx, val, gate))
uops = linearize_uop([store])
uops = to_uops_list([store])
if DEBUG >= 4: print(Device[Device.DEFAULT].renderer.render("test", uops))
if_uop = next(u for u in uops if u.op is UOps.IF)
endif = next(u for u in uops if u.op is UOps.ENDIF)
@ -334,7 +336,7 @@ class TestAssembly(unittest.TestCase):
l1 = UOp(UOps.LOAD, dtypes.int, (g1, c1))
a1 = UOp(UOps.ALU, dtypes.int, (l1, c1), BinaryOps.MUL)
a2 = UOp(UOps.ALU, dtypes.int, (l1, c2), BinaryOps.MUL)
uops = linearize_uop([a1,a2], opts=Device[Device.DEFAULT].renderer)
uops = to_uops_list([a1,a2], opts=Device[Device.DEFAULT].renderer)
Device[Device.DEFAULT].renderer.render("test", uops)
self.assertEqual(uops[-1].arg, BinaryOps.SHL)
self.assertEqual(uops[-2].arg, BinaryOps.MUL)
@ -346,7 +348,7 @@ class TestAssembly(unittest.TestCase):
l1 = UOp(UOps.LOAD, dtypes.int, (g1, c1))
a1 = UOp(UOps.ALU, dtypes.int, (l1, c1), BinaryOps.IDIV)
a2 = UOp(UOps.ALU, dtypes.int, (l1, c2), BinaryOps.IDIV)
uops = linearize_uop([a1,a2], opts=Device[Device.DEFAULT].renderer)
uops = to_uops_list([a1,a2], opts=Device[Device.DEFAULT].renderer)
Device[Device.DEFAULT].renderer.render("test", uops)
self.assertEqual(uops[-1].arg, BinaryOps.SHR)
self.assertEqual(uops[-2].arg, BinaryOps.IDIV)
@ -385,7 +387,7 @@ class TestIndexingOrdering(unittest.TestCase):
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 0)
st0 = UOp(UOps.STORE, dtypes.float.vec(4), (buf, UOp.const(dtypes.int, 0), UOp.const(dtypes.float.vec(4), 42)))
st1 = UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, 4), UOp.const(dtypes.float, 10)))
uops = linearize_uop([st1, st0], skip_check=True)
uops = to_uops_list([st1, st0], skip_check=True)
stores = [st for st in uops if st.op is UOps.STORE]
assert stores[0].src[1] < stores[1].src[1], f"stored at idx {stores[1].src[1].arg} AFTER {stores[0].src[1].arg}"
@ -397,7 +399,7 @@ class TestIndexingOrdering(unittest.TestCase):
st1_0 = UOp(UOps.STORE, dtypes.float, (buf0, UOp.const(dtypes.int, 4), UOp.const(dtypes.float, 10)))
st0_1 = UOp(UOps.STORE, dtypes.float.vec(4), (buf1, UOp.const(dtypes.int, 0), UOp.const(dtypes.float.vec(4), 42)))
st1_1 = UOp(UOps.STORE, dtypes.float, (buf1, UOp.const(dtypes.int, 4), UOp.const(dtypes.float, 10)))
uops = linearize_uop([st0_0, st1_0, st0_1, st1_1], skip_check=True)
uops = to_uops_list([st0_0, st1_0, st0_1, st1_1], skip_check=True)
stores = [st for st in uops if st.op is UOps.STORE]
print("\n".join(map(str, stores)))
# buf0 stores come first
@ -413,7 +415,7 @@ class TestIndexingOrdering(unittest.TestCase):
gidx0 = UOp(UOps.SPECIAL, dtypes.int, (), ('gidx0', 4))
st0 = UOp(UOps.STORE, dtypes.float.vec(4), (buf, gidx0+UOp.const(dtypes.int, 0), UOp.const(dtypes.float.vec(4), 42)))
st1 = UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, 4), UOp.const(dtypes.float, 10)))
uops = linearize_uop([st1, st0], skip_check=True)
uops = linearize_uop(UOp.sink(st1, st0), skip_check=True)
stores = [st for st in uops if st.op is UOps.STORE]
assert stores[0].src[1] < stores[1].src[1], f"stored at idx {stores[1].src[1].arg} AFTER {stores[0].src[1].arg}"

View File

@ -105,7 +105,7 @@ class TestUOpsStats(unittest.TestCase):
u3 = UOp(UOps.CONST, dtypes.int, tuple(), 3)
u4 = UOp(UOps.ALU, dtypes.int, (u1,u2), BinaryOps.MUL)
u5 = UOp(UOps.ALU, dtypes.int, (u4,u3), BinaryOps.ADD)
uops = linearize_uop([u5])
uops = linearize_uop(u5.sink())
globl = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), tuple())
o1 = UOp(UOps.CONST, dtypes.int, tuple(), 1)
@ -114,7 +114,7 @@ class TestUOpsStats(unittest.TestCase):
u2 = UOp(UOps.LOAD, dtypes.int, (globl, o2))
u3 = UOp(UOps.CONST, dtypes.int, tuple(), 3)
u4 = UOp(UOps.ALU, dtypes.int, (u1,u2,u3), TernaryOps.MULACC)
uops_fma = linearize_uop([u4])
uops_fma = linearize_uop(u4.sink())
self.assertEqual(flops_mem(uops), flops_mem(uops_fma))

View File

@ -9,14 +9,14 @@ from typing import Tuple
from tinygrad.helpers import DEBUG
from tinygrad.dtype import dtypes, PtrDType, ConstType
from tinygrad.codegen.uopgraph import linearize_uop
from tinygrad.codegen.uopgraph import linearize_uop, full_graph_rewrite
from tinygrad.ops import BinaryOps, UOp, UOps, print_uops
import functools
def render(self) -> Tuple[str, ConstType, ConstType]:
# NOTE: we need STORE so the ALU op has children
glbl = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=0)
uops = linearize_uop([UOp(UOps.STORE, None, (glbl, UOp.const(dtypes.int, 0), self))])
uops = linearize_uop(full_graph_rewrite(UOp(UOps.STORE, None, (glbl, UOp.const(dtypes.int, 0), self)).sink()))
if DEBUG>=5: print_uops(uops)
from tinygrad.renderer.cstyle import CStyleLanguage
class TestRenderer(CStyleLanguage):

View File

@ -13,7 +13,7 @@ from tinygrad.helpers import all_same, colored, ansilen, dedup, getenv, prod, DE
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.symbolic import Variable, sint
from tinygrad.shape.view import strides_for_shape
from tinygrad.codegen.uopgraph import linearize_uop
from tinygrad.codegen.uopgraph import linearize_uop, full_graph_rewrite
from tinygrad.codegen.lowerer import ast_to_uop
from enum import Enum, auto
@ -745,7 +745,7 @@ class Kernel:
print(self.applied_opts)
verify_ast(modified_ast)
self.uops:List[UOp] = linearize_uop(ast_to_uop(modified_ast, self.opts), self.opts)
self.uops:List[UOp] = linearize_uop(full_graph_rewrite(ast_to_uop(modified_ast, self.opts), self.opts))
if DEBUG >= 5: print_uops(self.uops)
if getenv("GRAPHUOPS"):
from tinygrad.engine.graph import graph_uops

View File

@ -1,5 +1,5 @@
from __future__ import annotations
from typing import Optional, Tuple, Dict, List, Set, Union, cast, TYPE_CHECKING, Any, DefaultDict, Callable
from typing import Optional, Tuple, Dict, List, Set, cast, TYPE_CHECKING, Any, DefaultDict, Callable
import functools, itertools, heapq, math, operator
from collections import defaultdict
from tinygrad.dtype import dtypes, PtrDType, ImageDType, DType
@ -512,9 +512,8 @@ def get_children_dfs(u:UOp, children:Dict[UOp, List[UOp]], srcs:Dict[UOp, Dict[U
return srcs[u]
linearize_cnt = 0
def linearize_uop(sink_in:Union[UOp, List[UOp]], opts:Optional[Renderer]=None, skip_check=False) -> List[UOp]:
def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp:
global linearize_cnt, acc_number
sink: UOp = sink_in if isinstance(sink_in, UOp) else UOp(UOps.SINK, None, tuple(sink_in))
assert sink.op is UOps.SINK, f"sink isn't sink, it's {sink.op}"
folder = constant_folder + transcendental_folding(tuple() if TRANSCENDENTAL >= 2 or opts is None else tuple(opts.code_for_op.keys()))
@ -533,7 +532,10 @@ def linearize_uop(sink_in:Union[UOp, List[UOp]], opts:Optional[Renderer]=None, s
# for PTX only
if opts is not None and opts.extra_matcher is not None: sink = graph_rewrite(sink, folder+opts.extra_matcher)
return sink
def linearize_uop(sink:UOp, skip_check:bool=False) -> List[UOp]:
assert sink.op is UOps.SINK, f"sink isn't sink, it's {sink.op}"
# filter nodes that don't link to a sink
# BFS toposort
children: Dict[UOp, List[UOp]] = {}

View File

@ -123,6 +123,7 @@ class UOp:
ret = self.src[0 if self.op is UOps.CONST else 1]
assert ret.op is UOps.SHAPETRACKER, f"st_arg trying to return {ret}"
return ret.arg
def sink(self, *srcs): return UOp(UOps.SINK, None, (self,)+srcs)
def ufix(self, x): return self.const(x) if not isinstance(x, UOp) else x
def cast(self, dtype=None): return type(self)(UOps.CAST, dtype, (self,))
def bitcast(self, dtype=None): return type(self)(UOps.BITCAST, dtype, (self,))