move WMMA out of lowerer [run_process_replay] (#6647)

This commit is contained in:
George Hotz 2024-09-22 14:05:51 +08:00 committed by GitHub
parent 84703d5b77
commit 0eb710de84
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 13 additions and 16 deletions

View File

@ -711,7 +711,14 @@ class Kernel:
# MUL/SUM instead of WMMA
ret = UOp(UOps.REDUCE_AXIS, tc.dtype_out, (srcs[0].alu(BinaryOps.MUL, srcs[1]).cast(tc.dtype_out),), (alu_op, wmma_arg[-1]))
else:
ret = UOp(UOps.WMMA, tc.dtype_out, (fixup_ast(rsrc.src[0], fix_st1), fixup_ast(rsrc.src[1], fix_st2)), wmma_arg)
# real WMMA, use CONTRACT/EXPAND to get the vectorization right
wmma_upcast_axes = wmma_arg[-2]
wmma_sz = [prod(x[1] for x in l) for l in wmma_upcast_axes]
wmma = UOp(UOps.WMMA, dtype=tc.dtype_out.vec(wmma_sz[2]), src=(
UOp(UOps.CONTRACT, dtype=rsrc.src[0].dtype.vec(wmma_sz[0]), src=(fixup_ast(rsrc.src[0], fix_st1),), arg=wmma_upcast_axes[0]),
UOp(UOps.CONTRACT, dtype=rsrc.src[1].dtype.vec(wmma_sz[1]), src=(fixup_ast(rsrc.src[1], fix_st2),), arg=wmma_upcast_axes[1]),
UOp.const(tc.dtype_out.vec(wmma_sz[2]), 0.0)), arg=wmma_arg)
ret = UOp(UOps.EXPAND, tc.dtype_out, (wmma,), arg=wmma_upcast_axes[2])
new_reduce_axes = tuple(i for i in axis if i-self.first_upcast not in reduce_axes)
return op.replace(src=(ret,), arg=(alu_op, new_reduce_axes)) if new_reduce_axes else ret
if self.group_for_reduces:
@ -789,7 +796,8 @@ def _assert_valid_uop(uop:UOp, st:ShapeTracker, sts:Dict[UOp, ShapeTracker]) ->
if op in {UOps.REDUCE_AXIS, UOps.WMMA}: st = ShapeTracker.from_shape(sts[src[0]].reduce(arg[-1]))
elif op is UOps.SWIZZLE: st = arg
else:
assert op in {UOps.SHAPETRACKER, UOps.SWIZZLE, UOps.ALU, UOps.CAST, UOps.BITCAST, *BUFFER_UOPS}, f"bad UOp in intermediate uops {uop}"
assert op in {UOps.SHAPETRACKER, UOps.SWIZZLE, UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.CONTRACT, UOps.EXPAND, *BUFFER_UOPS}, \
f"bad UOp in intermediate uops {uop}"
# movementops are pushed to the edges with SHAPETRACKER
# elementwise inherits shape
st = arg if op is UOps.SHAPETRACKER else sts[src[uop.st_loc if op in BUFFER_UOPS else 0]]

View File

@ -35,16 +35,6 @@ def get_grouped_dims(prefix, dims:Tuple[sint, ...], max_sizes:Optional[Tuple[int
idx //= dims[c]
return ret[::-1] if reverse else ret
# TODO: move this to kernel.py, it doesn't depend on axes
def lower_wmma(ctx: IndependentLowerer, x: UOp):
upcast_axes = x.arg[-2]
wmma_sz = [prod(x[1] for x in l) for l in upcast_axes]
ret = UOp(UOps.WMMA, dtype=x.dtype.vec(wmma_sz[2]), src=(
UOp(UOps.CONTRACT, dtype=x.src[0].dtype.vec(wmma_sz[0]), src=(x.src[0],), arg=upcast_axes[0]),
UOp(UOps.CONTRACT, dtype=x.src[1].dtype.vec(wmma_sz[1]), src=(x.src[1],), arg=upcast_axes[1]),
UOp.const(x.dtype.vec(wmma_sz[2]), 0.0)), arg=x.arg)
return UOp(UOps.EXPAND, x.dtype, (ret,), arg=upcast_axes[2])
def lower_reduce_axis(ctx: IndependentLowerer, x: UOp):
# NOTE: always using ridxs is fine here
reduce_range, reduce_expand = partition([ctx.ridxs[i] for i in x.arg[1]], lambda y: y.op is UOps.RANGE)
@ -75,7 +65,6 @@ def lower_load_store(ctx: IndependentLowerer, x: UOp):
return UOp(UOps.STORE, dtypes.void, (buf, idx, x.src[2]) + ((valid,) if has_valid else ()))
pm_lowerer = PatternMatcher([
(UPat(UOps.WMMA, src=(UPat(), UPat()), name="x"), lower_wmma), # 2 param -> 3 param WMMA
(UPat(UOps.REDUCE_AXIS, name="x"), lower_reduce_axis),
(UPat(UOps.VALID, src=(UPat(UOps.SHAPETRACKER),), name="x"), lambda ctx,x: x.st_arg.to_indexed_uops(ctx.idxs)[1]),
# rewrite LOAD/STORE SHAPETRACKER to LOAD/STORE with indexed

View File

@ -618,10 +618,10 @@ spec = PatternMatcher([(x, functools.partial(lambda fxn,**kw: UOp.const(dtypes.b
(UPat(UOps.ASSIGN, src=(UPat(UOps.DEFINE_ACC), UPat())), lambda: True),
(UPat(UOps.ENDRANGE, dtype=dtypes.void, src=(UPat(UOps.RANGE),)), lambda: True),
# early WMMA has 2 args, <x, w>
(UPat(UOps.WMMA, src=(UPat(), UPat())), lambda: True),
# late WMMA has 3 args, <x, w, acc>
# all WMMA has 3 args, <x, w, acc>
(UPat(UOps.WMMA, src=(UPat(), UPat(), UPat())), lambda: True),
(UPat(UOps.CONTRACT, name="x"), lambda x: x.dtype.count == prod(y[1] for y in x.arg)),
(UPat(UOps.EXPAND, name="x"), lambda x: x.src[0].dtype.count == prod(y[1] for y in x.arg)),
# if has a <gate, barrier>
(UPat(UOps.IF, dtype=dtypes.void, src=(UPat(), UPat(UOps.BARRIER))), lambda: True),