minimum new style expand (#6534)

* minimum new style expand [run_process_replay]

* float4 folding works

* fix uop graph

* if means or

* dype.count idx overload

* fix test arange

* expand nope

* fix expand contract

* fix amd tensor core

* oh, that's a good test with a real failure

* remove prints

* early reduce

* tomorrow, we remove sorted on expand args

* fix wmma issue

* that makes test_arange pass

* vectorized folding

* no check

* broadcast

* fix clang with self assign rule
This commit is contained in:
George Hotz 2024-09-17 13:02:41 +08:00 committed by GitHub
parent f5dd25d376
commit a2239c812e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 175 additions and 80 deletions

View File

@ -295,6 +295,7 @@ class TestUOpGraph(unittest.TestCase):
assert_equiv_uops(uops[0], acc)
self.assertEqual(len(uops), 1)
@unittest.skip("wmma is wrong here, it needs an arg")
def test_wmma_vectorize_no_fold(self):
for i in [4, 8]:
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i),
@ -437,63 +438,60 @@ def float4_rewrite(sink): return graph_rewrite(sink, constant_folder + expander
class TestExpander(unittest.TestCase):
def test_expand_add_broadcast(self):
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),))
e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(x for x in range(4))),), ((1,4),))
sink = expander_rewrite(e1+3)
assert sink.op is UOps.EXPAND and len(sink.src) == 4
self.assertListEqual([x.arg for x in sink.src], [3,4,5,6])
assert sink.op is UOps.EXPAND and len(sink.src[0].src) == 4
self.assertListEqual([x.arg for x in sink.src[0].src], [3,4,5,6])
def test_contract_simple(self):
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),))
e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(x for x in range(4))),), ((1,4),))
con = UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), ((1,4),))
sink = expander_rewrite(con)
assert sink.op is UOps.VECTORIZE and len(sink.src) == 4
self.assertListEqual([x.arg for x in sink.src], [0,1,2,3])
def test_contract_axis_1(self):
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(16)), ((1,4),(2,4)))
e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(16), tuple(x for x in range(16))),), ((1,4),(2,4)))
con = UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), ((1,4),))
sink = expander_rewrite(con)
assert sink.op is UOps.EXPAND and len(sink.src) == 4 and sink.arg == ((2,4),)
assert sink.src[0].op is UOps.VECTORIZE and len(sink.src[0].src) == 4
self.assertListEqual([x.arg for x in sink.src[0].src], [0,4,8,12])
self.assertListEqual([x.arg for x in sink.src[3].src], [3,7,11,15])
assert sink.op is UOps.EXPAND and len(sink.src[0].src) == 16 and sink.arg == ((2,4),)
assert sink.src[0].op is UOps.VECTORIZE and len(sink.src[0].src) == 16
self.assertListEqual([x.arg for x in sink.src[0].src][0:4], [0,4,8,12])
self.assertListEqual([x.arg for x in sink.src[0].src][12:], [3,7,11,15])
def test_contract_axis_2(self):
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(16)), ((1,4),(2,4)))
e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(16), tuple(x for x in range(16))),), ((1,4),(2,4)))
con = UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), ((2,4),))
sink = expander_rewrite(con)
assert sink.op is UOps.EXPAND and len(sink.src) == 4 and sink.arg == ((1,4),)
assert sink.src[0].op is UOps.VECTORIZE and len(sink.src[0].src) == 4
self.assertListEqual([x.arg for x in sink.src[0].src], [0,1,2,3])
self.assertListEqual([x.arg for x in sink.src[3].src], [12,13,14,15])
assert sink.op is UOps.EXPAND and len(sink.src[0].src) == 16 and sink.arg == ((1,4),)
assert sink.src[0].op is UOps.VECTORIZE and len(sink.src[0].src) == 16
self.assertListEqual([x.arg for x in sink.src[0].src][0:4], [0,1,2,3])
self.assertListEqual([x.arg for x in sink.src[0].src][12:], [12,13,14,15])
def test_contract_axis_2_big(self):
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(16)), ((1,2),(2,2),(3,2),(4,2)))
e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(16), tuple(x for x in range(16))),), ((1,2),(2,2),(3,2),(4,2)))
con = UOp(UOps.CONTRACT, dtypes.int.vec(2), (e1,), ((2,2),))
sink = expander_rewrite(con)
assert sink.op is UOps.EXPAND and sink.arg == ((1, 2), (3, 2), (4, 2))
self.assertListEqual([x.arg for x in sink.src[0].src], [0,4])
self.assertListEqual([x.arg for x in sink.src[6].src], [10,14])
self.assertListEqual([x.arg for x in sink.src[0].src][0:2], [0,4])
self.assertListEqual([x.arg for x in sink.src[0].src][12:14], [10,14])
def test_contract_multi_axis(self):
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(16)), ((1,2),(2,2),(3,2),(4,2)))
e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(16), tuple(x for x in range(16))),), ((1,2),(2,2),(3,2),(4,2)))
sink = expander_rewrite(UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), ((3, 2), (2, 2))))
assert sink.op is UOps.EXPAND and sink.arg == ((1, 2), (4, 2))
self.assertListEqual([x.arg for x in sink.src[0].src], [0,4,2,6])
self.assertListEqual([x.arg for x in sink.src[0].src][0:4], [0, 4, 2, 6])
sink = expander_rewrite(UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), ((2, 2), (3, 2))))
assert sink.op is UOps.EXPAND and sink.arg == ((1, 2), (4, 2))
self.assertListEqual([x.arg for x in sink.src[0].src], [0,2,4,6])
self.assertListEqual([x.arg for x in sink.src[0].src][0:4], [0, 2, 4, 6])
def test_contract_mid(self):
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(8)), ((1,2),(2,2),(3,2)))
e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(8), tuple(x for x in range(8))),), ((1,2),(2,2),(3,2)))
con = UOp(UOps.CONTRACT, dtypes.int.vec(2), (e1,), ((2,2),))
sink = expander_rewrite(con)
assert sink.op is UOps.EXPAND and len(sink.src) == 4 and sink.arg == ((1,2),(3,2))
assert sink.src[0].op is UOps.VECTORIZE and len(sink.src[0].src) == 2
self.assertListEqual([x.arg for x in sink.src[0].src], [0,2])
self.assertListEqual([x.arg for x in sink.src[1].src], [1,3])
self.assertListEqual([x.arg for x in sink.src[2].src], [4,6])
self.assertListEqual([x.arg for x in sink.src[3].src], [5,7])
assert sink.op is UOps.EXPAND and sink.arg == ((1,2),(3,2))
assert sink.src[0].op is UOps.VECTORIZE and len(sink.src[0].src) == 8
self.assertListEqual([x.arg for x in sink.src[0].src], [0,2,1,3,4,6,5,7])
def test_contract_no_expand(self):
e1 = UOp(UOps.DEFINE_VAR, dtypes.int)
@ -503,7 +501,7 @@ class TestExpander(unittest.TestCase):
assert sink.src[0] == sink.src[1]
def test_contract_half_expand(self):
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),))
e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(x for x in range(4))),), ((1,4),))
con = UOp(UOps.CONTRACT, dtypes.int.vec(8), (e1,), ((1,4), (2,2)))
sink = expander_rewrite(con)
assert sink.op is UOps.VECTORIZE and len(sink.src) == 8
@ -512,19 +510,19 @@ class TestExpander(unittest.TestCase):
assert sink.src[6] == sink.src[7]
def test_expand_same_axis(self):
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),))
e2 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, 4*x) for x in range(4)), ((1,4),))
e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(x for x in range(4))),), ((1,4),))
e2 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(4*x for x in range(4))),), ((1,4),))
sink = expander_rewrite(e1+e2)
assert sink.op is UOps.EXPAND and len(sink.src) == 4
self.assertListEqual([x.arg for x in sink.src], [0,5,10,15])
assert sink.op is UOps.EXPAND and len(sink.src[0].src) == 4
self.assertListEqual([x.arg for x in sink.src[0].src], [0,5,10,15])
def test_expand_different_axis(self, flip=False):
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, 4*x) for x in range(4)), ((1,4),))
e2 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((2,4),))
e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(4*x for x in range(4))),), ((1,4),))
e2 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(x for x in range(4))),), ((2,4),))
sink = expander_rewrite((e2+e1) if flip else (e1+e2))
assert sink.op is UOps.EXPAND and len(sink.src) == 16
assert sink.op is UOps.EXPAND and len(sink.src[0].src) == 16
assert sink.arg == ((1, 4), (2, 4))
self.assertListEqual([x.arg for x in sink.src], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])
self.assertListEqual([x.arg for x in sink.src[0].src], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])
def test_expand_different_axis_flip(self): self.test_expand_different_axis(True)
@ -544,6 +542,7 @@ class TestExpander(unittest.TestCase):
assert sink.op is UOps.CONST
self.assertEqual(sink.arg, 3*4)
@unittest.skip("no longer supported")
def test_double_expand(self):
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((2,4),))
e2 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, 4+x) for x in range(4)), ((2,4),))
@ -553,6 +552,7 @@ class TestExpander(unittest.TestCase):
assert sink.arg == ((1, 2), (2, 4))
self.assertListEqual([x.arg for x in sink.src], [0,1,2,3,4,5,6,7])
@unittest.skip("no longer supported")
def test_double_expand_reverse(self):
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),))
e2 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, 4+x) for x in range(4)), ((1,4),))
@ -562,6 +562,7 @@ class TestExpander(unittest.TestCase):
assert sink.arg == ((1, 4), (2, 2))
self.assertListEqual([x.arg for x in sink.src], [0, 4, 1, 5, 2, 6, 3, 7])
@unittest.skip("no longer supported")
def test_double_expand_middle(self):
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,2),(3,2)))
e2 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, 4+x) for x in range(4)), ((1,2),(3,2)))
@ -585,14 +586,14 @@ class TestLoadStoreFolder(unittest.TestCase):
def test_simple_load_fold(self):
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float))
load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i))) for i in range(4)]
sink = UOp(UOps.EXPAND, dtypes.float, tuple(load), ((0,4),))
sink = UOp(UOps.VECTORIZE, dtypes.float.vec(len(load)), tuple(load))
sink = float4_rewrite(sink)
assert len([x for x in sink.sparents if x.op is UOps.LOAD]) == 1
def test_two_load_fold(self):
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float))
load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i))) for i in range(8)]
sink = UOp(UOps.EXPAND, dtypes.float, tuple(load), ((0,8),))
sink = UOp(UOps.VECTORIZE, dtypes.float.vec(len(load)), tuple(load))
sink = float4_rewrite(sink)
assert len([x for x in sink.sparents if x.op is UOps.LOAD]) == 2
@ -600,7 +601,7 @@ class TestLoadStoreFolder(unittest.TestCase):
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float))
gate = UOp(UOps.DEFINE_VAR, dtypes.bool)
load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate)) for i in range(4)]
sink = UOp(UOps.EXPAND, dtypes.float, tuple(load), ((0,4),))
sink = UOp(UOps.VECTORIZE, dtypes.float.vec(len(load)), tuple(load))
sink = float4_rewrite(sink)
assert len([x for x in sink.sparents if x.op is UOps.LOAD]) == 1
single_load = [x for x in sink.sparents if x.op is UOps.LOAD][0]
@ -611,7 +612,7 @@ class TestLoadStoreFolder(unittest.TestCase):
gate = UOp(UOps.DEFINE_VAR, dtypes.bool, arg=("g1", UOp.const(dtypes.bool, False), UOp.const(dtypes.bool, True)))
gate2 = UOp(UOps.DEFINE_VAR, dtypes.bool, arg=("g2", UOp.const(dtypes.bool, False), UOp.const(dtypes.bool, True)))
load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate if i == 0 else gate2)) for i in range(4)]
sink = UOp(UOps.EXPAND, dtypes.float, tuple(load), ((0,4),))
sink = UOp(UOps.VECTORIZE, dtypes.float.vec(len(load)), tuple(load))
sink = float4_rewrite(sink)
assert len([x for x in sink.sparents if x.op is UOps.LOAD]) == 3

