mirror of https://github.com/commaai/tinygrad.git
fix acc folding for NV tensor cores (#5658)
* fix acc folding for NV tensor cores * fix correctness of reduce_before_expand
This commit is contained in:
parent
01fe00e055
commit
4d47968580
|
@ -0,0 +1,26 @@
|
||||||
|
from tinygrad import Tensor, dtypes, Device
|
||||||
|
from tinygrad.codegen.kernel import Kernel, Opt, OptOps
|
||||||
|
from tinygrad.engine.realize import CompiledRunner, ExecItem
|
||||||
|
|
||||||
|
N = 4096
|
||||||
|
if __name__ == "__main__":
|
||||||
|
A, B = Tensor.empty(N, N, dtype=dtypes.float16), Tensor.empty(N, N, dtype=dtypes.float16)
|
||||||
|
C = A.matmul(B, acc_dtype=dtypes.float32)
|
||||||
|
si = C.schedule()[-1]
|
||||||
|
ast = si.ast
|
||||||
|
k = Kernel(ast, opts=Device[Device.DEFAULT].renderer)
|
||||||
|
opts = [Opt(op=OptOps.TC, axis=0, amt=0),
|
||||||
|
Opt(op=OptOps.UPCAST, axis=1, amt=16),
|
||||||
|
Opt(op=OptOps.UPCAST, axis=0, amt=2),
|
||||||
|
Opt(op=OptOps.LOCAL, axis=0, amt=4),
|
||||||
|
Opt(op=OptOps.UNROLL, axis=0, amt=4),
|
||||||
|
Opt(op=OptOps.LOCAL, axis=1, amt=2),
|
||||||
|
]
|
||||||
|
for opt in opts: k.apply_opt(opt)
|
||||||
|
prg = k.to_program()
|
||||||
|
ei = ExecItem(CompiledRunner(prg), [x.ensure_allocated() for x in si.bufs], si.metadata)
|
||||||
|
tflops = []
|
||||||
|
for i in range(5):
|
||||||
|
tm = ei.run(wait=True)
|
||||||
|
tflops.append((2*N*N*N/tm)*1e-12)
|
||||||
|
print(f"TFLOPS: {sum(tflops)/len(tflops):.2f}")
|
|
@ -451,7 +451,7 @@ class Kernel:
|
||||||
elif opt.op is OptOps.UPCAST: # yellow
|
elif opt.op is OptOps.UPCAST: # yellow
|
||||||
check(axis < self.first_reduce, "upcast is for non-reduce")
|
check(axis < self.first_reduce, "upcast is for non-reduce")
|
||||||
check(not(self.tensor_core and self.global_dims <= axis < self.global_dims+len(self.tensor_core.threads)), "can't upcast TC locals")
|
check(not(self.tensor_core and self.global_dims <= axis < self.global_dims+len(self.tensor_core.threads)), "can't upcast TC locals")
|
||||||
check(amt <= 8, "don't upcast more than 8")
|
check(amt <= 16, "don't upcast more than 16")
|
||||||
self.shift_to(axis, amt, insert_before=None)
|
self.shift_to(axis, amt, insert_before=None)
|
||||||
self.upcast()
|
self.upcast()
|
||||||
elif opt.op is OptOps.UPCASTMID: # white
|
elif opt.op is OptOps.UPCASTMID: # white
|
||||||
|
@ -729,6 +729,7 @@ class Kernel:
|
||||||
if DEBUG >= 3:
|
if DEBUG >= 3:
|
||||||
print(self.name)
|
print(self.name)
|
||||||
print(modified_ast)
|
print(modified_ast)
|
||||||
|
print(self.applied_opts)
|
||||||
verify_lazyop(modified_ast)
|
verify_lazyop(modified_ast)
|
||||||
|
|
||||||
uop_sink = lazyop_to_uop(modified_ast, self.opts)
|
uop_sink = lazyop_to_uop(modified_ast, self.opts)
|
||||||
|
|
|
@ -105,6 +105,10 @@ def threefry2x32(x: UOp, seed: UOp):
|
||||||
# ***** main rewriter *****
|
# ***** main rewriter *****
|
||||||
|
|
||||||
def reduce_before_expand(reduce_allow_any_len, expand, x):
|
def reduce_before_expand(reduce_allow_any_len, expand, x):
|
||||||
|
# if the expand is being reduced, you can't push it through
|
||||||
|
# NOTE: could do a partial push here in some cases
|
||||||
|
expands = flatten([x.arg for x in reduce_allow_any_len.src[1:] if x.op is UOps.EXPAND])
|
||||||
|
if any(x in expands for x in expand.arg): return None
|
||||||
red = UOp(UOps.REDUCE, x.dtype, (x,)+reduce_allow_any_len.src[1:], reduce_allow_any_len.arg)
|
red = UOp(UOps.REDUCE, x.dtype, (x,)+reduce_allow_any_len.src[1:], reduce_allow_any_len.arg)
|
||||||
gep = tuple(UOp(UOps.GEP, reduce_allow_any_len.dtype, (red,), i) for i in range(x.dtype.count))
|
gep = tuple(UOp(UOps.GEP, reduce_allow_any_len.dtype, (red,), i) for i in range(x.dtype.count))
|
||||||
return UOp(expand.op, expand.dtype, gep, expand.arg)
|
return UOp(expand.op, expand.dtype, gep, expand.arg)
|
||||||
|
@ -154,10 +158,8 @@ constant_folder = PatternMatcher([
|
||||||
(UOp(UOps.WMMA, src=(UOp.const(None, 0.0), UOp.var(), UOp.var('acc'))), lambda acc: acc),
|
(UOp(UOps.WMMA, src=(UOp.const(None, 0.0), UOp.var(), UOp.var('acc'))), lambda acc: acc),
|
||||||
(UOp(UOps.WMMA, src=(UOp.var(), UOp.const(None, 0.0), UOp.var('acc'))), lambda acc: acc),
|
(UOp(UOps.WMMA, src=(UOp.var(), UOp.const(None, 0.0), UOp.var('acc'))), lambda acc: acc),
|
||||||
# tensor core cleanups
|
# tensor core cleanups
|
||||||
(UOp(UOps.REDUCE, src=(UOp(UOps.EXPAND, src=tuple(UOp(UOps.GEP, dtypes.float, src=(UOp.var('x'),), arg=i) for i in range(2))).name("expand"),))
|
*[(UOp(UOps.REDUCE, src=(UOp(UOps.EXPAND, src=tuple(UOp(UOps.GEP, dtypes.float, src=(UOp.var('x'),), arg=i) for i in range(j))).name("expand"),))
|
||||||
.name("reduce_allow_any_len"), reduce_before_expand),
|
.name("reduce_allow_any_len"), reduce_before_expand) for j in [2,4,8]],
|
||||||
(UOp(UOps.REDUCE, src=(UOp(UOps.EXPAND, src=tuple(UOp(UOps.GEP, dtypes.float, src=(UOp.var('x'),), arg=i) for i in range(8))).name("expand"),))
|
|
||||||
.name("reduce_allow_any_len"), reduce_before_expand),
|
|
||||||
(UOp.var("add") + UOp(UOps.WMMA).name("wmma"),
|
(UOp.var("add") + UOp(UOps.WMMA).name("wmma"),
|
||||||
lambda add, wmma: UOp(wmma.op, wmma.dtype, (wmma.src[0], wmma.src[1], wmma.src[2]+add), wmma.arg)),
|
lambda add, wmma: UOp(wmma.op, wmma.dtype, (wmma.src[0], wmma.src[1], wmma.src[2]+add), wmma.arg)),
|
||||||
# threefry
|
# threefry
|
||||||
|
|
Loading…
Reference in New Issue