fix acc folding for NV tensor cores (#5658)

* fix acc folding for NV tensor cores

* fix correctness of reduce_before_expand
This commit is contained in:
George Hotz 2024-07-23 13:03:02 -07:00 committed by GitHub
parent 01fe00e055
commit 4d47968580
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 34 additions and 5 deletions

View File

@ -0,0 +1,26 @@
from tinygrad import Tensor, dtypes, Device
from tinygrad.codegen.kernel import Kernel, Opt, OptOps
from tinygrad.engine.realize import CompiledRunner, ExecItem
N = 4096
if __name__ == "__main__":
A, B = Tensor.empty(N, N, dtype=dtypes.float16), Tensor.empty(N, N, dtype=dtypes.float16)
C = A.matmul(B, acc_dtype=dtypes.float32)
si = C.schedule()[-1]
ast = si.ast
k = Kernel(ast, opts=Device[Device.DEFAULT].renderer)
opts = [Opt(op=OptOps.TC, axis=0, amt=0),
Opt(op=OptOps.UPCAST, axis=1, amt=16),
Opt(op=OptOps.UPCAST, axis=0, amt=2),
Opt(op=OptOps.LOCAL, axis=0, amt=4),
Opt(op=OptOps.UNROLL, axis=0, amt=4),
Opt(op=OptOps.LOCAL, axis=1, amt=2),
]
for opt in opts: k.apply_opt(opt)
prg = k.to_program()
ei = ExecItem(CompiledRunner(prg), [x.ensure_allocated() for x in si.bufs], si.metadata)
tflops = []
for i in range(5):
tm = ei.run(wait=True)
tflops.append((2*N*N*N/tm)*1e-12)
print(f"TFLOPS: {sum(tflops)/len(tflops):.2f}")

View File

@ -451,7 +451,7 @@ class Kernel:
elif opt.op is OptOps.UPCAST: # yellow
check(axis < self.first_reduce, "upcast is for non-reduce")
check(not(self.tensor_core and self.global_dims <= axis < self.global_dims+len(self.tensor_core.threads)), "can't upcast TC locals")
check(amt <= 8, "don't upcast more than 8")
check(amt <= 16, "don't upcast more than 16")
self.shift_to(axis, amt, insert_before=None)
self.upcast()
elif opt.op is OptOps.UPCASTMID: # white
@ -729,6 +729,7 @@ class Kernel:
if DEBUG >= 3:
print(self.name)
print(modified_ast)
print(self.applied_opts)
verify_lazyop(modified_ast)
uop_sink = lazyop_to_uop(modified_ast, self.opts)

View File

@ -105,6 +105,10 @@ def threefry2x32(x: UOp, seed: UOp):
# ***** main rewriter *****
def reduce_before_expand(reduce_allow_any_len, expand, x):
# if the expand is being reduced, you can't push it through
# NOTE: could do a partial push here in some cases
expands = flatten([x.arg for x in reduce_allow_any_len.src[1:] if x.op is UOps.EXPAND])
if any(x in expands for x in expand.arg): return None
red = UOp(UOps.REDUCE, x.dtype, (x,)+reduce_allow_any_len.src[1:], reduce_allow_any_len.arg)
gep = tuple(UOp(UOps.GEP, reduce_allow_any_len.dtype, (red,), i) for i in range(x.dtype.count))
return UOp(expand.op, expand.dtype, gep, expand.arg)
@ -154,10 +158,8 @@ constant_folder = PatternMatcher([
(UOp(UOps.WMMA, src=(UOp.const(None, 0.0), UOp.var(), UOp.var('acc'))), lambda acc: acc),
(UOp(UOps.WMMA, src=(UOp.var(), UOp.const(None, 0.0), UOp.var('acc'))), lambda acc: acc),
# tensor core cleanups
(UOp(UOps.REDUCE, src=(UOp(UOps.EXPAND, src=tuple(UOp(UOps.GEP, dtypes.float, src=(UOp.var('x'),), arg=i) for i in range(2))).name("expand"),))
.name("reduce_allow_any_len"), reduce_before_expand),
(UOp(UOps.REDUCE, src=(UOp(UOps.EXPAND, src=tuple(UOp(UOps.GEP, dtypes.float, src=(UOp.var('x'),), arg=i) for i in range(8))).name("expand"),))
.name("reduce_allow_any_len"), reduce_before_expand),
*[(UOp(UOps.REDUCE, src=(UOp(UOps.EXPAND, src=tuple(UOp(UOps.GEP, dtypes.float, src=(UOp.var('x'),), arg=i) for i in range(j))).name("expand"),))
.name("reduce_allow_any_len"), reduce_before_expand) for j in [2,4,8]],
(UOp.var("add") + UOp(UOps.WMMA).name("wmma"),
lambda add, wmma: UOp(wmma.op, wmma.dtype, (wmma.src[0], wmma.src[1], wmma.src[2]+add), wmma.arg)),
# threefry