View File

@ -103,7 +103,7 @@ class TestUOpsStats(unittest.TestCase):
c = a.matmul(b)
c.realize()
expected_ops = N ** 3 * 2
assert expected_ops == GlobalCounters.global_ops
self.assertEqual(expected_ops, GlobalCounters.global_ops)
#MULACC should have the same stats as MUL + ADD
def test_mulacc(self):

View File

@ -72,7 +72,7 @@ class IndependentLowerer:
# upcast loops
for i,g in enumerate(full_shape[first_upcasted:], start=first_upcasted):
assert isinstance(g, int), "needs to be int to upcast/unroll"
self.idxs.append(UOp(UOps.EXPAND, dtypes.pyint, tuple(UOp.const(dtypes.pyint, j) for j in range(0, g)), ((i,g),)))
self.idxs.append(UOp(UOps.EXPAND, dtypes.pyint, (UOp.const(dtypes.pyint.vec(g), tuple(range(g))),), ((i,g),)))
# late indexes (group for reduce)
self.ridxs = self.idxs[:]
@ -117,7 +117,7 @@ class IndependentLowerer:
UOp(UOps.CONTRACT, dtype=in_uops[0].dtype.vec(wmma_sz[0]), src=(in_uops[0],), arg=upcast_axes[0]),
UOp(UOps.CONTRACT, dtype=in_uops[1].dtype.vec(wmma_sz[1]), src=(in_uops[1],), arg=upcast_axes[1]),
UOp.const(x.dtype.vec(wmma_sz[2]), 0.0)), arg=x.arg)
return UOp(UOps.EXPAND, x.dtype, tuple(UOp(UOps.GEP, x.dtype, (ret,), (i,)) for i in range(wmma_sz[2])), arg=upcast_axes[2])
return UOp(UOps.EXPAND, x.dtype, (ret,), arg=upcast_axes[2])
if x.op is UOps.REDUCE_AXIS:
# NOTE: always using ridxs is fine here
reduce_range, reduce_expand = partition([self.ridxs[i] for i in x.arg[1]], lambda y: y.op is UOps.RANGE)

