switch contract arg to match expand arg [run_process_replay] (#5667)

* switch contract arg to match expand arg [run_process_replay]

* support multiaxis contract too, it's easy

* cancel contract/expand
This commit is contained in:
George Hotz 2024-07-23 18:08:33 -07:00 committed by GitHub
parent ea99efe815
commit fa14f7b4fd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 20 additions and 27 deletions

View File

@ -285,14 +285,14 @@ class TestExpander(unittest.TestCase):
def test_contract_simple(self):
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),))
con = UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), (1,))
con = UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), ((1,4),))
sink = expander_rewrite(con)
assert sink.op is UOps.VECTORIZE and len(sink.src) == 4
self.assertListEqual([x.arg for x in sink.src], [0,1,2,3])
def test_contract_axis_1(self):
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(16)), ((1,4),(2,4)))
con = UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), (1,))
con = UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), ((1,4),))
sink = expander_rewrite(con)
assert sink.op is UOps.EXPAND and len(sink.src) == 4 and sink.arg == ((2,4),)
assert sink.src[0].op is UOps.VECTORIZE and len(sink.src[0].src) == 4
@ -301,7 +301,7 @@ class TestExpander(unittest.TestCase):
def test_contract_axis_2(self):
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(16)), ((1,4),(2,4)))
con = UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), (2,))
con = UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), ((2,4),))
sink = expander_rewrite(con)
assert sink.op is UOps.EXPAND and len(sink.src) == 4 and sink.arg == ((1,4),)
assert sink.src[0].op is UOps.VECTORIZE and len(sink.src[0].src) == 4
@ -310,25 +310,24 @@ class TestExpander(unittest.TestCase):
def test_contract_axis_2_big(self):
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(16)), ((1,2),(2,2),(3,2),(4,2)))
con = UOp(UOps.CONTRACT, dtypes.int.vec(2), (e1,), (2,))
con = UOp(UOps.CONTRACT, dtypes.int.vec(2), (e1,), ((2,2),))
sink = expander_rewrite(con)
assert sink.op is UOps.EXPAND and sink.arg == ((1, 2), (3, 2), (4, 2))
self.assertListEqual([x.arg for x in sink.src[0].src], [0,4])
self.assertListEqual([x.arg for x in sink.src[6].src], [10,14])
@unittest.skip("TODO: add support for this")
def test_contract_multi_axis(self):
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(16)), ((1,2),(2,2),(3,2),(4,2)))
sink = expander_rewrite(UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), (3,2)))
sink = expander_rewrite(UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), ((3,2),(2,2))))
assert sink.op is UOps.EXPAND and sink.arg == ((1, 2), (4, 2))
self.assertListEqual([x.arg for x in sink.src[0].src], [0,4,2,6])
sink = expander_rewrite(UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), (2,3)))
sink = expander_rewrite(UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), ((2,2),(3,2))))
assert sink.op is UOps.EXPAND and sink.arg == ((1, 2), (4, 2))
self.assertListEqual([x.arg for x in sink.src[0].src], [0,2,4,6])
def test_contract_mid(self):
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(8)), ((1,2),(2,2),(3,2)))
con = UOp(UOps.CONTRACT, dtypes.int.vec(2), (e1,), (2,))
con = UOp(UOps.CONTRACT, dtypes.int.vec(2), (e1,), ((2,2),))
sink = expander_rewrite(con)
assert sink.op is UOps.EXPAND and len(sink.src) == 4 and sink.arg == ((1,2),(3,2))
assert sink.src[0].op is UOps.VECTORIZE and len(sink.src[0].src) == 2

View File

@ -173,8 +173,8 @@ class IndependentLowerer:
if x.op is ReduceOps.WMMA:
wmma_sz, upcast_axis = x.arg[4], x.arg[6]
ret = UOp(UOps.WMMA, dtype=dtype.vec(wmma_sz[2]), src=(
UOp(UOps.CONTRACT, dtype=cast(DType, in_uops[0].dtype).vec(wmma_sz[0]), src=(in_uops[0],), arg=(upcast_axis[0],)),
UOp(UOps.CONTRACT, dtype=cast(DType, in_uops[1].dtype).vec(wmma_sz[1]), src=(in_uops[1],), arg=(upcast_axis[1],)),
UOp(UOps.CONTRACT, dtype=cast(DType, in_uops[0].dtype).vec(wmma_sz[0]), src=(in_uops[0],), arg=((upcast_axis[0], wmma_sz[0]),)),
UOp(UOps.CONTRACT, dtype=cast(DType, in_uops[1].dtype).vec(wmma_sz[1]), src=(in_uops[1],), arg=((upcast_axis[1], wmma_sz[1]),)),
UOp.const(dtype.vec(wmma_sz[2]), 0.0)), arg=x.arg)
return UOp(UOps.EXPAND, dtype, tuple(UOp(UOps.GEP, dtype, (ret,), i) for i in range(wmma_sz[2])), arg=((upcast_axis[2], wmma_sz[2]),))
# NOTE: always using ridxs is fine here

