mirror of https://github.com/commaai/tinygrad.git
wmma: refactor to remove wmma_func and create TC funcs as needed (#3945)
* wmma: refactor to remove wmma_func and create TC funcs as needed * test_linearizer: disable bf16 CUDA during emulation testing * cstyle: clean up creation of CUDA vec dtypes * extra/gemm: add option to accumulate to bfloat16 * cleanups * benchmark: add CUDA bfloat16 matmul * more cleanups
This commit is contained in:
parent
88b24df40a
commit
7c5729a3bd
|
@ -108,7 +108,9 @@ jobs:
|
|||
- name: Test speed vs torch
|
||||
run: CUDA=1 BIG=2 TORCHCUDA=1 python3 test/test_speed_v_torch.py | tee torch_speed.txt
|
||||
- name: Run Tensor Core GEMM
|
||||
run: CUDA=1 HALF=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul.txt
|
||||
run: |
|
||||
CUDA=1 HALF=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul.txt
|
||||
CUDA=1 BFLOAT16=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul_bfloat16.txt
|
||||
- name: Run LLaMA
|
||||
run: |
|
||||
CUDA=1 JIT=0 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_unjitted.txt
|
||||
|
@ -140,6 +142,7 @@ jobs:
|
|||
onnx_inference_speed.csv
|
||||
torch_speed.txt
|
||||
matmul.txt
|
||||
matmul_bfloat16.txt
|
||||
llama_unjitted.txt
|
||||
llama_jitted.txt
|
||||
llama_beam.txt
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
from tinygrad.helpers import getenv
|
||||
from tinygrad import dtypes, Tensor
|
||||
|
||||
dtype_in = dtypes.half if getenv("HALF") else dtypes.float
|
||||
acc_dtype = dtypes.half if getenv("ACC_HALF") else None
|
||||
dtype_in = dtypes.half if getenv("HALF") else dtypes.bfloat16 if getenv("BFLOAT16") else dtypes.float
|
||||
acc_dtype = dtypes.half if getenv("ACC_HALF") else dtypes.bfloat16 if getenv("ACC_BFLOAT16") else None
|
||||
|
||||
CNT = getenv("CNT", 8)
|
||||
BS = getenv("BS", 16)
|
||||
|
@ -12,6 +12,8 @@ HW = getenv("HW", 128)
|
|||
K = getenv("K", 3)
|
||||
PADDING = getenv("PADDING", 1)
|
||||
COMP = getenv("COMP", 0)
|
||||
ATOL = getenv("ATOL", 1e-4)
|
||||
RTOL = getenv("RTOL", 3e-2)
|
||||
|
||||
FLOPS = BS*K*K*CIN*HW*HW*COUT*2
|
||||
def rand_input(): return Tensor.rand(BS, CIN, HW, HW, dtype=dtype_in).realize(), Tensor.rand(COUT, CIN, K, K, dtype=dtype_in).realize()
|
||||
|
@ -28,4 +30,4 @@ if __name__ == "__main__":
|
|||
torch_device = "cuda:0" if torch.cuda.is_available() else ("mps" if getenv("MPS", 0) else "cpu")
|
||||
ta, tb = torch.from_numpy(a.numpy()).to(torch_device), torch.from_numpy(b.numpy()).to(torch_device)
|
||||
tc = torch.nn.functional.conv2d(ta, tb, padding=PADDING)
|
||||
np.testing.assert_allclose(c.numpy(), tc.cpu(), atol=1e-4, rtol=3e-2)
|
||||
np.testing.assert_allclose(c.numpy(), tc.cpu(), atol=ATOL, rtol=RTOL)
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
import numpy as np
|
||||
from tinygrad.helpers import getenv
|
||||
from tinygrad import dtypes, Tensor
|
||||
dtype_in = dtypes.half if getenv("HALF") else dtypes.float
|
||||
acc_dtype = dtypes.half if getenv("ACC_HALF") else None
|
||||
dtype_in = dtypes.half if getenv("HALF") else dtypes.bfloat16 if getenv("BFLOAT16") else dtypes.float
|
||||
acc_dtype = dtypes.half if getenv("ACC_HALF") else dtypes.bfloat16 if getenv("ACC_BFLOAT16") else None
|
||||
N = getenv("N", 4096)
|
||||
CNT = getenv("CNT", 10)
|
||||
ATOL = getenv("ATOL", 1e-4)
|
||||
RTOL = getenv("RTOL", 3e-2)
|
||||
|
||||
if __name__ == "__main__":
|
||||
a, b = Tensor.rand(N, N, dtype=dtype_in).realize(), Tensor.rand(N, N, dtype=dtype_in).realize()
|
||||
|
@ -14,4 +16,4 @@ if __name__ == "__main__":
|
|||
c = a.matmul(b, acc_dtype=acc_dtype).realize()
|
||||
comp = a.numpy().astype(np.float32) @ b.numpy().astype(np.float32)
|
||||
nc = c.numpy()
|
||||
np.testing.assert_allclose(nc, comp, atol=1e-4, rtol=3e-2)
|
||||
np.testing.assert_allclose(nc, comp, atol=ATOL, rtol=RTOL)
|
||||
|
|
|
@ -12,7 +12,7 @@ from tinygrad.tensor import Tensor
|
|||
from tinygrad.engine.jit import CacheCollector
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.engine.realize import run_schedule
|
||||
from tinygrad.helpers import prod, Context
|
||||
from tinygrad.helpers import prod, Context, getenv
|
||||
from tinygrad.dtype import DType, dtypes
|
||||
from tinygrad.codegen.uops import UOpGraph
|
||||
|
||||
|
@ -186,6 +186,7 @@ class TestLinearizer(unittest.TestCase):
|
|||
if not Device[Device.DEFAULT].compiler.linearizer_opts.has_tensor_cores:
|
||||
self.skipTest("device doesn't have tensor cores")
|
||||
for tc in tensor_cores[Device[Device.DEFAULT].compiler.linearizer_opts.device]:
|
||||
if getenv("EMULATE_CUDA") and (tc.dtype_in == dtypes.bfloat16 or tc.dtype_out == dtypes.bfloat16): continue
|
||||
a, b = Tensor.rand(tc.dims[1], tc.dims[2], dtype=tc.dtype_in), Tensor.rand(tc.dims[2], tc.dims[0], dtype=tc.dtype_in)
|
||||
np_a, np_b = a.numpy(), b.numpy()
|
||||
r = a.matmul(b, acc_dtype=tc.dtype_out)
|
||||
|
|
|
@ -30,16 +30,14 @@ class Opt:
|
|||
|
||||
@dataclass(frozen=True)
|
||||
class TensorCore: # D = A * B + C, A is (M x K), B is (K x N), C and D are (M x N)
|
||||
dims: List[int] # N, M, K
|
||||
dims: Tuple[int,int,int] # N, M, K
|
||||
dtype_in: DType # dtype for A and B
|
||||
dtype_out: DType # dtype for C and D
|
||||
threads: List[Tuple[int,int]] # list of (TC dim,amt) that construct the warp thread structure
|
||||
thread_local_aliases: List[List[List[int]]] # a list of [threads_1, ..., threads_n, upcast_1(unrolled), upcast_2(upcast)] defining the alias (-1 is upcast, 1-n is warp threads) for each TC dim # noqa: E501
|
||||
thread_local_sizes: List[List[int]] # in each thread, the number of elements stored in registers for each TC dim
|
||||
wmma_func: str # name of wmma function to call
|
||||
def __str__(self): return f"tensor_core<{self.dims}, {self.dtype_in}, {self.dtype_out}>"
|
||||
def num_threads(self): return len(self.threads)
|
||||
def num_upcasts(self): return len(self.thread_local_aliases[0]) - self.num_threads()
|
||||
def __str__(self): return "_".join(["WMMA"] + list(map(str, self.dims)) + [self.dtype_in.name, self.dtype_out.name])
|
||||
def num_upcasts(self): return len(self.thread_local_aliases[0]) - len(self.threads)
|
||||
|
||||
class TensorCoreOptions(NamedTuple):
|
||||
bufs: Tuple[int, int] # the local aliased buffers for A and B
|
||||
|
@ -51,18 +49,9 @@ class TensorCoreOptions(NamedTuple):
|
|||
elif removed_axis == self.axes[tc_dim]: self.axes_exist[tc_dim] = False
|
||||
|
||||
tensor_cores: Dict[str, List[TensorCore]] = {
|
||||
"METAL": [
|
||||
TensorCore(dims=[8,8,8], dtype_in=dtypes.float, dtype_out=dtypes.float, wmma_func="__metal_wmma<float2,simdgroup_float8x8,float2>", threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[[2],[2],[2]], thread_local_aliases=[ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ]), # noqa: E501
|
||||
TensorCore(dims=[8,8,8], dtype_in=dtypes.half, dtype_out=dtypes.float, wmma_func="__metal_wmma<half2,simdgroup_float8x8,float2>", threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[[2],[2],[2]], thread_local_aliases=[ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ]), # noqa: E501
|
||||
TensorCore(dims=[8,8,8], dtype_in=dtypes.half, dtype_out=dtypes.half, wmma_func="__metal_wmma<half2,simdgroup_half8x8,half2>", threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[[2],[2],[2]], thread_local_aliases=[ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ]), # noqa: E501
|
||||
],
|
||||
"HSA": [
|
||||
TensorCore(dims=[16,16,16], dtype_in=dtypes.half, dtype_out=dtypes.float, wmma_func="__builtin_amdgcn_wmma_f32_16x16x16_f16_w32", threads=[(0,8),(0,2),(1,2)], thread_local_sizes=[[16],[16],[4,2]], thread_local_aliases=[ [[2],[0],[0],[-1],[1]], [[0],[2],[1],[-1],[0]], [[-2],[2],[1],[0],[3,-1]] ]), # noqa: E501
|
||||
TensorCore(dims=[16,16,16], dtype_in=dtypes.half, dtype_out=dtypes.half, wmma_func="__hip_wmma_f16_f16", threads=[(0,8),(0,2),(1,2)], thread_local_sizes=[[16],[16],[4,2]], thread_local_aliases=[ [[2],[0],[0],[-1],[1]], [[0],[2],[1],[-1],[0]], [[-2],[2],[1],[0],[3,-1]] ]), # noqa: E501
|
||||
],
|
||||
"CUDA": [
|
||||
TensorCore(dims=[8,16,16], dtype_in=dtypes.half, dtype_out=dtypes.float, wmma_func="__cuda_mma_m16n8k16_f16_f32", threads=[(0,2),(0,2),(1,2),(1,2),(0,2)], thread_local_sizes=[[2,2,2],[2,2],[2,2]], thread_local_aliases=[ [[0],[-2],[5],[0],[0],[-1,1,2,-3],[3,4]], [[5],[0],[0],[4],[3],[-1,1,2,-2],[0]], [[2],[-2],[5],[1],[-1],[0],[3,4]] ]), # noqa: E501
|
||||
],
|
||||
"METAL": [TensorCore(dims=(8,8,8), threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[[2],[2],[2]], thread_local_aliases=[ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ], dtype_in=di, dtype_out=do) for (di, do) in [(dtypes.float, dtypes.float), (dtypes.half, dtypes.float), (dtypes.half, dtypes.half)]], # noqa: E501
|
||||
"HSA": [TensorCore(dims=(16,16,16), threads=[(0,8),(0,2),(1,2)], thread_local_sizes=[[16],[16],[4,2]], thread_local_aliases=[ [[2],[0],[0],[-1],[1]], [[0],[2],[1],[-1],[0]], [[-2],[2],[1],[0],[3,-1]] ], dtype_in=di, dtype_out=do) for (di, do) in [(dtypes.half, dtypes.float), (dtypes.half, dtypes.half)]], # noqa: E501
|
||||
"CUDA": [TensorCore(dims=(8,16,16), threads=[(0,2),(0,2),(1,2),(1,2),(0,2)], thread_local_sizes=[[2,2,2],[2,2],[2,2]], thread_local_aliases=[ [[0],[-2],[5],[0],[0],[-1,1,2,-3],[3,4]], [[5],[0],[0],[4],[3],[-1,1,2,-2],[0]], [[2],[-2],[5],[1],[-1],[0],[3,4]] ], dtype_in=di, dtype_out=do) for (di, do) in [(dtypes.half, dtypes.float), (dtypes.bfloat16, dtypes.float)]], # noqa: E501
|
||||
}
|
||||
|
||||
class LocalBuffer(NamedTuple):
|
||||
|
|
|
@ -291,10 +291,10 @@ class Linearizer(Kernel):
|
|||
if (tc:=self.tensor_core):
|
||||
min_alias_idx = min(self.local_alias.keys())
|
||||
replace_input_idxs = calc_tc_idxs(tc.thread_local_sizes[i-min_alias_idx], tc.thread_local_aliases[i-min_alias_idx])
|
||||
for n in range(tc.num_threads()):
|
||||
buf_idxs[self.first_reduce-tc.num_threads()+n] = replace_input_idxs[n] # replace locals
|
||||
for n in range(len(tc.threads)):
|
||||
buf_idxs[self.first_reduce-len(tc.threads)+n] = replace_input_idxs[n] # replace locals
|
||||
for n in range(tc.num_upcasts()):
|
||||
buf_idxs[self.shape_len-self.upcasted+n] = replace_input_idxs[tc.num_threads()+n] # replace upcasts
|
||||
buf_idxs[self.shape_len-self.upcasted+n] = replace_input_idxs[len(tc.threads)+n] # replace upcasts
|
||||
if DEBUG >= 3: print(f"{localbuf_idx} alias {i}: sts={self.sts[i]} idxs={buf_idxs}")
|
||||
ll = self.global_load(i, buf_idxs)
|
||||
locals_to_store.append((localbuf_idx, buf_idxs, ll))
|
||||
|
@ -315,13 +315,13 @@ class Linearizer(Kernel):
|
|||
strides.append((0 if stride == 0 else next, sz))
|
||||
next *= 1 if stride == 0 else sz
|
||||
return strides
|
||||
upcasts = [upcast_strides(x) for x in [locals_to_store[0][0], locals_to_store[1][0], 0]]
|
||||
for iter in [x[::-1] for x in [x for x in itertools.product(*[x for x in [range(sz) for _,sz in upcasts[0]][::-1]])]]:
|
||||
upcasts, dev = [upcast_strides(x) for x in [locals_to_store[0][0], locals_to_store[1][0], 0]], self.opts.device
|
||||
for iter in [x[::-1] for x in itertools.product(*[x for x in [range(sz) for _,sz in upcasts[0]][::-1]])]:
|
||||
offs = [x*y for (x,y) in zip([sum([prod(x) for x in zip(iter, [stride for stride,_ in y])]) for y in upcasts], wmma_sz)]
|
||||
ops = (self.uops.add(UOps.CAST, tc.dtype_in.vec(wmma_sz[0]), tuple(locals_to_store[0][2][offs[0]:offs[0]+wmma_sz[0]])),
|
||||
self.uops.add(UOps.CAST, tc.dtype_in.vec(wmma_sz[1]), tuple(locals_to_store[1][2][offs[1]:offs[1]+wmma_sz[1]])),
|
||||
self.uops.add(UOps.CAST, tc.dtype_out.vec(wmma_sz[2]), tuple(op3:=acc[offs[2]:offs[2]+wmma_sz[2]])))
|
||||
ret = self.uops.add(UOps.WMMA, tc.dtype_out.vec(wmma_sz[2]), ops, tc.wmma_func)
|
||||
self.uops.add(UOps.CAST, (dt3:=tc.dtype_out.vec(wmma_sz[2])), tuple(op3:=acc[offs[2]:offs[2]+wmma_sz[2]])))
|
||||
ret = self.uops.add(UOps.WMMA, dt3, ops, (str(tc), tc.dims, tc.dtype_in, tc.dtype_out, tuple(map(prod, tc.thread_local_sizes)), dev))
|
||||
for z in range(wmma_sz[2]):
|
||||
acc[offs[2]+z] = self.uops.add(UOps.PHI, tc.dtype_out, (op3[z], self.uops.add(UOps.GEP, tc.dtype_out, (ret,), z)) + loop_ctx)
|
||||
else:
|
||||
|
|
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
|||
import functools, math, operator, itertools
|
||||
from typing import List, Set, Optional, Tuple, Any, Dict, DefaultDict, Callable, cast
|
||||
from collections import defaultdict
|
||||
from tinygrad.helpers import DEBUG, flatten
|
||||
from tinygrad.helpers import DEBUG, flatten, prod
|
||||
from tinygrad.dtype import dtypes, DType
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
|
||||
from tinygrad.shape.symbolic import sint, Variable, Node, NumNode, MulNode, DivNode, SumNode
|
||||
|
@ -385,8 +385,6 @@ class UOpGraph:
|
|||
assert u.vin[2].dtype is not None
|
||||
mem += u.vin[2].dtype.itemsize * mults
|
||||
elif u.uop is UOps.WMMA:
|
||||
if u.arg.startswith("__metal_wmma"): flops += 2*(8*8*8)//32 * mults
|
||||
elif u.arg == "__hip_wmma_f16_f16" or u.arg == "__builtin_amdgcn_wmma_f32_16x16x16_f16_w32": flops += 2*(16*16*16)//32 * mults
|
||||
elif u.arg == "__cuda_mma_m16n8k16_f16_f32": flops += 2*(8*16*16)//32 * mults
|
||||
else: raise NotImplementedError(f"not implemented wmma {u.arg=}")
|
||||
assert u.arg[1] is not None
|
||||
flops += 2 * prod(u.arg[1]) // 32 * mults
|
||||
return flops, mem
|
||||
|
|
|
@ -164,7 +164,7 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:UOpGraph) -> str
|
|||
assert len(bufs) == args[0], f"missed a global buffer {len(bufs)} {args}"
|
||||
bufs.append((args[1], (dtype,args[2])))
|
||||
r[u] = args[1]
|
||||
elif uop is UOps.WMMA: kk(f"{lang.render_dtype(dtype)} {ssa(u, 'wmma')} = {args}({r[vin[0]]}, {r[vin[1]]}, {r[vin[2]]});")
|
||||
elif uop is UOps.WMMA: kk(f"{lang.render_dtype(dtype)} {ssa(u, 'wmma')} = __{args[0]}({r[vin[0]]}, {r[vin[1]]}, {r[vin[2]]});")
|
||||
elif uop is UOps.DEFINE_ACC: kk(f"{lang.render_dtype(dtype)} {ssa(u,'acc')} = {lang.render_const(args, dtype)};")
|
||||
elif uop is UOps.CONST: r[u] = lang.render_const(args, dtype) if args >= 0 else f"({lang.render_const(args, dtype)})"
|
||||
elif uop is UOps.GEP:
|
||||
|
@ -215,11 +215,11 @@ class MetalLanguage(CStyleLanguage):
|
|||
return f"as_type<{self.render_dtype(var_dtype)}>({x[0]})" if bitcast else super().render_cast(x, var_dtype)
|
||||
|
||||
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None):
|
||||
prefix = ["#include <metal_stdlib>","using namespace metal;"]
|
||||
if any(uop.uop == UOps.WMMA for uop in uops): prefix.append("""template<typename T, typename S, typename U> U __metal_wmma(T m, T n, U o) {
|
||||
S a,b,c; a.thread_elements()[0] = m.x; a.thread_elements()[1] = m.y; b.thread_elements()[0] = n.x; b.thread_elements()[1] = n.y;
|
||||
c.thread_elements()[0] = o.x; c.thread_elements()[1] = o.y; simdgroup_multiply_accumulate(c, a, b, c);
|
||||
return U(c.thread_elements()[0], c.thread_elements()[1]);\n}""")
|
||||
prefix, wmma_args = ["#include <metal_stdlib>","using namespace metal;"], set([uop.arg for uop in uops if uop.uop == UOps.WMMA])
|
||||
for arg in wmma_args: prefix.append(f"""{arg[3].name}2 __{arg[0]}({arg[2].name}2 m, {arg[2].name}2 n, {arg[3].name}2 o) {{
|
||||
simdgroup_{arg[3].name}8x8 a,b,c; a.thread_elements()[0] = m.x; a.thread_elements()[1] = m.y; b.thread_elements()[0] = n.x;
|
||||
b.thread_elements()[1] = n.y; c.thread_elements()[0] = o.x; c.thread_elements()[1] = o.y; simdgroup_multiply_accumulate(c, a, b, c);
|
||||
return {arg[3].name}2(c.thread_elements()[0], c.thread_elements()[1]);\n}}""")
|
||||
return super().render_kernel(function_name, kernel, bufs, uops, prefix)
|
||||
MetalRenderer = functools.partial(uops_to_cstyle, MetalLanguage())
|
||||
|
||||
|
@ -229,6 +229,11 @@ code_for_op_half = {BinaryOps.MAX: lambda a,b,dtype: f"__hmax({a},{b})" if dtype
|
|||
UnaryOps.LOG2: lambda x,dtype: f"hlog2({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"log2({x})",
|
||||
UnaryOps.EXP2: lambda x,dtype: f"hexp2({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"exp2({x})",}
|
||||
|
||||
_nms = "xyzwabcdefghijkl"
|
||||
def _make_cuda_dtype(base_type, name, cnt):
|
||||
vec, elems, header = f"{name}{cnt}", ', '.join(_nms[:cnt]), ', '.join([f"{base_type} {x}" for x in _nms[:cnt]])
|
||||
return f"struct {vec} {{ {base_type} {elems}; }}; __device__ {vec} make_{vec}({header}) {{ {vec} r={{{elems}}}; return r; }}"
|
||||
|
||||
class CUDALanguage(CStyleLanguage):
|
||||
kernel_prefix = "extern \"C\" __global__ "
|
||||
smem_prefix = "__shared__ "
|
||||
|
@ -241,17 +246,24 @@ class CUDALanguage(CStyleLanguage):
|
|||
type_map = {dtypes.bfloat16: "nv_bfloat16"}
|
||||
|
||||
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None):
|
||||
# TODO: why is dtypes.bfloat16.name == "__bf16"? would be easier not override dtypes.name
|
||||
dt_map = { dtypes.float: ("float","f32"), dtypes.half: ("half","f16"), dtypes.bfloat16: ("bfloat16","bf16"), }
|
||||
|
||||
prefix = ["#define INFINITY (__int_as_float(0x7f800000))","#define NAN (__int_as_float(0x7fffffff))"]
|
||||
if any(uop.dtype == dtypes.half for uop in uops):
|
||||
prefix += ["#include <cuda_fp16.h>", "struct half4 { half x, y, z, w; };", "struct half8 { half x, y, z, w, a, b, c, d; };",
|
||||
"__device__ half4 make_half4(half x, half y, half z, half w) { half4 r={x, y, z, w}; return r; }",
|
||||
"__device__ half8 make_half8(half x, half y, half z, half w, half a, half b, half c, half d) { half8 r={x, y, z, w, a, b, c, d}; return r; }",
|
||||
"""__device__ float4 __cuda_mma_m16n8k16_f16_f32(half8 a, half4 b, float4 c) { int *a_pk = (int *) (&a), *b_pk = (int *) (&b);
|
||||
asm( "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 { %0, %1, %2, %3 }, { %4, %5, %6, %7 }, { %8, %9 }, { %0, %1, %2, %3 };"
|
||||
: "+f"(c.x), "+f"(c.y), "+f"(c.z), "+f"(c.w) : "r"(a_pk[0]), "r"(a_pk[1]), "r"(a_pk[2]), "r"(a_pk[3]), "r"(b_pk[0]), "r"(b_pk[1]) );
|
||||
return c;}""",]
|
||||
prefix += ["#include <cuda_fp16.h>"] + [_make_cuda_dtype("half", "half", x) for x in [4, 8]]
|
||||
|
||||
if any(uop.dtype == dtypes.bfloat16 for uop in uops):
|
||||
prefix += ["#include <cuda_bf16.h>"] + [_make_cuda_dtype("nv_bfloat16", "bfloat16", x) for x in [4, 8]]
|
||||
|
||||
# TODO: this has to be way better to generate for arbitrary M,N,K: use arg[1] for MNK, use arg[4] for vec sizes, encode register packing
|
||||
for arg in set([uop.arg for uop in uops if uop.uop == UOps.WMMA]):
|
||||
fn, ti, to, ci, co = arg[0], dt_map[arg[2]][0], dt_map[arg[3]][0], dt_map[arg[2]][1], dt_map[arg[3]][1]
|
||||
prefix.append(f"""__device__ {to}4 __{fn}({ti}8 a, {ti}4 b, {to}4 c) {{ int *a_pk = (int *) (&a), *b_pk = (int *) (&b);
|
||||
asm( "mma.sync.aligned.m16n8k16.row.col.{co}.{ci}.{ci}.{co} {{ %0, %1, %2, %3 }}, {{ %4, %5, %6, %7 }}, {{ %8, %9 }}, {{ %0, %1, %2, %3 }};"
|
||||
: "+f"(c.x), "+f"(c.y), "+f"(c.z), "+f"(c.w) : "r"(a_pk[0]), "r"(a_pk[1]), "r"(a_pk[2]), "r"(a_pk[3]), "r"(b_pk[0]), "r"(b_pk[1]) );
|
||||
return c;}}""")
|
||||
|
||||
if any(uop.dtype == dtypes.bfloat16 for uop in uops): prefix.append("#include <cuda_bf16.h>")
|
||||
return super().render_kernel(function_name, kernel, bufs, uops, prefix=prefix)
|
||||
CUDARenderer = functools.partial(uops_to_cstyle, CUDALanguage())
|
||||
|
||||
|
@ -273,43 +285,31 @@ def _make_hip_code_for_op():
|
|||
return { k:wrapper(k,v) for k,v in {**CStyleLanguage().code_for_op, **code_for_op_hip}.items() }
|
||||
|
||||
def _make_hip_dtype(base_type, name, cnt):
|
||||
nms = "xyzwabcdefghijkl"[:cnt]
|
||||
elems, header = ', '.join(_nms[:cnt]), ', '.join([f"{base_type} {x}" for x in _nms[:cnt]])
|
||||
return f"typedef {base_type} {name}{cnt} __attribute__((ext_vector_type({cnt})));\n" + \
|
||||
f"static inline __attribute__((device)) {name}{cnt} make_{name}{cnt}(" + ', '.join([f"{base_type} {x}" for x in nms]) + \
|
||||
") { return {" + ', '.join(nms) + "}; }"
|
||||
f"static inline __attribute__((device)) {name}{cnt} make_{name}{cnt}({header}) {{ return {{{elems}}}; }}"
|
||||
|
||||
class HIPLanguage(CStyleLanguage):
|
||||
kernel_prefix = """
|
||||
#define half _Float16
|
||||
|
||||
extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_local_id(unsigned int);
|
||||
extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_group_id(unsigned int);
|
||||
extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_local_size(unsigned int);
|
||||
|
||||
extern "C" {
|
||||
__attribute__((device)) __attribute__((const)) float __ocml_fmax_f32(float, float);
|
||||
__attribute__((device)) __attribute__((pure)) float __ocml_exp2_f32(float);
|
||||
__attribute__((device)) __attribute__((pure)) float __ocml_log2_f32(float);
|
||||
__attribute__((device)) float __ocml_sin_f32(float);
|
||||
__attribute__((device)) __attribute__((const)) float __ocml_sqrt_f32(float);
|
||||
__attribute__((device)) __attribute__((const)) double __ocml_fmax_f64(double, double);
|
||||
__attribute__((device)) __attribute__((pure)) double __ocml_exp2_f64(double);
|
||||
__attribute__((device)) __attribute__((pure)) double __ocml_log2_f64(double);
|
||||
__attribute__((device)) double __ocml_sin_f64(double);
|
||||
__attribute__((device)) __attribute__((const)) double __ocml_sqrt_f64(double);
|
||||
__attribute__((device)) __attribute__((const)) _Float16 __ocml_fmax_f16(_Float16, _Float16);
|
||||
__attribute__((device)) __attribute__((pure)) _Float16 __ocml_exp2_f16(_Float16);
|
||||
__attribute__((device)) __attribute__((pure)) _Float16 __ocml_log2_f16(_Float16);
|
||||
__attribute__((device)) _Float16 __ocml_sin_f16(_Float16);
|
||||
__attribute__((device)) __attribute__((const)) _Float16 __ocml_sqrt_f16(_Float16);
|
||||
}\n""" + '\n'.join([_make_hip_dtype(*x) for x in [
|
||||
("_Float16", "half", 2), ("_Float16", "half", 4), ("_Float16", "half", 8), ("_Float16", "half", 16),
|
||||
("float", "float", 8)]]) + """
|
||||
static __attribute__((device)) half8 __hip_wmma_f16_f16(half16 a, half16 b, half8 c) {
|
||||
half16 c_frag = {}; half8 d; for (int n = 0; n < 8; n++) { c_frag[n*2] = c[n]; }
|
||||
c_frag = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(a, b, c_frag, false);
|
||||
for (int n = 0; n < 8; n++) { d[n] = c_frag[n*2]; } return d;
|
||||
}\nextern "C" __attribute__((global))"""
|
||||
kernel_prefix = """extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_local_id(unsigned int);
|
||||
extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_group_id(unsigned int);
|
||||
extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_local_size(unsigned int);
|
||||
extern "C" {
|
||||
__attribute__((device)) __attribute__((const)) float __ocml_fmax_f32(float, float);
|
||||
__attribute__((device)) __attribute__((pure)) float __ocml_exp2_f32(float);
|
||||
__attribute__((device)) __attribute__((pure)) float __ocml_log2_f32(float);
|
||||
__attribute__((device)) float __ocml_sin_f32(float);
|
||||
__attribute__((device)) __attribute__((const)) float __ocml_sqrt_f32(float);
|
||||
__attribute__((device)) __attribute__((const)) double __ocml_fmax_f64(double, double);
|
||||
__attribute__((device)) __attribute__((pure)) double __ocml_exp2_f64(double);
|
||||
__attribute__((device)) __attribute__((pure)) double __ocml_log2_f64(double);
|
||||
__attribute__((device)) double __ocml_sin_f64(double);
|
||||
__attribute__((device)) __attribute__((const)) double __ocml_sqrt_f64(double);
|
||||
__attribute__((device)) __attribute__((const)) _Float16 __ocml_fmax_f16(_Float16, _Float16);
|
||||
__attribute__((device)) __attribute__((pure)) _Float16 __ocml_exp2_f16(_Float16);
|
||||
__attribute__((device)) __attribute__((pure)) _Float16 __ocml_log2_f16(_Float16);
|
||||
__attribute__((device)) _Float16 __ocml_sin_f16(_Float16);
|
||||
__attribute__((device)) __attribute__((const)) _Float16 __ocml_sqrt_f16(_Float16);
|
||||
}\nextern "C" __attribute__((global))"""
|
||||
code_for_workitem = {"g": lambda x: f"__ockl_get_group_id({x})", "l": lambda x: f"__ockl_get_local_id({x})",
|
||||
"i": lambda x: f"(__ockl_get_group_id({x})*__ockl_get_local_size({x})+__ockl_get_local_id({x}))"}
|
||||
code_for_op = _make_hip_code_for_op()
|
||||
|
@ -321,8 +321,10 @@ class HIPLanguage(CStyleLanguage):
|
|||
type_map = {dtypes.bfloat16: "hip_bfloat16"}
|
||||
|
||||
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
|
||||
prefix = ["#define INFINITY (__builtin_inff())\n#define NAN (__builtin_nanf(\"\"))",
|
||||
"typedef long unsigned int size_t;"]
|
||||
prefix = ["#define INFINITY (__builtin_inff())", "#define NAN (__builtin_nanf(\"\"))", "typedef long unsigned int size_t;"]
|
||||
vec_dts = [("float", "float", 2), ("float", "float", 4), ("float", "float", 8), ("signed int", "int", 4), ("signed int", "int", 2)]
|
||||
|
||||
# TODO: add BF16 vec dts
|
||||
if any(uop.dtype == dtypes.bfloat16 for uop in uops): prefix.append("""
|
||||
struct hip_bfloat16 {
|
||||
unsigned short data;
|
||||
|
@ -339,8 +341,19 @@ struct hip_bfloat16 {
|
|||
static __attribute__((device)) bool operator<(hip_bfloat16 a, hip_bfloat16 b) { return ((float)a) < ((float)b); }
|
||||
static __attribute__((device)) bool operator==(hip_bfloat16 a, hip_bfloat16 b) { return ((float)a) == ((float)b); }
|
||||
""")
|
||||
prefix.append('\n'.join(_make_hip_dtype(*x) for x in [("float", "float", 2), ("float", "float", 4),
|
||||
("signed int", "int", 4), ("signed int", "int", 2)]))
|
||||
|
||||
if any(uop.dtype == dtypes.half for uop in uops):
|
||||
prefix.append("#define half _Float16")
|
||||
vec_dts += [("_Float16", "half", 2), ("_Float16", "half", 4), ("_Float16", "half", 8), ("_Float16", "half", 16)]
|
||||
|
||||
prefix += [_make_hip_dtype(*x) for x in vec_dts]
|
||||
|
||||
for arg in set([uop.arg for uop in uops if uop.uop == UOps.WMMA]): # TODO: handle TCs f32_bf16 and bf16_bf16 w/ wrapper
|
||||
if arg[3] == dtypes.float: prefix.append(f"#define __{arg[0]} __builtin_amdgcn_wmma_f32_16x16x16_f16_w32")
|
||||
else: prefix.append(f"static __attribute__((device)) half8 __{arg[0]}"+"""(half16 a, half16 b, half8 c) {
|
||||
half16 c_frag = {}; half8 d; for (int n = 0; n < 8; n++) { c_frag[n*2] = c[n]; }
|
||||
c_frag = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(a, b, c_frag, false);
|
||||
for (int n = 0; n < 8; n++) { d[n] = c_frag[n*2]; } return d;\n}""")
|
||||
return super().render_kernel(function_name, kernel, bufs, uops, prefix)
|
||||
|
||||
def get_kernel_modifier(self, uops:UOpGraph) -> str:
|
||||
|
|
|
@ -145,13 +145,14 @@ class PythonProgram:
|
|||
out[elem_idx][goff+lane_id] += sum(a_elem(inp[0], _k, c_j, goff) * b_elem(inp[1], c_i, _k, goff) for _k in range(K))
|
||||
return out
|
||||
|
||||
if arg.startswith('__metal_wmma'):
|
||||
# TODO: refactor these to a shared TensorCoreLayout in kernel.py
|
||||
if arg[5] == "METAL":
|
||||
# A (2 elements on 32 threads): row major
|
||||
def a_b_elem(x, i, j, goff): return x[(i%2)][goff+(i//2)%2+(j%4)*2+(i//4)*8+(j//4)*16]
|
||||
# (i, j), C, D (2 elements on 32 threads): row major same as A/B
|
||||
def c_map(lane, elem): return (elem + ((lane%2)*2) + ((lane//8)%2)*4, ((lane//2)%4) + (lane//16)*4)
|
||||
ul[i] = wmma_helper(32, 8, 2, 2, 2, a_b_elem, a_b_elem, c_map)
|
||||
elif arg == '__builtin_amdgcn_wmma_f32_16x16x16_f16_w32' or arg == '__hip_wmma_f16_f16':
|
||||
elif arg[5] == "HSA":
|
||||
# A (16 elements on 32 threads): col major, lane 16-32 == lane 0-15
|
||||
def a_elem(x, i, j, goff):
|
||||
assert x[i][goff+j] == x[i][goff+j+16], "warp elements not duplicated properly across lanes"
|
||||
|
@ -160,7 +161,7 @@ class PythonProgram:
|
|||
def b_elem(x, i, j, goff): return a_elem(x, j, i, goff)
|
||||
def c_map(lane, elem): return (lane%16, lane//16+elem*2) # (i, j), C, D (8 elements on 32 threads): row major
|
||||
ul[i] = wmma_helper(32, 16, 16, 16, 8, a_elem, b_elem, c_map)
|
||||
elif arg == '__cuda_mma_m16n8k16_f16_f32':
|
||||
elif arg[5] == "CUDA":
|
||||
# A (8 elements on 32 threads)
|
||||
def a_elem(x, i, j, goff): return x[(i%2)+(j//8)*2+(i//8)*4][goff+((i//2)%4)+(j%8)*4]
|
||||
# B (4 elements on 32 threads)
|
||||
|
|
Loading…
Reference in New Issue