View File

@ -5,7 +5,7 @@ from collections import defaultdict
from tinygrad.dtype import dtypes, PtrDType, ImageDType
from tinygrad.ops import UnaryOps, BinaryOps, exec_alu, UOp, UOps, END_FOR_UOP, type_verify, print_uops, identity_element
from tinygrad.ops import UPat, PatternMatcher, graph_rewrite
from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, AMX, prod, CI, partition
from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, AMX, prod, CI, partition, all_same
from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, TRANSCENDENTAL_SUPPORTED_DTYPES
if TYPE_CHECKING: from tinygrad.renderer import Renderer
@ -71,7 +71,7 @@ def fix_unfoldable_image_load(load:UOp, buf:UOp):
return functools.reduce(lambda ret, i: id4.ne(i).where(ret, UOp(UOps.GEP, load.dtype, (vec_load,), (i,))), range(4), load.const_like(float('nan')))
float4_folding = PatternMatcher([
(UPat(UOps.EXPAND, src=UPat(UOps.LOAD, src=(UPat.var("buf"), UPat()), allow_any_len=True), name="ex"), fold_expanded),
(UPat(UOps.VECTORIZE, src=UPat(UOps.LOAD, src=(UPat.var("buf"), UPat()), allow_any_len=True), name="ex"), fold_expanded),
(UPat((UOps.BARRIER, UOps.SINK), src=UPat(UOps.STORE, src=(UPat.var("buf"), UPat(), UPat()), allow_any_len=True), name="ex"), fold_expanded),
])
@ -201,7 +201,7 @@ def reduce_before_expand(reduce, expand, x):
red = UOp(UOps.REDUCE, x.dtype, (x,)+reduce.src[1:], reduce.arg)
return UOp(expand.op, expand.dtype, tuple(UOp(UOps.GEP, reduce.dtype, (red,), (i,)) for i in range(x.dtype.count)), expand.arg)
def loop_collapse(loop_start, loop_end, compval, idx, mval, multconst, rng, reduce, idx2=None, idx3=None, extra=None):
def loop_collapse(loop_start, loop_end, compval, idx, mval, multconst, rng, reduce, idx2=None, idx3=None, extra=None, vec=None):
if getenv("DISABLE_LOOP_COLLAPSE") or rng not in reduce.src: return None # must be the right REDUCE
if mval.arg >= 0 or loop_start.arg != 0:
# TODO: support and test this with other mvals and loop_starts
@ -209,6 +209,10 @@ def loop_collapse(loop_start, loop_end, compval, idx, mval, multconst, rng, redu
return None
if idx2 is not None: idx = idx + idx2
if idx3 is not None: idx = idx + idx3
if vec is not None:
# idx, mval, loop_start, loop_end
def dvec(x): return UOp(UOps.VECTORIZE, x.dtype.vec(vec.dtype.count), src=(x,)*vec.dtype.count)
idx, mval, loop_start, loop_end = dvec(idx), dvec(mval), dvec(loop_start), dvec(loop_end)
comprange = UOp.min(loop_end, UOp.max((idx-compval-mval)//mval + (loop_end-loop_start), loop_start))
new_reduce_op = comprange.cast(multconst.dtype) * multconst
ret = UOp(UOps.REDUCE, reduce.dtype, (new_reduce_op,) + tuple(x for x in reduce.src[1:] if x is not rng), reduce.arg)
@ -225,6 +229,8 @@ constant_folder = PatternMatcher([
# bool ADD is OR, MUL is AND. prevents other rules to rewrite bool ADD/MUL incorrectly
(UPat(UOps.ALU, dtypes.bool, arg=BinaryOps.ADD, name="x"), lambda x: UOp(x.op, x.dtype, x.src, BinaryOps.OR)),
(UPat(UOps.ALU, dtypes.bool, arg=BinaryOps.MUL, name="x"), lambda x: UOp(x.op, x.dtype, x.src, BinaryOps.AND)),
# self ASSIGN is just self
(UPat(UOps.ASSIGN, src=(UPat.var('x'), UPat.var('x'))), lambda x: x),
# VECTORIZE/GEP: the expander rule allows tuple GEP creation, this is just for removal
(UPat(UOps.VECTORIZE, src=UPat(UOps.GEP, src=(UPat(name="x"),)), name="vec"),
lambda vec,x: x if x.dtype == vec.dtype and tuple(y.arg[0] for y in vec.src) == tuple(range(len(vec.src))) else None),
@ -234,6 +240,7 @@ constant_folder = PatternMatcher([
# VECTORIZE of a single element is just that element
(UPat(UOps.VECTORIZE, src=(UPat(name='x'),)), lambda x: x),
# VECTORIZE void is SINK
(UPat(UOps.VECTORIZE, dtype=dtypes.void, src=UPat(UOps.BARRIER, name='b')), lambda b: b),
(UPat(UOps.VECTORIZE, dtype=dtypes.void, name='x'), lambda x: UOp(UOps.SINK, dtypes.void, x.src)),
# GEP/VECTORIZE, GEP/GEP, GEP/CONST, GEP/VCONST
(UPat(UOps.GEP, src=(UPat(UOps.GEP, name='g2'),), name='g1'),
@ -242,9 +249,12 @@ constant_folder = PatternMatcher([
lambda gep, vec: UOp(UOps.VECTORIZE, gep.dtype, tuple(vec.src[i] for i in gep.arg)) if len(gep.arg) > 1 else vec.src[gep.arg[0]]),
(UPat(UOps.GEP, src=(UPat.cvar("c", vec=False),), name="gep"), lambda gep, c: gep.const_like(c.arg)),
(UPat(UOps.GEP, src=(UPat(UOps.VCONST, name="c"),), name="gep"), lambda gep, c: gep.const_like(tuple(c.arg[x] for x in gep.arg))),
# push all GEPs through ALUs (fix arange stuff)
(UPat(UOps.GEP, src=(UPat((UOps.ALU, UOps.CAST, UOps.BITCAST), name='alu'),), name='gep'),
lambda gep,alu: UOp(alu.op, alu.dtype.scalar().vec(gep.dtype.count), tuple(x.gep(gep.arg) for x in alu.src), alu.arg)),
# GEP add push (non-shrinking only)
(UPat(UOps.GEP, None, (UPat.var('x') + UPat.cvar('c1'),), name="gep"),
lambda x,c1,gep: x.gep(gep.arg) + c1.gep(gep.arg) if len(gep.arg) >= x.dtype.count else None),
#(UPat(UOps.GEP, None, (UPat.var('x') + UPat.cvar('c1'),), name="gep"),
# lambda x,c1,gep: x.gep(gep.arg) + c1.gep(gep.arg) if len(gep.arg) >= x.dtype.count else None),
# tensor core with a 0 input is acc
*[(UPat(UOps.WMMA, src=(UPat.const(None, 0.0), UPat.var(), UPat.var("acc"))), lambda acc: acc) for i in [2, 4, 8]],
*[(UPat(UOps.WMMA, src=(UPat.var(), UPat.const(None, 0.0), UPat.var("acc"))), lambda acc: acc) for i in [2, 4, 8]],
@ -265,12 +275,22 @@ constant_folder = PatternMatcher([
.where(UPat.cvar("multconst"), UPat.const(None, 0)),), arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse),
# arange loop folding (reduce)
(UPat(UOps.REDUCE, src=((UPat.var("idx") + UPat.cvar("mval") * UPat(UOps.RANGE, src=(UPat.var("loop_start"), UPat.var("loop_end")), name="rng"))
.lt(UPat.cvar("compval"))
.where(UPat.cvar("multconst"), UPat.const(None, 0)),), arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse),
.lt(UPat.cvar("compval")).where(UPat.cvar("multconst"), UPat.const(None, 0)),),
arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse),
# arange loop folding (unrolled)
(UPat(UOps.REDUCE, src=((UPat.var("idx") + UPat.cvar("mval") * UPat(UOps.RANGE, src=(UPat.var("loop_start"), UPat.var("loop_end")), name="rng"))
.lt(UPat.cvar("compval")).where(UPat.cvar("multconst"), UPat.const(None, 0)) + UPat.var("extra"),),
arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse),
# arange loop folding (vectorized)
(UPat(UOps.REDUCE, src=(UPat(UOps.VECTORIZE, name="vec", src=(UPat.var("idx") + UPat.cvar("mval", vec=False) *
UPat(UOps.RANGE, src=(UPat.cvar("loop_start", vec=False), UPat.var("loop_end")), name="rng")))
.lt(UPat.cvar("compval")).where(UPat.cvar("multconst"), UPat.const(None, 0)),),
arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse),
# arange loop folding (unrolled, vectorized)
(UPat(UOps.REDUCE, src=(UPat(UOps.VECTORIZE, name="vec", src=(UPat.var("idx") + UPat.cvar("mval", vec=False) *
UPat(UOps.RANGE, src=(UPat.cvar("loop_start", vec=False), UPat.var("loop_end")), name="rng")))
.lt(UPat.cvar("compval")).where(UPat.cvar("multconst"), UPat.const(None, 0)) + UPat.var("extra"),),
arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse),
# unrolled arange div folding
(UPat.var("divs") + UPat.cvar("c"), fold_unrolled_divs),
# indexing (with a multiply offset)!
@ -406,22 +426,47 @@ def do_expand(root:UOp):
if len(expands) == 0: return None
# NOTE: we 0 out the reduce axis for WMMA. in theory they should all be the same, but is this always correct?
exclude_args = tuple(dedup(root.arg[-1] + tuple(y[0] for y in flatten(root.arg[-2])))) if root.op is UOps.WMMA else ()
expand_args = tuple(x for x in sorted(dedup(flatten([x.arg for x in expands]))) if x[0] not in exclude_args)
esrcs = [[src.src[x] for x in _swizzle_args(expand_args, src.arg, exclude_args)] \
if src.op is UOps.EXPAND else itertools.repeat(src) for src in root.src]
new_srcs = [UOp(root.op, root.dtype, new_src, root.arg) for new_src in zip(*esrcs)]
if root.op is UOps.EXPAND:
# merge two expands
expand_args, old_args = tuple(sorted(root.arg+expand_args)), expand_args
assert len(expand_args) == (len(old_args) + len(root.arg))
new_srcs = [new_srcs[_expand_arg_to_idx(old_args, rpk)].src[_expand_arg_to_idx(root.arg, rpk)] for rpk in _choices_from_args(expand_args)]
if root.op is UOps.IF:
# merge ifs into an or
conditions = functools.reduce(lambda x,y: x|y, dedup(x.src[0] for x in new_srcs if x.src[0].op is not UOps.CONST))
barriers = tuple(set(x.src[1] for x in new_srcs))
new_srcs = [UOp(UOps.IF, src=(conditions,)+barriers) for _ in new_srcs]
assert prod([x[1] for x in expand_args]) == len(new_srcs)
return UOp(UOps.EXPAND, root.dtype, tuple(new_srcs), expand_args)
if all_same(expands_args:=[x.arg for x in expands]) and len(exclude_args) == 0:
# if there's only one expand arg, it's okay to use it (optimization)
expand_args = expands[0].arg
else:
# otherwise, we sort them and GEP
expand_args = tuple(x for x in sorted(dedup(flatten(expands_args))) if x[0] not in exclude_args)
expand_sz = prod([x[1] for x in expand_args])
new_srcs = []
for i,src in enumerate(root.src):
if src.op is UOps.EXPAND:
if root.op is UOps.IF and i == 0:
# IF means OR on first arg to IF
new_srcs.append(functools.reduce(operator.__or__, [src.src[0].gep(i) for i in range(expand_sz)]))
elif expand_args == src.arg:
# just remove the expand
new_srcs.append(src.src[0])
else:
lst = _swizzle_args(expand_args, src.arg, exclude_args)
# if the base dtype is > 1, put those at the end
if src.dtype.count > 1: lst = flatten([[i*src.dtype.count+j for j in range(src.dtype.count)] for i in lst])
new_srcs.append(src.src[0].gep(tuple(lst)))
else:
# non-EXPAND input
if (root.op in {UOps.LOAD, UOps.STORE} and i == 0) or (root.op is UOps.REDUCE and i != 0):
# for the first arg of LOAD/STORE and the RANGE args of REDUCE, just pass them through ignoring EXPANDS
new_srcs.append(src)
elif src.dtype.count > 1:
# put any input dtype > 1 grouped together
new_srcs.append(UOp(UOps.VECTORIZE,
src.dtype.scalar().vec(expand_sz*src.dtype.count), tuple(src.gep(i) for i in range(src.dtype.count))*expand_sz))
else:
# repeat the arg
new_srcs.append(src.broadcast(expand_sz))
new_arg = root.arg
if root.op is UOps.GEP:
assert root.dtype.count == 1
# is this right?
new_arg = tuple(range(root.arg[0], new_srcs[0].dtype.count, new_srcs[0].dtype.count // expand_sz))
nsrc = UOp(root.op, root.dtype.scalar().vec(root.dtype.count*expand_sz), tuple(new_srcs), new_arg)
return UOp(UOps.EXPAND, root.dtype, (nsrc,), expand_args)
acc_number = 0
def do_reduce(root:UOp):
@ -435,7 +480,7 @@ def do_reduce(root:UOp):
ret = UOp(UOps.ASSIGN, root.dtype, (acc, acc.alu(root.arg, ret)))
# for MAX, we can just ignore the unparented
if root.arg is BinaryOps.ADD:
for r in reduce_unparented: ret = ret * (r.src[1]-r.src[0]).cast(ret.dtype)
for r in reduce_unparented:ret = ret * (r.src[1]-r.src[0]).cast(ret.dtype.scalar()).broadcast(ret.dtype.count)
return ret
def do_contract(con:UOp):
@ -444,16 +489,16 @@ def do_contract(con:UOp):
if ex.op is not UOps.EXPAND: return UOp(UOps.VECTORIZE, con.dtype, con.src*con.dtype.count)
# CONTRACT may remove several axes from EXPAND
assert con.dtype.count == prod([x[1] for x in con.arg]), "dtype is wrong"
srcs = []
idxs = []
for rpk in _choices_from_args(new_ex_args:=tuple(x for x in ex.arg if x not in con.arg)):
lsrcs = [ex.src[_expand_arg_to_idx(ex.arg, {**rpk, **lrpk})] for lrpk in _choices_from_args(con.arg)]
srcs.append(UOp(UOps.VECTORIZE, con.dtype, tuple(lsrcs)))
return srcs[0] if len(srcs) == 1 else UOp(UOps.EXPAND, con.dtype, tuple(srcs), new_ex_args)
idxs += [_expand_arg_to_idx(ex.arg, {**rpk, **lrpk}) for lrpk in _choices_from_args(con.arg)]
return UOp(UOps.EXPAND, con.dtype, (ex.src[0].gep(tuple(idxs)),), new_ex_args)
def no_vectorized_alu(alu):
if alu.dtype.count == 1: return None
alus = tuple(UOp(alu.op, alu.dtype.scalar(),
tuple(UOp(UOps.GEP, s.dtype.scalar(), (s,), (i,)) for s in alu.src), alu.arg) for i in range(alu.dtype.count))
tuple(UOp(UOps.GEP, s.dtype.scalar(), (s,), (i,)) if alu.op is not UOps.REDUCE or j == 0 else s for j,s in enumerate(alu.src)),
alu.arg) for i in range(alu.dtype.count))
return UOp(UOps.VECTORIZE, alu.dtype, alus)
def create_gate(root:UOp) -> Optional[UOp]:
@ -470,9 +515,12 @@ expander = PatternMatcher([
(UPat(UOps.VECTORIZE, src=UPat(UOps.GEP, src=(UPat(name="x"),)), name="vec"), lambda vec,x: x.gep(tuple(y.arg[0] for y in vec.src))),
# create gate MUST BE BEFORE expander
(UPat(UOps.STORE, name="root"), create_gate),
# double expand
(UPat(UOps.EXPAND, name="outer", src=(UPat(UOps.EXPAND, name="inner"),)),
lambda outer, inner: UOp(UOps.EXPAND, outer.dtype, (inner.src[0],), inner.arg+outer.arg)),
# do expansion
(UPat((UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.GEP, UOps.WMMA, UOps.LOAD, UOps.STORE,
UOps.VECTORIZE, UOps.REDUCE, UOps.EXPAND, UOps.IF), name="root", custom_early_reject=set([(UOps.EXPAND, None)])), do_expand),
UOps.VECTORIZE, UOps.REDUCE, UOps.IF), name="root", custom_early_reject=set([(UOps.EXPAND, None)])), do_expand),
(UPat(UOps.CONTRACT, name="con"), do_contract),
# remove EXPANDs from SINK
(UPat(UOps.SINK, name="root"),
@ -488,6 +536,37 @@ expander = PatternMatcher([
lambda ex,x,y: UOp(UOps.EXPAND, ex.dtype, tuple((x+y).gep(i) for i in range(256 if AMX else 8)), ex.arg)),
])
just_reduce = PatternMatcher([
# do reduce (in expander now)
(UPat(UOps.REDUCE, name="root"), do_reduce),
])
def no_vectorized_load_store(ls:UOp):
idx = ls.src[1]
if idx.dtype.count == 1: return None
# ugh, the meaning of a dtype.count idx is overloaded
if ls.op is UOps.LOAD and idx.dtype.count != ls.dtype.count: return None
if ls.op is UOps.STORE and idx.dtype.count != ls.src[2].dtype.count: return None
tv = [UOp(ls.op, ls.dtype.scalar(), (ls.src[0],) + tuple(j.gep(i) for j in ls.src[1:])) for i in range(idx.dtype.count)]
return UOp(UOps.VECTORIZE, ls.dtype, tuple(tv))
def no_vectorized_acc(acc:UOp):
if acc.dtype.count == 1: return None
alus = tuple(UOp(acc.op, acc.dtype.scalar(),
tuple(UOp(UOps.GEP, s.dtype.scalar(), (s,), (i,)) if j == 0 else s for j,s in enumerate(acc.src)), acc.arg+(i,)) for i in range(acc.dtype.count))
return UOp(UOps.VECTORIZE, acc.dtype, alus)
def no_vectorized_wmma(wmma:UOp):
out_sz = prod(x[1] for x in wmma.arg[6][-1])
if wmma.dtype.count == out_sz: return None
tsrcs = []
for s,sz in zip(wmma.src, wmma.arg[6]):
ssz = prod(x[1] for x in sz)
tsrcs.append([s.gep(tuple(range(grp, grp+ssz))) for grp in range(0, s.dtype.count, ssz)])
wmmas = [UOp(UOps.WMMA, wmma.dtype.scalar().vec(out_sz), tsrc, wmma.arg) for tsrc in zip(*tsrcs)]
wmma_ex = flatten([[e.gep(i) for i in range(out_sz)] for e in wmmas])
return UOp(UOps.VECTORIZE, wmma.dtype, tuple(wmma_ex))
def delete_redundant_gates(root:UOp) -> Optional[UOp]:
@functools.lru_cache(None)
def find_gate(x:UOp) -> Optional[UOp]:
@ -496,14 +575,19 @@ def delete_redundant_gates(root:UOp) -> Optional[UOp]:
if len(root.src) == 3 or (gate:=find_gate(root)) is None or gate.src[0] is not root.src[3]: return None
return UOp(UOps.STORE, root.dtype, root.src[:3], root.arg)
devectorize = PatternMatcher([
# no ALU on vectorized dtypes
(UPat((UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.ASSIGN, UOps.REDUCE), name="alu"), no_vectorized_alu),
(UPat(UOps.WMMA, name="wmma"), no_vectorized_wmma),
(UPat(UOps.DEFINE_ACC, name="acc"), no_vectorized_acc),
(UPat((UOps.LOAD, UOps.STORE), name="ls"), no_vectorized_load_store),
])
reducer = PatternMatcher([
(UPat(UOps.REDUCE, name="root"), do_reduce),
(UPat(UOps.CONST, name='c'),
lambda c: UOp(UOps.VECTORIZE, c.dtype, (UOp.const(c.dtype.scalar(), c.arg),)*c.dtype.count) if c.dtype.count > 1 else None),
(UPat(UOps.VCONST, name='c'), lambda c: UOp(UOps.VECTORIZE, c.dtype, tuple(UOp.const(c.dtype.scalar(), x) for x in c.arg))),
(UPat(UOps.GEP, name='gep'), lambda gep: UOp(UOps.VECTORIZE, gep.dtype, tuple(gep.src[0].gep(x) for x in gep.arg)) if len(gep.arg) > 1 else None),
# no ALU on vectorized dtypes
(UPat((UOps.ALU, UOps.CAST, UOps.BITCAST), name="alu"), no_vectorized_alu),
# delete_redundant_gates (after expand, is this still needed?)
(UPat(UOps.STORE, name="root"), delete_redundant_gates),
# late fixup of unfoldable image loads
@ -544,8 +628,11 @@ def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp:
# expand
linearize_cnt += 1
if linearize_cnt != (de:=getenv("DEBUG_EXPAND", 0)) and de != -1:
sink = graph_rewrite(sink, folder+(expander+float4_folding if opts is not None and opts.supports_float4 else expander))
if getenv("DO_REDUCE", 1): sink = graph_rewrite(sink, folder+reducer)
sink = graph_rewrite(sink, folder+expander)
if getenv("DO_REDUCE", 1):
sink = graph_rewrite(sink, folder+just_reduce)
sink = graph_rewrite(sink, folder+(devectorize+float4_folding if opts is not None and opts.supports_float4 else devectorize))
sink = graph_rewrite(sink, folder+reducer)
# for PTX only
if opts is not None and opts.extra_matcher is not None: sink = graph_rewrite(sink, folder+opts.extra_matcher)

View File

@ -349,6 +349,9 @@ class UOp(MathTrait):
def __init__(self, op: UOps, dtype:DType=dtypes.void, src: Tuple[UOp,...]=tuple(), arg:Any=None):
# TODO: instant check rules here make debugging easier
#if op is UOps.ALU and arg is BinaryOps.CMPNE: assert dtype.scalar() == dtypes.bool
#if op is UOps.VECTORIZE and dtype != dtypes.void: assert len(src) == dtype.count, f"{len(src)} invalid for {dtype}"
#if op is UOps.ALU and arg not in (BinaryOps.CMPNE, BinaryOps.CMPLT, TernaryOps.WHERE): assert all_same([dtype] + [x.dtype for x in src])
#if op is UOps.CAST: assert dtype.count == src[0].dtype.count, f"cast can't change vectorization {src[0].dtype} --> {dtype}"
self.op, self.dtype, self.src, self.arg = op, dtype, src, arg
def replace(self, op: Optional[UOps]=None, dtype:Optional[DType]=None, src: Optional[Tuple[UOp,...]]=None, arg:Any=None):
return UOp(op or self.op, dtype or self.dtype, self.src if src is None else src, self.arg if arg is None else arg)
@ -387,6 +390,10 @@ class UOp(MathTrait):
def sink(self, *srcs): return UOp(UOps.SINK, dtypes.void, (self,)+srcs)
def swizzle(self, st:ShapeTracker): return UOp(UOps.SWIZZLE, self.dtype, (self,), st)
def const_like(self, b:ConstType|Variable|Tuple[ConstType]): return type(self).const(self.dtype, b)
def broadcast(self, count:int):
assert self.dtype.count == 1
if count == 1: return self
return UOp(UOps.VECTORIZE, self.dtype.vec(count), (self,)*count)
def cast(self, dtype:DType): return type(self)(UOps.CAST, dtype, (self,))
def bitcast(self, dtype:DType): return type(self)(UOps.BITCAST, dtype, (self,))
def gep(self, i:Union[Tuple[int, ...], int]):