mirror of https://github.com/commaai/tinygrad.git
move WMMA out of lowerer [run_process_replay] (#6647)
This commit is contained in:
parent
84703d5b77
commit
0eb710de84
|
@ -711,7 +711,14 @@ class Kernel:
|
|||
# MUL/SUM instead of WMMA
|
||||
ret = UOp(UOps.REDUCE_AXIS, tc.dtype_out, (srcs[0].alu(BinaryOps.MUL, srcs[1]).cast(tc.dtype_out),), (alu_op, wmma_arg[-1]))
|
||||
else:
|
||||
ret = UOp(UOps.WMMA, tc.dtype_out, (fixup_ast(rsrc.src[0], fix_st1), fixup_ast(rsrc.src[1], fix_st2)), wmma_arg)
|
||||
# real WMMA, use CONTRACT/EXPAND to get the vectorization right
|
||||
wmma_upcast_axes = wmma_arg[-2]
|
||||
wmma_sz = [prod(x[1] for x in l) for l in wmma_upcast_axes]
|
||||
wmma = UOp(UOps.WMMA, dtype=tc.dtype_out.vec(wmma_sz[2]), src=(
|
||||
UOp(UOps.CONTRACT, dtype=rsrc.src[0].dtype.vec(wmma_sz[0]), src=(fixup_ast(rsrc.src[0], fix_st1),), arg=wmma_upcast_axes[0]),
|
||||
UOp(UOps.CONTRACT, dtype=rsrc.src[1].dtype.vec(wmma_sz[1]), src=(fixup_ast(rsrc.src[1], fix_st2),), arg=wmma_upcast_axes[1]),
|
||||
UOp.const(tc.dtype_out.vec(wmma_sz[2]), 0.0)), arg=wmma_arg)
|
||||
ret = UOp(UOps.EXPAND, tc.dtype_out, (wmma,), arg=wmma_upcast_axes[2])
|
||||
new_reduce_axes = tuple(i for i in axis if i-self.first_upcast not in reduce_axes)
|
||||
return op.replace(src=(ret,), arg=(alu_op, new_reduce_axes)) if new_reduce_axes else ret
|
||||
if self.group_for_reduces:
|
||||
|
@ -789,7 +796,8 @@ def _assert_valid_uop(uop:UOp, st:ShapeTracker, sts:Dict[UOp, ShapeTracker]) ->
|
|||
if op in {UOps.REDUCE_AXIS, UOps.WMMA}: st = ShapeTracker.from_shape(sts[src[0]].reduce(arg[-1]))
|
||||
elif op is UOps.SWIZZLE: st = arg
|
||||
else:
|
||||
assert op in {UOps.SHAPETRACKER, UOps.SWIZZLE, UOps.ALU, UOps.CAST, UOps.BITCAST, *BUFFER_UOPS}, f"bad UOp in intermediate uops {uop}"
|
||||
assert op in {UOps.SHAPETRACKER, UOps.SWIZZLE, UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.CONTRACT, UOps.EXPAND, *BUFFER_UOPS}, \
|
||||
f"bad UOp in intermediate uops {uop}"
|
||||
# movementops are pushed to the edges with SHAPETRACKER
|
||||
# elementwise inherits shape
|
||||
st = arg if op is UOps.SHAPETRACKER else sts[src[uop.st_loc if op in BUFFER_UOPS else 0]]
|
||||
|
|
|
@ -35,16 +35,6 @@ def get_grouped_dims(prefix, dims:Tuple[sint, ...], max_sizes:Optional[Tuple[int
|
|||
idx //= dims[c]
|
||||
return ret[::-1] if reverse else ret
|
||||
|
||||
# TODO: move this to kernel.py, it doesn't depend on axes
|
||||
def lower_wmma(ctx: IndependentLowerer, x: UOp):
|
||||
upcast_axes = x.arg[-2]
|
||||
wmma_sz = [prod(x[1] for x in l) for l in upcast_axes]
|
||||
ret = UOp(UOps.WMMA, dtype=x.dtype.vec(wmma_sz[2]), src=(
|
||||
UOp(UOps.CONTRACT, dtype=x.src[0].dtype.vec(wmma_sz[0]), src=(x.src[0],), arg=upcast_axes[0]),
|
||||
UOp(UOps.CONTRACT, dtype=x.src[1].dtype.vec(wmma_sz[1]), src=(x.src[1],), arg=upcast_axes[1]),
|
||||
UOp.const(x.dtype.vec(wmma_sz[2]), 0.0)), arg=x.arg)
|
||||
return UOp(UOps.EXPAND, x.dtype, (ret,), arg=upcast_axes[2])
|
||||
|
||||
def lower_reduce_axis(ctx: IndependentLowerer, x: UOp):
|
||||
# NOTE: always using ridxs is fine here
|
||||
reduce_range, reduce_expand = partition([ctx.ridxs[i] for i in x.arg[1]], lambda y: y.op is UOps.RANGE)
|
||||
|
@ -75,7 +65,6 @@ def lower_load_store(ctx: IndependentLowerer, x: UOp):
|
|||
return UOp(UOps.STORE, dtypes.void, (buf, idx, x.src[2]) + ((valid,) if has_valid else ()))
|
||||
|
||||
pm_lowerer = PatternMatcher([
|
||||
(UPat(UOps.WMMA, src=(UPat(), UPat()), name="x"), lower_wmma), # 2 param -> 3 param WMMA
|
||||
(UPat(UOps.REDUCE_AXIS, name="x"), lower_reduce_axis),
|
||||
(UPat(UOps.VALID, src=(UPat(UOps.SHAPETRACKER),), name="x"), lambda ctx,x: x.st_arg.to_indexed_uops(ctx.idxs)[1]),
|
||||
# rewrite LOAD/STORE SHAPETRACKER to LOAD/STORE with indexed
|
||||
|
|
|
@ -618,10 +618,10 @@ spec = PatternMatcher([(x, functools.partial(lambda fxn,**kw: UOp.const(dtypes.b
|
|||
(UPat(UOps.ASSIGN, src=(UPat(UOps.DEFINE_ACC), UPat())), lambda: True),
|
||||
(UPat(UOps.ENDRANGE, dtype=dtypes.void, src=(UPat(UOps.RANGE),)), lambda: True),
|
||||
|
||||
# early WMMA has 2 args, <x, w>
|
||||
(UPat(UOps.WMMA, src=(UPat(), UPat())), lambda: True),
|
||||
# late WMMA has 3 args, <x, w, acc>
|
||||
# all WMMA has 3 args, <x, w, acc>
|
||||
(UPat(UOps.WMMA, src=(UPat(), UPat(), UPat())), lambda: True),
|
||||
(UPat(UOps.CONTRACT, name="x"), lambda x: x.dtype.count == prod(y[1] for y in x.arg)),
|
||||
(UPat(UOps.EXPAND, name="x"), lambda x: x.src[0].dtype.count == prod(y[1] for y in x.arg)),
|
||||
|
||||
# if has a <gate, barrier>
|
||||
(UPat(UOps.IF, dtype=dtypes.void, src=(UPat(), UPat(UOps.BARRIER))), lambda: True),
|
||||
|
|
Loading…
Reference in New Issue