work to make GEMV fast (#5824)

* work to make GEMV fast

* half8 cast

* align struct

* fix amd

* float8 is a later problem
This commit is contained in:
George Hotz 2024-07-30 17:41:40 -07:00 committed by GitHub
parent 2d90b7a103
commit e6879035a0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 49 additions and 24 deletions

View File

@ -1,26 +1,39 @@
from tinygrad import Tensor, dtypes, Device
from tinygrad.helpers import getenv, DEBUG
from tinygrad.engine.graph import print_globalcounters
from tinygrad.codegen.kernel import Kernel, Opt, OptOps
from tinygrad.engine.realize import CompiledRunner, ExecItem
from dataclasses import replace
N = 4096
if __name__ == "__main__":
A, B = Tensor.empty(N, N, dtype=dtypes.float16), Tensor.empty(N, N, dtype=dtypes.float16)
if getenv("GEMV"):
A, B = Tensor.empty(1, N, dtype=dtypes.float), Tensor.empty(14336, N, dtype=dtypes.float16).T
else:
A, B = Tensor.empty(N, N, dtype=dtypes.float16), Tensor.empty(N, N, dtype=dtypes.float16)
C = A.matmul(B)
si = C.schedule()[-1]
ast = si.ast
k = Kernel(ast, opts=Device[Device.DEFAULT].renderer)
opts = [Opt(op=OptOps.TC, axis=0, amt=0),
Opt(op=OptOps.UPCAST, axis=1, amt=16),
Opt(op=OptOps.UPCAST, axis=0, amt=2),
Opt(op=OptOps.LOCAL, axis=0, amt=4),
Opt(op=OptOps.UNROLL, axis=0, amt=4),
Opt(op=OptOps.LOCAL, axis=1, amt=2),
]
if getenv("GEMV"):
opts = [
Opt(op=OptOps.UNROLL, axis=0, amt=8),
Opt(op=OptOps.GROUP, axis=0, amt=32),
]
else:
opts = [
Opt(op=OptOps.TC, axis=0, amt=0),
Opt(op=OptOps.UPCAST, axis=0, amt=4),
Opt(op=OptOps.UPCAST, axis=1, amt=8),
Opt(op=OptOps.LOCAL, axis=0, amt=2),
Opt(op=OptOps.LOCAL, axis=1, amt=2),
Opt(op=OptOps.LOCAL, axis=0, amt=2),
]
for opt in opts: k.apply_opt(opt)
prg = k.to_program()
new_src = prg.src
# can mod source here
prg = replace(prg, src=new_src)
ei = ExecItem(CompiledRunner(prg), [x.ensure_allocated() for x in si.bufs], si.metadata)
tflops = []
for i in range(5):
tm = ei.run(wait=True)
tflops.append((2*N*N*N/tm)*1e-12)
print(f"TFLOPS: {sum(tflops)/len(tflops):.2f}")
for i in range(5): ei.run(wait=True)
if DEBUG < 2: print_globalcounters()

View File

@ -0,0 +1,12 @@
#!/usr/bin/env python3
from tinygrad import Tensor, TinyJit, nn
from extra.models.llama import FeedForward
if __name__ == "__main__":
model = FeedForward(4096, 14336)
for x in nn.state.get_parameters(model): x.replace(x.half()).realize()
jrun = TinyJit(model)
for i in range(5):
print(f"*** run {i}")
jrun(Tensor.rand(1, 4096))

View File

@ -12,12 +12,11 @@ with contextlib.suppress(ImportError): import networkx as nx
# **** debugging and graphing ****
if DEBUG >= 2:
def print_globalcounters():
if GlobalCounters.time_sum_s == 0: return
print(f"avg: {GlobalCounters.global_ops*1e-9/GlobalCounters.time_sum_s:8.2f} GFLOPS {GlobalCounters.global_mem*1e-9/GlobalCounters.time_sum_s:8.2f} GB/s", # noqa: E501
f"{' '*10}total: {GlobalCounters.kernel_count:5d} kernels {GlobalCounters.global_ops*1e-9:8.2f} GOPS {GlobalCounters.global_mem*1e-9:8.2f} GB {GlobalCounters.time_sum_s*1e3:8.2f} ms") # noqa: E501
atexit.register(print_globalcounters)
def print_globalcounters():
if GlobalCounters.time_sum_s == 0: return
print(f"avg: {GlobalCounters.global_ops*1e-9/GlobalCounters.time_sum_s:8.2f} GFLOPS {GlobalCounters.global_mem*1e-9/GlobalCounters.time_sum_s:8.2f} GB/s", # noqa: E501
f"{' '*10}total: {GlobalCounters.kernel_count:5d} kernels {GlobalCounters.global_ops*1e-9:8.2f} GOPS {GlobalCounters.global_mem*1e-9:8.2f} GB {GlobalCounters.time_sum_s*1e3:8.2f} ms") # noqa: E501
if DEBUG >= 2: atexit.register(print_globalcounters)
def save_graph(G, fn, opt=""):
print("saving", G, f"to {fn}.svg")

View File

@ -177,7 +177,8 @@ class CStyleLanguage(Renderer):
elif uop is UOps.GEP:
assert src[0].dtype is not None
from_ssa = src[0].op in {UOps.LOAD, UOps.WMMA, UOps.DEFINE_ACC}
r[u] = (r[src[0]] if from_ssa else f"{(r[src[0]])}") + (f"[{args}]" if src[0].dtype.count > 4 else f".{'xyzw'[args]}")
r[u] = (r[src[0]] if from_ssa else f"{(r[src[0]])}") + \
(f"[{args}]" if src[0].dtype.count > (8 if self.device in {"CUDA", "NV"} else 4) else f".{'xyzwabcd'[args]}")
else: raise RuntimeError(f"failed to render {u}")
return self.render_kernel(name, kernel, bufs, uops)
@ -257,9 +258,9 @@ code_for_op_half = {UnaryOps.RECIP: lambda x,dtype: f"hrcp({x})" if dtype in (dt
UnaryOps.EXP2: lambda x,dtype: f"hexp2({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"exp2({x})",}
_nms = "xyzwabcdefghijkl"
def _make_cuda_dtype(base_type, name, cnt):
def _make_cuda_dtype(base_type, name, cnt, align):
vec, elems, header = f"{name}{cnt}", ', '.join(_nms[:cnt]), ', '.join([f"{base_type} {x}" for x in _nms[:cnt]])
return f"struct {vec} {{ {base_type} {elems}; }}; __device__ {vec} make_{vec}({header}) {{ {vec} r={{{elems}}}; return r; }}"
return f"struct __align__({align}) {vec} {{ {base_type} {elems}; }}; __device__ {vec} make_{vec}({header}) {{ {vec} r={{{elems}}}; return r; }}"
class CUDARenderer(CStyleLanguage):
device = "CUDA"
@ -286,10 +287,10 @@ class CUDARenderer(CStyleLanguage):
prefix = ["#define INFINITY (__int_as_float(0x7f800000))","#define NAN (__int_as_float(0x7fffffff))"]
if any(uop.dtype == dtypes.half for uop in uops):
prefix += ["#include <cuda_fp16.h>"] + [_make_cuda_dtype("half", "half", x) for x in [4, 8]]
prefix += ["#include <cuda_fp16.h>"] + [_make_cuda_dtype("half", "half", x, x*2) for x in [4, 8]]
if any(uop.dtype == dtypes.bfloat16 for uop in uops):
prefix += ["#include <cuda_bf16.h>"] + [_make_cuda_dtype("nv_bfloat16", "bfloat16", x) for x in [4, 8]]
prefix += ["#include <cuda_bf16.h>"] + [_make_cuda_dtype("nv_bfloat16", "bfloat16", x, x*2) for x in [4, 8]]
# TODO: this has to be way better to generate for arbitrary M,N,K: use arg[1] for MNK, use arg[4] for vec sizes, encode register packing
for arg in dedup([uop.arg for uop in uops if uop.op is UOps.WMMA]):