View File

@ -23,7 +23,7 @@ def image_contract_load(buf, idx, idy, id4, ls_allow_any_len):
ls_allow_any_len.const(float('nan')))
def image_contract_store(buf, ex, idx, idy, ls_allow_any_len, var):
new_var = UOp(UOps.CONTRACT, var.dtype.vec(4), (var,), (ex.arg[0][0],))
new_var = UOp(UOps.CONTRACT, var.dtype.vec(4), (var,), ((ex.arg[0][0],4),))
return UOp(UOps.STORE, None, (buf, UOp(UOps.VECTORIZE, dtypes.int.vec(2), (idx, idy)), new_var) + ls_allow_any_len.src[3:])
# ***** float4 handling *****
@ -47,7 +47,7 @@ def float4_contract_store(buf, ex, var, store_allow_any_len, idx=UOp.const(dtype
if idx3 is not None: idx = idx + idx3
if not idx.divides(len(ex.src)): return None
new_var = UOp(UOps.CONTRACT, var.dtype.vec(len(ex.src)), (var,), (ex.arg[0][0],))
new_var = UOp(UOps.CONTRACT, var.dtype.vec(len(ex.src)), (var,), ((ex.arg[0][0],len(ex.src)),))
return UOp(UOps.STORE, None, (buf, idx, new_var) + store_allow_any_len.src[3:])
float4_folding = PatternMatcher([
@ -379,23 +379,17 @@ def do_contract(con:UOp):
ex = con.src[0]
assert con.dtype is not None
# CONTRACT without EXPAND repeats the element VECTORIZED
if ex.op is not UOps.EXPAND: return UOp(UOps.VECTORIZE, con.dtype, con.src*con.dtype.count)
# simple CONTRACT and EXPAND cancel out
if len(ex.arg) == 1 and len(con.arg) == 1 and ex.arg[0][0] in con.arg: return UOp(UOps.VECTORIZE, con.dtype, ex.src)
# complex CONTRACT may only remove one axis from EXPAND
assert len(con.arg) == 1, "contract arg one is all that's supported"
try:
split_index = [x[0] for x in ex.arg].index(con.arg[0])
except ValueError:
# CONTRACT without EXPAND (still) repeats the element VECTORIZED
if ex.op is not UOps.EXPAND or not all(x in ex.arg for x in con.arg):
assert ex.op is not UOps.EXPAND or not any(x in ex.arg for x in con.arg), "partial contract not supported"
return UOp(UOps.VECTORIZE, con.dtype, con.src*con.dtype.count)
assert con.dtype.count == ex.arg[split_index][1], "contract arg must match"
number_after = prod([x[1] for x in ex.arg[split_index+1:]])
to_join = [ex.src[i:i+number_after] for i in range(0, len(ex.src), number_after)]
# simple CONTRACT and EXPAND cancel out
if len(ex.arg) == 1 and len(con.arg) == 1 and ex.arg == con.arg: return UOp(UOps.VECTORIZE, con.dtype, ex.src)
# complex CONTRACT may remove several axes from EXPAND
srcs = []
for i in range(0, len(to_join), con.dtype.count):
srcs += [UOp(UOps.VECTORIZE, con.dtype, tuple(src)) for src in zip(*to_join[i:i+con.dtype.count])]
return UOp(UOps.EXPAND, con.dtype, tuple(srcs), tuple(x for x in ex.arg if x[0] != con.arg[0]))
for rpk in _choices_from_args(new_ex_args:=tuple(x for x in ex.arg if x not in con.arg)):
lsrcs = [ex.src[_expand_arg_to_idx(ex.arg, {**rpk, **lrpk})] for lrpk in _choices_from_args(con.arg)]
srcs.append(UOp(UOps.VECTORIZE, con.dtype, tuple(lsrcs)))
return UOp(UOps.EXPAND, con.dtype, tuple(srcs), new_ex_args)
def no_vectorized_alu(alu):
if alu.dtype.count == 1: return None