mirror of https://github.com/commaai/tinygrad.git
minimum new style expand (#6534)
* minimum new style expand [run_process_replay] * float4 folding works * fix uop graph * if means or * dype.count idx overload * fix test arange * expand nope * fix expand contract * fix amd tensor core * oh, that's a good test with a real failure * remove prints * early reduce * tomorrow, we remove sorted on expand args * fix wmma issue * that makes test_arange pass * vectorized folding * no check * broadcast * fix clang with self assign rule
This commit is contained in:
parent
f5dd25d376
commit
a2239c812e
|
@ -295,6 +295,7 @@ class TestUOpGraph(unittest.TestCase):
|
|||
assert_equiv_uops(uops[0], acc)
|
||||
self.assertEqual(len(uops), 1)
|
||||
|
||||
@unittest.skip("wmma is wrong here, it needs an arg")
|
||||
def test_wmma_vectorize_no_fold(self):
|
||||
for i in [4, 8]:
|
||||
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i),
|
||||
|
@ -437,63 +438,60 @@ def float4_rewrite(sink): return graph_rewrite(sink, constant_folder + expander
|
|||
|
||||
class TestExpander(unittest.TestCase):
|
||||
def test_expand_add_broadcast(self):
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),))
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(x for x in range(4))),), ((1,4),))
|
||||
sink = expander_rewrite(e1+3)
|
||||
assert sink.op is UOps.EXPAND and len(sink.src) == 4
|
||||
self.assertListEqual([x.arg for x in sink.src], [3,4,5,6])
|
||||
assert sink.op is UOps.EXPAND and len(sink.src[0].src) == 4
|
||||
self.assertListEqual([x.arg for x in sink.src[0].src], [3,4,5,6])
|
||||
|
||||
def test_contract_simple(self):
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),))
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(x for x in range(4))),), ((1,4),))
|
||||
con = UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), ((1,4),))
|
||||
sink = expander_rewrite(con)
|
||||
assert sink.op is UOps.VECTORIZE and len(sink.src) == 4
|
||||
self.assertListEqual([x.arg for x in sink.src], [0,1,2,3])
|
||||
|
||||
def test_contract_axis_1(self):
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(16)), ((1,4),(2,4)))
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(16), tuple(x for x in range(16))),), ((1,4),(2,4)))
|
||||
con = UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), ((1,4),))
|
||||
sink = expander_rewrite(con)
|
||||
assert sink.op is UOps.EXPAND and len(sink.src) == 4 and sink.arg == ((2,4),)
|
||||
assert sink.src[0].op is UOps.VECTORIZE and len(sink.src[0].src) == 4
|
||||
self.assertListEqual([x.arg for x in sink.src[0].src], [0,4,8,12])
|
||||
self.assertListEqual([x.arg for x in sink.src[3].src], [3,7,11,15])
|
||||
assert sink.op is UOps.EXPAND and len(sink.src[0].src) == 16 and sink.arg == ((2,4),)
|
||||
assert sink.src[0].op is UOps.VECTORIZE and len(sink.src[0].src) == 16
|
||||
self.assertListEqual([x.arg for x in sink.src[0].src][0:4], [0,4,8,12])
|
||||
self.assertListEqual([x.arg for x in sink.src[0].src][12:], [3,7,11,15])
|
||||
|
||||
def test_contract_axis_2(self):
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(16)), ((1,4),(2,4)))
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(16), tuple(x for x in range(16))),), ((1,4),(2,4)))
|
||||
con = UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), ((2,4),))
|
||||
sink = expander_rewrite(con)
|
||||
assert sink.op is UOps.EXPAND and len(sink.src) == 4 and sink.arg == ((1,4),)
|
||||
assert sink.src[0].op is UOps.VECTORIZE and len(sink.src[0].src) == 4
|
||||
self.assertListEqual([x.arg for x in sink.src[0].src], [0,1,2,3])
|
||||
self.assertListEqual([x.arg for x in sink.src[3].src], [12,13,14,15])
|
||||
assert sink.op is UOps.EXPAND and len(sink.src[0].src) == 16 and sink.arg == ((1,4),)
|
||||
assert sink.src[0].op is UOps.VECTORIZE and len(sink.src[0].src) == 16
|
||||
self.assertListEqual([x.arg for x in sink.src[0].src][0:4], [0,1,2,3])
|
||||
self.assertListEqual([x.arg for x in sink.src[0].src][12:], [12,13,14,15])
|
||||
|
||||
def test_contract_axis_2_big(self):
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(16)), ((1,2),(2,2),(3,2),(4,2)))
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(16), tuple(x for x in range(16))),), ((1,2),(2,2),(3,2),(4,2)))
|
||||
con = UOp(UOps.CONTRACT, dtypes.int.vec(2), (e1,), ((2,2),))
|
||||
sink = expander_rewrite(con)
|
||||
assert sink.op is UOps.EXPAND and sink.arg == ((1, 2), (3, 2), (4, 2))
|
||||
self.assertListEqual([x.arg for x in sink.src[0].src], [0,4])
|
||||
self.assertListEqual([x.arg for x in sink.src[6].src], [10,14])
|
||||
self.assertListEqual([x.arg for x in sink.src[0].src][0:2], [0,4])
|
||||
self.assertListEqual([x.arg for x in sink.src[0].src][12:14], [10,14])
|
||||
|
||||
def test_contract_multi_axis(self):
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(16)), ((1,2),(2,2),(3,2),(4,2)))
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(16), tuple(x for x in range(16))),), ((1,2),(2,2),(3,2),(4,2)))
|
||||
sink = expander_rewrite(UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), ((3, 2), (2, 2))))
|
||||
assert sink.op is UOps.EXPAND and sink.arg == ((1, 2), (4, 2))
|
||||
self.assertListEqual([x.arg for x in sink.src[0].src], [0,4,2,6])
|
||||
self.assertListEqual([x.arg for x in sink.src[0].src][0:4], [0, 4, 2, 6])
|
||||
sink = expander_rewrite(UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), ((2, 2), (3, 2))))
|
||||
assert sink.op is UOps.EXPAND and sink.arg == ((1, 2), (4, 2))
|
||||
self.assertListEqual([x.arg for x in sink.src[0].src], [0,2,4,6])
|
||||
self.assertListEqual([x.arg for x in sink.src[0].src][0:4], [0, 2, 4, 6])
|
||||
|
||||
def test_contract_mid(self):
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(8)), ((1,2),(2,2),(3,2)))
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(8), tuple(x for x in range(8))),), ((1,2),(2,2),(3,2)))
|
||||
con = UOp(UOps.CONTRACT, dtypes.int.vec(2), (e1,), ((2,2),))
|
||||
sink = expander_rewrite(con)
|
||||
assert sink.op is UOps.EXPAND and len(sink.src) == 4 and sink.arg == ((1,2),(3,2))
|
||||
assert sink.src[0].op is UOps.VECTORIZE and len(sink.src[0].src) == 2
|
||||
self.assertListEqual([x.arg for x in sink.src[0].src], [0,2])
|
||||
self.assertListEqual([x.arg for x in sink.src[1].src], [1,3])
|
||||
self.assertListEqual([x.arg for x in sink.src[2].src], [4,6])
|
||||
self.assertListEqual([x.arg for x in sink.src[3].src], [5,7])
|
||||
assert sink.op is UOps.EXPAND and sink.arg == ((1,2),(3,2))
|
||||
assert sink.src[0].op is UOps.VECTORIZE and len(sink.src[0].src) == 8
|
||||
self.assertListEqual([x.arg for x in sink.src[0].src], [0,2,1,3,4,6,5,7])
|
||||
|
||||
def test_contract_no_expand(self):
|
||||
e1 = UOp(UOps.DEFINE_VAR, dtypes.int)
|
||||
|
@ -503,7 +501,7 @@ class TestExpander(unittest.TestCase):
|
|||
assert sink.src[0] == sink.src[1]
|
||||
|
||||
def test_contract_half_expand(self):
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),))
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(x for x in range(4))),), ((1,4),))
|
||||
con = UOp(UOps.CONTRACT, dtypes.int.vec(8), (e1,), ((1,4), (2,2)))
|
||||
sink = expander_rewrite(con)
|
||||
assert sink.op is UOps.VECTORIZE and len(sink.src) == 8
|
||||
|
@ -512,19 +510,19 @@ class TestExpander(unittest.TestCase):
|
|||
assert sink.src[6] == sink.src[7]
|
||||
|
||||
def test_expand_same_axis(self):
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),))
|
||||
e2 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, 4*x) for x in range(4)), ((1,4),))
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(x for x in range(4))),), ((1,4),))
|
||||
e2 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(4*x for x in range(4))),), ((1,4),))
|
||||
sink = expander_rewrite(e1+e2)
|
||||
assert sink.op is UOps.EXPAND and len(sink.src) == 4
|
||||
self.assertListEqual([x.arg for x in sink.src], [0,5,10,15])
|
||||
assert sink.op is UOps.EXPAND and len(sink.src[0].src) == 4
|
||||
self.assertListEqual([x.arg for x in sink.src[0].src], [0,5,10,15])
|
||||
|
||||
def test_expand_different_axis(self, flip=False):
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, 4*x) for x in range(4)), ((1,4),))
|
||||
e2 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((2,4),))
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(4*x for x in range(4))),), ((1,4),))
|
||||
e2 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(x for x in range(4))),), ((2,4),))
|
||||
sink = expander_rewrite((e2+e1) if flip else (e1+e2))
|
||||
assert sink.op is UOps.EXPAND and len(sink.src) == 16
|
||||
assert sink.op is UOps.EXPAND and len(sink.src[0].src) == 16
|
||||
assert sink.arg == ((1, 4), (2, 4))
|
||||
self.assertListEqual([x.arg for x in sink.src], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])
|
||||
self.assertListEqual([x.arg for x in sink.src[0].src], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])
|
||||
|
||||
def test_expand_different_axis_flip(self): self.test_expand_different_axis(True)
|
||||
|
||||
|
@ -544,6 +542,7 @@ class TestExpander(unittest.TestCase):
|
|||
assert sink.op is UOps.CONST
|
||||
self.assertEqual(sink.arg, 3*4)
|
||||
|
||||
@unittest.skip("no longer supported")
|
||||
def test_double_expand(self):
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((2,4),))
|
||||
e2 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, 4+x) for x in range(4)), ((2,4),))
|
||||
|
@ -553,6 +552,7 @@ class TestExpander(unittest.TestCase):
|
|||
assert sink.arg == ((1, 2), (2, 4))
|
||||
self.assertListEqual([x.arg for x in sink.src], [0,1,2,3,4,5,6,7])
|
||||
|
||||
@unittest.skip("no longer supported")
|
||||
def test_double_expand_reverse(self):
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),))
|
||||
e2 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, 4+x) for x in range(4)), ((1,4),))
|
||||
|
@ -562,6 +562,7 @@ class TestExpander(unittest.TestCase):
|
|||
assert sink.arg == ((1, 4), (2, 2))
|
||||
self.assertListEqual([x.arg for x in sink.src], [0, 4, 1, 5, 2, 6, 3, 7])
|
||||
|
||||
@unittest.skip("no longer supported")
|
||||
def test_double_expand_middle(self):
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,2),(3,2)))
|
||||
e2 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, 4+x) for x in range(4)), ((1,2),(3,2)))
|
||||
|
@ -585,14 +586,14 @@ class TestLoadStoreFolder(unittest.TestCase):
|
|||
def test_simple_load_fold(self):
|
||||
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float))
|
||||
load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i))) for i in range(4)]
|
||||
sink = UOp(UOps.EXPAND, dtypes.float, tuple(load), ((0,4),))
|
||||
sink = UOp(UOps.VECTORIZE, dtypes.float.vec(len(load)), tuple(load))
|
||||
sink = float4_rewrite(sink)
|
||||
assert len([x for x in sink.sparents if x.op is UOps.LOAD]) == 1
|
||||
|
||||
def test_two_load_fold(self):
|
||||
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float))
|
||||
load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i))) for i in range(8)]
|
||||
sink = UOp(UOps.EXPAND, dtypes.float, tuple(load), ((0,8),))
|
||||
sink = UOp(UOps.VECTORIZE, dtypes.float.vec(len(load)), tuple(load))
|
||||
sink = float4_rewrite(sink)
|
||||
assert len([x for x in sink.sparents if x.op is UOps.LOAD]) == 2
|
||||
|
||||
|
@ -600,7 +601,7 @@ class TestLoadStoreFolder(unittest.TestCase):
|
|||
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float))
|
||||
gate = UOp(UOps.DEFINE_VAR, dtypes.bool)
|
||||
load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate)) for i in range(4)]
|
||||
sink = UOp(UOps.EXPAND, dtypes.float, tuple(load), ((0,4),))
|
||||
sink = UOp(UOps.VECTORIZE, dtypes.float.vec(len(load)), tuple(load))
|
||||
sink = float4_rewrite(sink)
|
||||
assert len([x for x in sink.sparents if x.op is UOps.LOAD]) == 1
|
||||
single_load = [x for x in sink.sparents if x.op is UOps.LOAD][0]
|
||||
|
@ -611,7 +612,7 @@ class TestLoadStoreFolder(unittest.TestCase):
|
|||
gate = UOp(UOps.DEFINE_VAR, dtypes.bool, arg=("g1", UOp.const(dtypes.bool, False), UOp.const(dtypes.bool, True)))
|
||||
gate2 = UOp(UOps.DEFINE_VAR, dtypes.bool, arg=("g2", UOp.const(dtypes.bool, False), UOp.const(dtypes.bool, True)))
|
||||
load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate if i == 0 else gate2)) for i in range(4)]
|
||||
sink = UOp(UOps.EXPAND, dtypes.float, tuple(load), ((0,4),))
|
||||
sink = UOp(UOps.VECTORIZE, dtypes.float.vec(len(load)), tuple(load))
|
||||
sink = float4_rewrite(sink)
|
||||
assert len([x for x in sink.sparents if x.op is UOps.LOAD]) == 3
|
||||
|
||||
|
|
|
@ -103,7 +103,7 @@ class TestUOpsStats(unittest.TestCase):
|
|||
c = a.matmul(b)
|
||||
c.realize()
|
||||
expected_ops = N ** 3 * 2
|
||||
assert expected_ops == GlobalCounters.global_ops
|
||||
self.assertEqual(expected_ops, GlobalCounters.global_ops)
|
||||
|
||||
#MULACC should have the same stats as MUL + ADD
|
||||
def test_mulacc(self):
|
||||
|
|
|
@ -72,7 +72,7 @@ class IndependentLowerer:
|
|||
# upcast loops
|
||||
for i,g in enumerate(full_shape[first_upcasted:], start=first_upcasted):
|
||||
assert isinstance(g, int), "needs to be int to upcast/unroll"
|
||||
self.idxs.append(UOp(UOps.EXPAND, dtypes.pyint, tuple(UOp.const(dtypes.pyint, j) for j in range(0, g)), ((i,g),)))
|
||||
self.idxs.append(UOp(UOps.EXPAND, dtypes.pyint, (UOp.const(dtypes.pyint.vec(g), tuple(range(g))),), ((i,g),)))
|
||||
|
||||
# late indexes (group for reduce)
|
||||
self.ridxs = self.idxs[:]
|
||||
|
@ -117,7 +117,7 @@ class IndependentLowerer:
|
|||
UOp(UOps.CONTRACT, dtype=in_uops[0].dtype.vec(wmma_sz[0]), src=(in_uops[0],), arg=upcast_axes[0]),
|
||||
UOp(UOps.CONTRACT, dtype=in_uops[1].dtype.vec(wmma_sz[1]), src=(in_uops[1],), arg=upcast_axes[1]),
|
||||
UOp.const(x.dtype.vec(wmma_sz[2]), 0.0)), arg=x.arg)
|
||||
return UOp(UOps.EXPAND, x.dtype, tuple(UOp(UOps.GEP, x.dtype, (ret,), (i,)) for i in range(wmma_sz[2])), arg=upcast_axes[2])
|
||||
return UOp(UOps.EXPAND, x.dtype, (ret,), arg=upcast_axes[2])
|
||||
if x.op is UOps.REDUCE_AXIS:
|
||||
# NOTE: always using ridxs is fine here
|
||||
reduce_range, reduce_expand = partition([self.ridxs[i] for i in x.arg[1]], lambda y: y.op is UOps.RANGE)
|
||||
|
|
|
@ -5,7 +5,7 @@ from collections import defaultdict
|
|||
from tinygrad.dtype import dtypes, PtrDType, ImageDType
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, exec_alu, UOp, UOps, END_FOR_UOP, type_verify, print_uops, identity_element
|
||||
from tinygrad.ops import UPat, PatternMatcher, graph_rewrite
|
||||
from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, AMX, prod, CI, partition
|
||||
from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, AMX, prod, CI, partition, all_same
|
||||
from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, TRANSCENDENTAL_SUPPORTED_DTYPES
|
||||
if TYPE_CHECKING: from tinygrad.renderer import Renderer
|
||||
|
||||
|
@ -71,7 +71,7 @@ def fix_unfoldable_image_load(load:UOp, buf:UOp):
|
|||
return functools.reduce(lambda ret, i: id4.ne(i).where(ret, UOp(UOps.GEP, load.dtype, (vec_load,), (i,))), range(4), load.const_like(float('nan')))
|
||||
|
||||
float4_folding = PatternMatcher([
|
||||
(UPat(UOps.EXPAND, src=UPat(UOps.LOAD, src=(UPat.var("buf"), UPat()), allow_any_len=True), name="ex"), fold_expanded),
|
||||
(UPat(UOps.VECTORIZE, src=UPat(UOps.LOAD, src=(UPat.var("buf"), UPat()), allow_any_len=True), name="ex"), fold_expanded),
|
||||
(UPat((UOps.BARRIER, UOps.SINK), src=UPat(UOps.STORE, src=(UPat.var("buf"), UPat(), UPat()), allow_any_len=True), name="ex"), fold_expanded),
|
||||
])
|
||||
|
||||
|
@ -201,7 +201,7 @@ def reduce_before_expand(reduce, expand, x):
|
|||
red = UOp(UOps.REDUCE, x.dtype, (x,)+reduce.src[1:], reduce.arg)
|
||||
return UOp(expand.op, expand.dtype, tuple(UOp(UOps.GEP, reduce.dtype, (red,), (i,)) for i in range(x.dtype.count)), expand.arg)
|
||||
|
||||
def loop_collapse(loop_start, loop_end, compval, idx, mval, multconst, rng, reduce, idx2=None, idx3=None, extra=None):
|
||||
def loop_collapse(loop_start, loop_end, compval, idx, mval, multconst, rng, reduce, idx2=None, idx3=None, extra=None, vec=None):
|
||||
if getenv("DISABLE_LOOP_COLLAPSE") or rng not in reduce.src: return None # must be the right REDUCE
|
||||
if mval.arg >= 0 or loop_start.arg != 0:
|
||||
# TODO: support and test this with other mvals and loop_starts
|
||||
|
@ -209,6 +209,10 @@ def loop_collapse(loop_start, loop_end, compval, idx, mval, multconst, rng, redu
|
|||
return None
|
||||
if idx2 is not None: idx = idx + idx2
|
||||
if idx3 is not None: idx = idx + idx3
|
||||
if vec is not None:
|
||||
# idx, mval, loop_start, loop_end
|
||||
def dvec(x): return UOp(UOps.VECTORIZE, x.dtype.vec(vec.dtype.count), src=(x,)*vec.dtype.count)
|
||||
idx, mval, loop_start, loop_end = dvec(idx), dvec(mval), dvec(loop_start), dvec(loop_end)
|
||||
comprange = UOp.min(loop_end, UOp.max((idx-compval-mval)//mval + (loop_end-loop_start), loop_start))
|
||||
new_reduce_op = comprange.cast(multconst.dtype) * multconst
|
||||
ret = UOp(UOps.REDUCE, reduce.dtype, (new_reduce_op,) + tuple(x for x in reduce.src[1:] if x is not rng), reduce.arg)
|
||||
|
@ -225,6 +229,8 @@ constant_folder = PatternMatcher([
|
|||
# bool ADD is OR, MUL is AND. prevents other rules to rewrite bool ADD/MUL incorrectly
|
||||
(UPat(UOps.ALU, dtypes.bool, arg=BinaryOps.ADD, name="x"), lambda x: UOp(x.op, x.dtype, x.src, BinaryOps.OR)),
|
||||
(UPat(UOps.ALU, dtypes.bool, arg=BinaryOps.MUL, name="x"), lambda x: UOp(x.op, x.dtype, x.src, BinaryOps.AND)),
|
||||
# self ASSIGN is just self
|
||||
(UPat(UOps.ASSIGN, src=(UPat.var('x'), UPat.var('x'))), lambda x: x),
|
||||
# VECTORIZE/GEP: the expander rule allows tuple GEP creation, this is just for removal
|
||||
(UPat(UOps.VECTORIZE, src=UPat(UOps.GEP, src=(UPat(name="x"),)), name="vec"),
|
||||
lambda vec,x: x if x.dtype == vec.dtype and tuple(y.arg[0] for y in vec.src) == tuple(range(len(vec.src))) else None),
|
||||
|
@ -234,6 +240,7 @@ constant_folder = PatternMatcher([
|
|||
# VECTORIZE of a single element is just that element
|
||||
(UPat(UOps.VECTORIZE, src=(UPat(name='x'),)), lambda x: x),
|
||||
# VECTORIZE void is SINK
|
||||
(UPat(UOps.VECTORIZE, dtype=dtypes.void, src=UPat(UOps.BARRIER, name='b')), lambda b: b),
|
||||
(UPat(UOps.VECTORIZE, dtype=dtypes.void, name='x'), lambda x: UOp(UOps.SINK, dtypes.void, x.src)),
|
||||
# GEP/VECTORIZE, GEP/GEP, GEP/CONST, GEP/VCONST
|
||||
(UPat(UOps.GEP, src=(UPat(UOps.GEP, name='g2'),), name='g1'),
|
||||
|
@ -242,9 +249,12 @@ constant_folder = PatternMatcher([
|
|||
lambda gep, vec: UOp(UOps.VECTORIZE, gep.dtype, tuple(vec.src[i] for i in gep.arg)) if len(gep.arg) > 1 else vec.src[gep.arg[0]]),
|
||||
(UPat(UOps.GEP, src=(UPat.cvar("c", vec=False),), name="gep"), lambda gep, c: gep.const_like(c.arg)),
|
||||
(UPat(UOps.GEP, src=(UPat(UOps.VCONST, name="c"),), name="gep"), lambda gep, c: gep.const_like(tuple(c.arg[x] for x in gep.arg))),
|
||||
# push all GEPs through ALUs (fix arange stuff)
|
||||
(UPat(UOps.GEP, src=(UPat((UOps.ALU, UOps.CAST, UOps.BITCAST), name='alu'),), name='gep'),
|
||||
lambda gep,alu: UOp(alu.op, alu.dtype.scalar().vec(gep.dtype.count), tuple(x.gep(gep.arg) for x in alu.src), alu.arg)),
|
||||
# GEP add push (non-shrinking only)
|
||||
(UPat(UOps.GEP, None, (UPat.var('x') + UPat.cvar('c1'),), name="gep"),
|
||||
lambda x,c1,gep: x.gep(gep.arg) + c1.gep(gep.arg) if len(gep.arg) >= x.dtype.count else None),
|
||||
#(UPat(UOps.GEP, None, (UPat.var('x') + UPat.cvar('c1'),), name="gep"),
|
||||
# lambda x,c1,gep: x.gep(gep.arg) + c1.gep(gep.arg) if len(gep.arg) >= x.dtype.count else None),
|
||||
# tensor core with a 0 input is acc
|
||||
*[(UPat(UOps.WMMA, src=(UPat.const(None, 0.0), UPat.var(), UPat.var("acc"))), lambda acc: acc) for i in [2, 4, 8]],
|
||||
*[(UPat(UOps.WMMA, src=(UPat.var(), UPat.const(None, 0.0), UPat.var("acc"))), lambda acc: acc) for i in [2, 4, 8]],
|
||||
|
@ -265,12 +275,22 @@ constant_folder = PatternMatcher([
|
|||
.where(UPat.cvar("multconst"), UPat.const(None, 0)),), arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse),
|
||||
# arange loop folding (reduce)
|
||||
(UPat(UOps.REDUCE, src=((UPat.var("idx") + UPat.cvar("mval") * UPat(UOps.RANGE, src=(UPat.var("loop_start"), UPat.var("loop_end")), name="rng"))
|
||||
.lt(UPat.cvar("compval"))
|
||||
.where(UPat.cvar("multconst"), UPat.const(None, 0)),), arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse),
|
||||
.lt(UPat.cvar("compval")).where(UPat.cvar("multconst"), UPat.const(None, 0)),),
|
||||
arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse),
|
||||
# arange loop folding (unrolled)
|
||||
(UPat(UOps.REDUCE, src=((UPat.var("idx") + UPat.cvar("mval") * UPat(UOps.RANGE, src=(UPat.var("loop_start"), UPat.var("loop_end")), name="rng"))
|
||||
.lt(UPat.cvar("compval")).where(UPat.cvar("multconst"), UPat.const(None, 0)) + UPat.var("extra"),),
|
||||
arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse),
|
||||
# arange loop folding (vectorized)
|
||||
(UPat(UOps.REDUCE, src=(UPat(UOps.VECTORIZE, name="vec", src=(UPat.var("idx") + UPat.cvar("mval", vec=False) *
|
||||
UPat(UOps.RANGE, src=(UPat.cvar("loop_start", vec=False), UPat.var("loop_end")), name="rng")))
|
||||
.lt(UPat.cvar("compval")).where(UPat.cvar("multconst"), UPat.const(None, 0)),),
|
||||
arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse),
|
||||
# arange loop folding (unrolled, vectorized)
|
||||
(UPat(UOps.REDUCE, src=(UPat(UOps.VECTORIZE, name="vec", src=(UPat.var("idx") + UPat.cvar("mval", vec=False) *
|
||||
UPat(UOps.RANGE, src=(UPat.cvar("loop_start", vec=False), UPat.var("loop_end")), name="rng")))
|
||||
.lt(UPat.cvar("compval")).where(UPat.cvar("multconst"), UPat.const(None, 0)) + UPat.var("extra"),),
|
||||
arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse),
|
||||
# unrolled arange div folding
|
||||
(UPat.var("divs") + UPat.cvar("c"), fold_unrolled_divs),
|
||||
# indexing (with a multiply offset)!
|
||||
|
@ -406,22 +426,47 @@ def do_expand(root:UOp):
|
|||
if len(expands) == 0: return None
|
||||
# NOTE: we 0 out the reduce axis for WMMA. in theory they should all be the same, but is this always correct?
|
||||
exclude_args = tuple(dedup(root.arg[-1] + tuple(y[0] for y in flatten(root.arg[-2])))) if root.op is UOps.WMMA else ()
|
||||
expand_args = tuple(x for x in sorted(dedup(flatten([x.arg for x in expands]))) if x[0] not in exclude_args)
|
||||
esrcs = [[src.src[x] for x in _swizzle_args(expand_args, src.arg, exclude_args)] \
|
||||
if src.op is UOps.EXPAND else itertools.repeat(src) for src in root.src]
|
||||
new_srcs = [UOp(root.op, root.dtype, new_src, root.arg) for new_src in zip(*esrcs)]
|
||||
if root.op is UOps.EXPAND:
|
||||
# merge two expands
|
||||
expand_args, old_args = tuple(sorted(root.arg+expand_args)), expand_args
|
||||
assert len(expand_args) == (len(old_args) + len(root.arg))
|
||||
new_srcs = [new_srcs[_expand_arg_to_idx(old_args, rpk)].src[_expand_arg_to_idx(root.arg, rpk)] for rpk in _choices_from_args(expand_args)]
|
||||
if root.op is UOps.IF:
|
||||
# merge ifs into an or
|
||||
conditions = functools.reduce(lambda x,y: x|y, dedup(x.src[0] for x in new_srcs if x.src[0].op is not UOps.CONST))
|
||||
barriers = tuple(set(x.src[1] for x in new_srcs))
|
||||
new_srcs = [UOp(UOps.IF, src=(conditions,)+barriers) for _ in new_srcs]
|
||||
assert prod([x[1] for x in expand_args]) == len(new_srcs)
|
||||
return UOp(UOps.EXPAND, root.dtype, tuple(new_srcs), expand_args)
|
||||
if all_same(expands_args:=[x.arg for x in expands]) and len(exclude_args) == 0:
|
||||
# if there's only one expand arg, it's okay to use it (optimization)
|
||||
expand_args = expands[0].arg
|
||||
else:
|
||||
# otherwise, we sort them and GEP
|
||||
expand_args = tuple(x for x in sorted(dedup(flatten(expands_args))) if x[0] not in exclude_args)
|
||||
expand_sz = prod([x[1] for x in expand_args])
|
||||
new_srcs = []
|
||||
for i,src in enumerate(root.src):
|
||||
if src.op is UOps.EXPAND:
|
||||
if root.op is UOps.IF and i == 0:
|
||||
# IF means OR on first arg to IF
|
||||
new_srcs.append(functools.reduce(operator.__or__, [src.src[0].gep(i) for i in range(expand_sz)]))
|
||||
elif expand_args == src.arg:
|
||||
# just remove the expand
|
||||
new_srcs.append(src.src[0])
|
||||
else:
|
||||
lst = _swizzle_args(expand_args, src.arg, exclude_args)
|
||||
# if the base dtype is > 1, put those at the end
|
||||
if src.dtype.count > 1: lst = flatten([[i*src.dtype.count+j for j in range(src.dtype.count)] for i in lst])
|
||||
new_srcs.append(src.src[0].gep(tuple(lst)))
|
||||
else:
|
||||
# non-EXPAND input
|
||||
if (root.op in {UOps.LOAD, UOps.STORE} and i == 0) or (root.op is UOps.REDUCE and i != 0):
|
||||
# for the first arg of LOAD/STORE and the RANGE args of REDUCE, just pass them through ignoring EXPANDS
|
||||
new_srcs.append(src)
|
||||
elif src.dtype.count > 1:
|
||||
# put any input dtype > 1 grouped together
|
||||
new_srcs.append(UOp(UOps.VECTORIZE,
|
||||
src.dtype.scalar().vec(expand_sz*src.dtype.count), tuple(src.gep(i) for i in range(src.dtype.count))*expand_sz))
|
||||
else:
|
||||
# repeat the arg
|
||||
new_srcs.append(src.broadcast(expand_sz))
|
||||
|
||||
new_arg = root.arg
|
||||
if root.op is UOps.GEP:
|
||||
assert root.dtype.count == 1
|
||||
# is this right?
|
||||
new_arg = tuple(range(root.arg[0], new_srcs[0].dtype.count, new_srcs[0].dtype.count // expand_sz))
|
||||
nsrc = UOp(root.op, root.dtype.scalar().vec(root.dtype.count*expand_sz), tuple(new_srcs), new_arg)
|
||||
return UOp(UOps.EXPAND, root.dtype, (nsrc,), expand_args)
|
||||
|
||||
acc_number = 0
|
||||
def do_reduce(root:UOp):
|
||||
|
@ -435,7 +480,7 @@ def do_reduce(root:UOp):
|
|||
ret = UOp(UOps.ASSIGN, root.dtype, (acc, acc.alu(root.arg, ret)))
|
||||
# for MAX, we can just ignore the unparented
|
||||
if root.arg is BinaryOps.ADD:
|
||||
for r in reduce_unparented: ret = ret * (r.src[1]-r.src[0]).cast(ret.dtype)
|
||||
for r in reduce_unparented:ret = ret * (r.src[1]-r.src[0]).cast(ret.dtype.scalar()).broadcast(ret.dtype.count)
|
||||
return ret
|
||||
|
||||
def do_contract(con:UOp):
|
||||
|
@ -444,16 +489,16 @@ def do_contract(con:UOp):
|
|||
if ex.op is not UOps.EXPAND: return UOp(UOps.VECTORIZE, con.dtype, con.src*con.dtype.count)
|
||||
# CONTRACT may remove several axes from EXPAND
|
||||
assert con.dtype.count == prod([x[1] for x in con.arg]), "dtype is wrong"
|
||||
srcs = []
|
||||
idxs = []
|
||||
for rpk in _choices_from_args(new_ex_args:=tuple(x for x in ex.arg if x not in con.arg)):
|
||||
lsrcs = [ex.src[_expand_arg_to_idx(ex.arg, {**rpk, **lrpk})] for lrpk in _choices_from_args(con.arg)]
|
||||
srcs.append(UOp(UOps.VECTORIZE, con.dtype, tuple(lsrcs)))
|
||||
return srcs[0] if len(srcs) == 1 else UOp(UOps.EXPAND, con.dtype, tuple(srcs), new_ex_args)
|
||||
idxs += [_expand_arg_to_idx(ex.arg, {**rpk, **lrpk}) for lrpk in _choices_from_args(con.arg)]
|
||||
return UOp(UOps.EXPAND, con.dtype, (ex.src[0].gep(tuple(idxs)),), new_ex_args)
|
||||
|
||||
def no_vectorized_alu(alu):
|
||||
if alu.dtype.count == 1: return None
|
||||
alus = tuple(UOp(alu.op, alu.dtype.scalar(),
|
||||
tuple(UOp(UOps.GEP, s.dtype.scalar(), (s,), (i,)) for s in alu.src), alu.arg) for i in range(alu.dtype.count))
|
||||
tuple(UOp(UOps.GEP, s.dtype.scalar(), (s,), (i,)) if alu.op is not UOps.REDUCE or j == 0 else s for j,s in enumerate(alu.src)),
|
||||
alu.arg) for i in range(alu.dtype.count))
|
||||
return UOp(UOps.VECTORIZE, alu.dtype, alus)
|
||||
|
||||
def create_gate(root:UOp) -> Optional[UOp]:
|
||||
|
@ -470,9 +515,12 @@ expander = PatternMatcher([
|
|||
(UPat(UOps.VECTORIZE, src=UPat(UOps.GEP, src=(UPat(name="x"),)), name="vec"), lambda vec,x: x.gep(tuple(y.arg[0] for y in vec.src))),
|
||||
# create gate MUST BE BEFORE expander
|
||||
(UPat(UOps.STORE, name="root"), create_gate),
|
||||
# double expand
|
||||
(UPat(UOps.EXPAND, name="outer", src=(UPat(UOps.EXPAND, name="inner"),)),
|
||||
lambda outer, inner: UOp(UOps.EXPAND, outer.dtype, (inner.src[0],), inner.arg+outer.arg)),
|
||||
# do expansion
|
||||
(UPat((UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.GEP, UOps.WMMA, UOps.LOAD, UOps.STORE,
|
||||
UOps.VECTORIZE, UOps.REDUCE, UOps.EXPAND, UOps.IF), name="root", custom_early_reject=set([(UOps.EXPAND, None)])), do_expand),
|
||||
UOps.VECTORIZE, UOps.REDUCE, UOps.IF), name="root", custom_early_reject=set([(UOps.EXPAND, None)])), do_expand),
|
||||
(UPat(UOps.CONTRACT, name="con"), do_contract),
|
||||
# remove EXPANDs from SINK
|
||||
(UPat(UOps.SINK, name="root"),
|
||||
|
@ -488,6 +536,37 @@ expander = PatternMatcher([
|
|||
lambda ex,x,y: UOp(UOps.EXPAND, ex.dtype, tuple((x+y).gep(i) for i in range(256 if AMX else 8)), ex.arg)),
|
||||
])
|
||||
|
||||
just_reduce = PatternMatcher([
|
||||
# do reduce (in expander now)
|
||||
(UPat(UOps.REDUCE, name="root"), do_reduce),
|
||||
])
|
||||
|
||||
def no_vectorized_load_store(ls:UOp):
|
||||
idx = ls.src[1]
|
||||
if idx.dtype.count == 1: return None
|
||||
# ugh, the meaning of a dtype.count idx is overloaded
|
||||
if ls.op is UOps.LOAD and idx.dtype.count != ls.dtype.count: return None
|
||||
if ls.op is UOps.STORE and idx.dtype.count != ls.src[2].dtype.count: return None
|
||||
tv = [UOp(ls.op, ls.dtype.scalar(), (ls.src[0],) + tuple(j.gep(i) for j in ls.src[1:])) for i in range(idx.dtype.count)]
|
||||
return UOp(UOps.VECTORIZE, ls.dtype, tuple(tv))
|
||||
|
||||
def no_vectorized_acc(acc:UOp):
|
||||
if acc.dtype.count == 1: return None
|
||||
alus = tuple(UOp(acc.op, acc.dtype.scalar(),
|
||||
tuple(UOp(UOps.GEP, s.dtype.scalar(), (s,), (i,)) if j == 0 else s for j,s in enumerate(acc.src)), acc.arg+(i,)) for i in range(acc.dtype.count))
|
||||
return UOp(UOps.VECTORIZE, acc.dtype, alus)
|
||||
|
||||
def no_vectorized_wmma(wmma:UOp):
|
||||
out_sz = prod(x[1] for x in wmma.arg[6][-1])
|
||||
if wmma.dtype.count == out_sz: return None
|
||||
tsrcs = []
|
||||
for s,sz in zip(wmma.src, wmma.arg[6]):
|
||||
ssz = prod(x[1] for x in sz)
|
||||
tsrcs.append([s.gep(tuple(range(grp, grp+ssz))) for grp in range(0, s.dtype.count, ssz)])
|
||||
wmmas = [UOp(UOps.WMMA, wmma.dtype.scalar().vec(out_sz), tsrc, wmma.arg) for tsrc in zip(*tsrcs)]
|
||||
wmma_ex = flatten([[e.gep(i) for i in range(out_sz)] for e in wmmas])
|
||||
return UOp(UOps.VECTORIZE, wmma.dtype, tuple(wmma_ex))
|
||||
|
||||
def delete_redundant_gates(root:UOp) -> Optional[UOp]:
|
||||
@functools.lru_cache(None)
|
||||
def find_gate(x:UOp) -> Optional[UOp]:
|
||||
|
@ -496,14 +575,19 @@ def delete_redundant_gates(root:UOp) -> Optional[UOp]:
|
|||
if len(root.src) == 3 or (gate:=find_gate(root)) is None or gate.src[0] is not root.src[3]: return None
|
||||
return UOp(UOps.STORE, root.dtype, root.src[:3], root.arg)
|
||||
|
||||
devectorize = PatternMatcher([
|
||||
# no ALU on vectorized dtypes
|
||||
(UPat((UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.ASSIGN, UOps.REDUCE), name="alu"), no_vectorized_alu),
|
||||
(UPat(UOps.WMMA, name="wmma"), no_vectorized_wmma),
|
||||
(UPat(UOps.DEFINE_ACC, name="acc"), no_vectorized_acc),
|
||||
(UPat((UOps.LOAD, UOps.STORE), name="ls"), no_vectorized_load_store),
|
||||
])
|
||||
|
||||
reducer = PatternMatcher([
|
||||
(UPat(UOps.REDUCE, name="root"), do_reduce),
|
||||
(UPat(UOps.CONST, name='c'),
|
||||
lambda c: UOp(UOps.VECTORIZE, c.dtype, (UOp.const(c.dtype.scalar(), c.arg),)*c.dtype.count) if c.dtype.count > 1 else None),
|
||||
(UPat(UOps.VCONST, name='c'), lambda c: UOp(UOps.VECTORIZE, c.dtype, tuple(UOp.const(c.dtype.scalar(), x) for x in c.arg))),
|
||||
(UPat(UOps.GEP, name='gep'), lambda gep: UOp(UOps.VECTORIZE, gep.dtype, tuple(gep.src[0].gep(x) for x in gep.arg)) if len(gep.arg) > 1 else None),
|
||||
# no ALU on vectorized dtypes
|
||||
(UPat((UOps.ALU, UOps.CAST, UOps.BITCAST), name="alu"), no_vectorized_alu),
|
||||
# delete_redundant_gates (after expand, is this still needed?)
|
||||
(UPat(UOps.STORE, name="root"), delete_redundant_gates),
|
||||
# late fixup of unfoldable image loads
|
||||
|
@ -544,8 +628,11 @@ def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp:
|
|||
# expand
|
||||
linearize_cnt += 1
|
||||
if linearize_cnt != (de:=getenv("DEBUG_EXPAND", 0)) and de != -1:
|
||||
sink = graph_rewrite(sink, folder+(expander+float4_folding if opts is not None and opts.supports_float4 else expander))
|
||||
if getenv("DO_REDUCE", 1): sink = graph_rewrite(sink, folder+reducer)
|
||||
sink = graph_rewrite(sink, folder+expander)
|
||||
if getenv("DO_REDUCE", 1):
|
||||
sink = graph_rewrite(sink, folder+just_reduce)
|
||||
sink = graph_rewrite(sink, folder+(devectorize+float4_folding if opts is not None and opts.supports_float4 else devectorize))
|
||||
sink = graph_rewrite(sink, folder+reducer)
|
||||
|
||||
# for PTX only
|
||||
if opts is not None and opts.extra_matcher is not None: sink = graph_rewrite(sink, folder+opts.extra_matcher)
|
||||
|
|
|
@ -349,6 +349,9 @@ class UOp(MathTrait):
|
|||
def __init__(self, op: UOps, dtype:DType=dtypes.void, src: Tuple[UOp,...]=tuple(), arg:Any=None):
|
||||
# TODO: instant check rules here make debugging easier
|
||||
#if op is UOps.ALU and arg is BinaryOps.CMPNE: assert dtype.scalar() == dtypes.bool
|
||||
#if op is UOps.VECTORIZE and dtype != dtypes.void: assert len(src) == dtype.count, f"{len(src)} invalid for {dtype}"
|
||||
#if op is UOps.ALU and arg not in (BinaryOps.CMPNE, BinaryOps.CMPLT, TernaryOps.WHERE): assert all_same([dtype] + [x.dtype for x in src])
|
||||
#if op is UOps.CAST: assert dtype.count == src[0].dtype.count, f"cast can't change vectorization {src[0].dtype} --> {dtype}"
|
||||
self.op, self.dtype, self.src, self.arg = op, dtype, src, arg
|
||||
def replace(self, op: Optional[UOps]=None, dtype:Optional[DType]=None, src: Optional[Tuple[UOp,...]]=None, arg:Any=None):
|
||||
return UOp(op or self.op, dtype or self.dtype, self.src if src is None else src, self.arg if arg is None else arg)
|
||||
|
@ -387,6 +390,10 @@ class UOp(MathTrait):
|
|||
def sink(self, *srcs): return UOp(UOps.SINK, dtypes.void, (self,)+srcs)
|
||||
def swizzle(self, st:ShapeTracker): return UOp(UOps.SWIZZLE, self.dtype, (self,), st)
|
||||
def const_like(self, b:ConstType|Variable|Tuple[ConstType]): return type(self).const(self.dtype, b)
|
||||
def broadcast(self, count:int):
|
||||
assert self.dtype.count == 1
|
||||
if count == 1: return self
|
||||
return UOp(UOps.VECTORIZE, self.dtype.vec(count), (self,)*count)
|
||||
def cast(self, dtype:DType): return type(self)(UOps.CAST, dtype, (self,))
|
||||
def bitcast(self, dtype:DType): return type(self)(UOps.BITCAST, dtype, (self,))
|
||||
def gep(self, i:Union[Tuple[int, ...], int]):
|
||||
|
|
Loading…
Reference in New Issue