mirror of https://github.com/commaai/tinygrad.git
work to make GEMV fast (#5824)
* work to make GEMV fast * half8 cast * align struct * fix amd * float8 is a later problem
This commit is contained in:
parent
2d90b7a103
commit
e6879035a0
|
@ -1,26 +1,39 @@
|
|||
from tinygrad import Tensor, dtypes, Device
|
||||
from tinygrad.helpers import getenv, DEBUG
|
||||
from tinygrad.engine.graph import print_globalcounters
|
||||
from tinygrad.codegen.kernel import Kernel, Opt, OptOps
|
||||
from tinygrad.engine.realize import CompiledRunner, ExecItem
|
||||
from dataclasses import replace
|
||||
|
||||
N = 4096
|
||||
if __name__ == "__main__":
|
||||
A, B = Tensor.empty(N, N, dtype=dtypes.float16), Tensor.empty(N, N, dtype=dtypes.float16)
|
||||
if getenv("GEMV"):
|
||||
A, B = Tensor.empty(1, N, dtype=dtypes.float), Tensor.empty(14336, N, dtype=dtypes.float16).T
|
||||
else:
|
||||
A, B = Tensor.empty(N, N, dtype=dtypes.float16), Tensor.empty(N, N, dtype=dtypes.float16)
|
||||
C = A.matmul(B)
|
||||
si = C.schedule()[-1]
|
||||
ast = si.ast
|
||||
k = Kernel(ast, opts=Device[Device.DEFAULT].renderer)
|
||||
opts = [Opt(op=OptOps.TC, axis=0, amt=0),
|
||||
Opt(op=OptOps.UPCAST, axis=1, amt=16),
|
||||
Opt(op=OptOps.UPCAST, axis=0, amt=2),
|
||||
Opt(op=OptOps.LOCAL, axis=0, amt=4),
|
||||
Opt(op=OptOps.UNROLL, axis=0, amt=4),
|
||||
Opt(op=OptOps.LOCAL, axis=1, amt=2),
|
||||
]
|
||||
if getenv("GEMV"):
|
||||
opts = [
|
||||
Opt(op=OptOps.UNROLL, axis=0, amt=8),
|
||||
Opt(op=OptOps.GROUP, axis=0, amt=32),
|
||||
]
|
||||
else:
|
||||
opts = [
|
||||
Opt(op=OptOps.TC, axis=0, amt=0),
|
||||
Opt(op=OptOps.UPCAST, axis=0, amt=4),
|
||||
Opt(op=OptOps.UPCAST, axis=1, amt=8),
|
||||
Opt(op=OptOps.LOCAL, axis=0, amt=2),
|
||||
Opt(op=OptOps.LOCAL, axis=1, amt=2),
|
||||
Opt(op=OptOps.LOCAL, axis=0, amt=2),
|
||||
]
|
||||
for opt in opts: k.apply_opt(opt)
|
||||
prg = k.to_program()
|
||||
new_src = prg.src
|
||||
# can mod source here
|
||||
prg = replace(prg, src=new_src)
|
||||
ei = ExecItem(CompiledRunner(prg), [x.ensure_allocated() for x in si.bufs], si.metadata)
|
||||
tflops = []
|
||||
for i in range(5):
|
||||
tm = ei.run(wait=True)
|
||||
tflops.append((2*N*N*N/tm)*1e-12)
|
||||
print(f"TFLOPS: {sum(tflops)/len(tflops):.2f}")
|
||||
for i in range(5): ei.run(wait=True)
|
||||
if DEBUG < 2: print_globalcounters()
|
||||
|
|
|
@ -0,0 +1,12 @@
|
|||
#!/usr/bin/env python3
|
||||
from tinygrad import Tensor, TinyJit, nn
|
||||
from extra.models.llama import FeedForward
|
||||
|
||||
if __name__ == "__main__":
|
||||
model = FeedForward(4096, 14336)
|
||||
for x in nn.state.get_parameters(model): x.replace(x.half()).realize()
|
||||
jrun = TinyJit(model)
|
||||
for i in range(5):
|
||||
print(f"*** run {i}")
|
||||
jrun(Tensor.rand(1, 4096))
|
||||
|
|
@ -12,12 +12,11 @@ with contextlib.suppress(ImportError): import networkx as nx
|
|||
|
||||
# **** debugging and graphing ****
|
||||
|
||||
if DEBUG >= 2:
|
||||
def print_globalcounters():
|
||||
if GlobalCounters.time_sum_s == 0: return
|
||||
print(f"avg: {GlobalCounters.global_ops*1e-9/GlobalCounters.time_sum_s:8.2f} GFLOPS {GlobalCounters.global_mem*1e-9/GlobalCounters.time_sum_s:8.2f} GB/s", # noqa: E501
|
||||
f"{' '*10}total: {GlobalCounters.kernel_count:5d} kernels {GlobalCounters.global_ops*1e-9:8.2f} GOPS {GlobalCounters.global_mem*1e-9:8.2f} GB {GlobalCounters.time_sum_s*1e3:8.2f} ms") # noqa: E501
|
||||
atexit.register(print_globalcounters)
|
||||
def print_globalcounters():
|
||||
if GlobalCounters.time_sum_s == 0: return
|
||||
print(f"avg: {GlobalCounters.global_ops*1e-9/GlobalCounters.time_sum_s:8.2f} GFLOPS {GlobalCounters.global_mem*1e-9/GlobalCounters.time_sum_s:8.2f} GB/s", # noqa: E501
|
||||
f"{' '*10}total: {GlobalCounters.kernel_count:5d} kernels {GlobalCounters.global_ops*1e-9:8.2f} GOPS {GlobalCounters.global_mem*1e-9:8.2f} GB {GlobalCounters.time_sum_s*1e3:8.2f} ms") # noqa: E501
|
||||
if DEBUG >= 2: atexit.register(print_globalcounters)
|
||||
|
||||
def save_graph(G, fn, opt=""):
|
||||
print("saving", G, f"to {fn}.svg")
|
||||
|
|
|
@ -177,7 +177,8 @@ class CStyleLanguage(Renderer):
|
|||
elif uop is UOps.GEP:
|
||||
assert src[0].dtype is not None
|
||||
from_ssa = src[0].op in {UOps.LOAD, UOps.WMMA, UOps.DEFINE_ACC}
|
||||
r[u] = (r[src[0]] if from_ssa else f"{(r[src[0]])}") + (f"[{args}]" if src[0].dtype.count > 4 else f".{'xyzw'[args]}")
|
||||
r[u] = (r[src[0]] if from_ssa else f"{(r[src[0]])}") + \
|
||||
(f"[{args}]" if src[0].dtype.count > (8 if self.device in {"CUDA", "NV"} else 4) else f".{'xyzwabcd'[args]}")
|
||||
else: raise RuntimeError(f"failed to render {u}")
|
||||
|
||||
return self.render_kernel(name, kernel, bufs, uops)
|
||||
|
@ -257,9 +258,9 @@ code_for_op_half = {UnaryOps.RECIP: lambda x,dtype: f"hrcp({x})" if dtype in (dt
|
|||
UnaryOps.EXP2: lambda x,dtype: f"hexp2({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"exp2({x})",}
|
||||
|
||||
_nms = "xyzwabcdefghijkl"
|
||||
def _make_cuda_dtype(base_type, name, cnt):
|
||||
def _make_cuda_dtype(base_type, name, cnt, align):
|
||||
vec, elems, header = f"{name}{cnt}", ', '.join(_nms[:cnt]), ', '.join([f"{base_type} {x}" for x in _nms[:cnt]])
|
||||
return f"struct {vec} {{ {base_type} {elems}; }}; __device__ {vec} make_{vec}({header}) {{ {vec} r={{{elems}}}; return r; }}"
|
||||
return f"struct __align__({align}) {vec} {{ {base_type} {elems}; }}; __device__ {vec} make_{vec}({header}) {{ {vec} r={{{elems}}}; return r; }}"
|
||||
|
||||
class CUDARenderer(CStyleLanguage):
|
||||
device = "CUDA"
|
||||
|
@ -286,10 +287,10 @@ class CUDARenderer(CStyleLanguage):
|
|||
|
||||
prefix = ["#define INFINITY (__int_as_float(0x7f800000))","#define NAN (__int_as_float(0x7fffffff))"]
|
||||
if any(uop.dtype == dtypes.half for uop in uops):
|
||||
prefix += ["#include <cuda_fp16.h>"] + [_make_cuda_dtype("half", "half", x) for x in [4, 8]]
|
||||
prefix += ["#include <cuda_fp16.h>"] + [_make_cuda_dtype("half", "half", x, x*2) for x in [4, 8]]
|
||||
|
||||
if any(uop.dtype == dtypes.bfloat16 for uop in uops):
|
||||
prefix += ["#include <cuda_bf16.h>"] + [_make_cuda_dtype("nv_bfloat16", "bfloat16", x) for x in [4, 8]]
|
||||
prefix += ["#include <cuda_bf16.h>"] + [_make_cuda_dtype("nv_bfloat16", "bfloat16", x, x*2) for x in [4, 8]]
|
||||
|
||||
# TODO: this has to be way better to generate for arbitrary M,N,K: use arg[1] for MNK, use arg[4] for vec sizes, encode register packing
|
||||
for arg in dedup([uop.arg for uop in uops if uop.op is UOps.WMMA]):
|
||||
|
|
Loading…
Reference in New Issue