wmma: refactor to remove wmma_func and create TC funcs as needed (#3945)

* wmma: refactor to remove wmma_func and create TC funcs as needed

* test_linearizer: disable bf16 CUDA during emulation testing

* cstyle: clean up creation of CUDA vec dtypes

* extra/gemm: add option to accumulate to bfloat16

* cleanups

* benchmark: add CUDA bfloat16 matmul

* more cleanups
This commit is contained in:
Francis Lam 2024-03-27 13:43:09 -07:00 committed by GitHub
parent 88b24df40a
commit 7c5729a3bd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 101 additions and 92 deletions

View File

@ -108,7 +108,9 @@ jobs:
- name: Test speed vs torch
run: CUDA=1 BIG=2 TORCHCUDA=1 python3 test/test_speed_v_torch.py | tee torch_speed.txt
- name: Run Tensor Core GEMM
run: CUDA=1 HALF=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul.txt
run: |
CUDA=1 HALF=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul.txt
CUDA=1 BFLOAT16=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul_bfloat16.txt
- name: Run LLaMA
run: |
CUDA=1 JIT=0 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_unjitted.txt
@ -140,6 +142,7 @@ jobs:
onnx_inference_speed.csv
torch_speed.txt
matmul.txt
matmul_bfloat16.txt
llama_unjitted.txt
llama_jitted.txt
llama_beam.txt

View File

@ -1,8 +1,8 @@
from tinygrad.helpers import getenv
from tinygrad import dtypes, Tensor
dtype_in = dtypes.half if getenv("HALF") else dtypes.float
acc_dtype = dtypes.half if getenv("ACC_HALF") else None
dtype_in = dtypes.half if getenv("HALF") else dtypes.bfloat16 if getenv("BFLOAT16") else dtypes.float
acc_dtype = dtypes.half if getenv("ACC_HALF") else dtypes.bfloat16 if getenv("ACC_BFLOAT16") else None
CNT = getenv("CNT", 8)
BS = getenv("BS", 16)
@ -12,6 +12,8 @@ HW = getenv("HW", 128)
K = getenv("K", 3)
PADDING = getenv("PADDING", 1)
COMP = getenv("COMP", 0)
ATOL = getenv("ATOL", 1e-4)
RTOL = getenv("RTOL", 3e-2)
FLOPS = BS*K*K*CIN*HW*HW*COUT*2
def rand_input(): return Tensor.rand(BS, CIN, HW, HW, dtype=dtype_in).realize(), Tensor.rand(COUT, CIN, K, K, dtype=dtype_in).realize()
@ -28,4 +30,4 @@ if __name__ == "__main__":
torch_device = "cuda:0" if torch.cuda.is_available() else ("mps" if getenv("MPS", 0) else "cpu")
ta, tb = torch.from_numpy(a.numpy()).to(torch_device), torch.from_numpy(b.numpy()).to(torch_device)
tc = torch.nn.functional.conv2d(ta, tb, padding=PADDING)
np.testing.assert_allclose(c.numpy(), tc.cpu(), atol=1e-4, rtol=3e-2)
np.testing.assert_allclose(c.numpy(), tc.cpu(), atol=ATOL, rtol=RTOL)

View File

@ -1,10 +1,12 @@
import numpy as np
from tinygrad.helpers import getenv
from tinygrad import dtypes, Tensor
dtype_in = dtypes.half if getenv("HALF") else dtypes.float
acc_dtype = dtypes.half if getenv("ACC_HALF") else None
dtype_in = dtypes.half if getenv("HALF") else dtypes.bfloat16 if getenv("BFLOAT16") else dtypes.float
acc_dtype = dtypes.half if getenv("ACC_HALF") else dtypes.bfloat16 if getenv("ACC_BFLOAT16") else None
N = getenv("N", 4096)
CNT = getenv("CNT", 10)
ATOL = getenv("ATOL", 1e-4)
RTOL = getenv("RTOL", 3e-2)
if __name__ == "__main__":
a, b = Tensor.rand(N, N, dtype=dtype_in).realize(), Tensor.rand(N, N, dtype=dtype_in).realize()
@ -14,4 +16,4 @@ if __name__ == "__main__":
c = a.matmul(b, acc_dtype=acc_dtype).realize()
comp = a.numpy().astype(np.float32) @ b.numpy().astype(np.float32)
nc = c.numpy()
np.testing.assert_allclose(nc, comp, atol=1e-4, rtol=3e-2)
np.testing.assert_allclose(nc, comp, atol=ATOL, rtol=RTOL)

View File

@ -12,7 +12,7 @@ from tinygrad.tensor import Tensor
from tinygrad.engine.jit import CacheCollector
from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import run_schedule
from tinygrad.helpers import prod, Context
from tinygrad.helpers import prod, Context, getenv
from tinygrad.dtype import DType, dtypes
from tinygrad.codegen.uops import UOpGraph
@ -186,6 +186,7 @@ class TestLinearizer(unittest.TestCase):
if not Device[Device.DEFAULT].compiler.linearizer_opts.has_tensor_cores:
self.skipTest("device doesn't have tensor cores")
for tc in tensor_cores[Device[Device.DEFAULT].compiler.linearizer_opts.device]:
if getenv("EMULATE_CUDA") and (tc.dtype_in == dtypes.bfloat16 or tc.dtype_out == dtypes.bfloat16): continue
a, b = Tensor.rand(tc.dims[1], tc.dims[2], dtype=tc.dtype_in), Tensor.rand(tc.dims[2], tc.dims[0], dtype=tc.dtype_in)
np_a, np_b = a.numpy(), b.numpy()
r = a.matmul(b, acc_dtype=tc.dtype_out)

View File

@ -30,16 +30,14 @@ class Opt:
@dataclass(frozen=True)
class TensorCore: # D = A * B + C, A is (M x K), B is (K x N), C and D are (M x N)
dims: List[int] # N, M, K
dims: Tuple[int,int,int] # N, M, K
dtype_in: DType # dtype for A and B
dtype_out: DType # dtype for C and D
threads: List[Tuple[int,int]] # list of (TC dim,amt) that construct the warp thread structure
thread_local_aliases: List[List[List[int]]] # a list of [threads_1, ..., threads_n, upcast_1(unrolled), upcast_2(upcast)] defining the alias (-1 is upcast, 1-n is warp threads) for each TC dim # noqa: E501
thread_local_sizes: List[List[int]] # in each thread, the number of elements stored in registers for each TC dim
wmma_func: str # name of wmma function to call
def __str__(self): return f"tensor_core<{self.dims}, {self.dtype_in}, {self.dtype_out}>"
def num_threads(self): return len(self.threads)
def num_upcasts(self): return len(self.thread_local_aliases[0]) - self.num_threads()
def __str__(self): return "_".join(["WMMA"] + list(map(str, self.dims)) + [self.dtype_in.name, self.dtype_out.name])
def num_upcasts(self): return len(self.thread_local_aliases[0]) - len(self.threads)
class TensorCoreOptions(NamedTuple):
bufs: Tuple[int, int] # the local aliased buffers for A and B
@ -51,18 +49,9 @@ class TensorCoreOptions(NamedTuple):
elif removed_axis == self.axes[tc_dim]: self.axes_exist[tc_dim] = False
tensor_cores: Dict[str, List[TensorCore]] = {
"METAL": [
TensorCore(dims=[8,8,8], dtype_in=dtypes.float, dtype_out=dtypes.float, wmma_func="__metal_wmma<float2,simdgroup_float8x8,float2>", threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[[2],[2],[2]], thread_local_aliases=[ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ]), # noqa: E501
TensorCore(dims=[8,8,8], dtype_in=dtypes.half, dtype_out=dtypes.float, wmma_func="__metal_wmma<half2,simdgroup_float8x8,float2>", threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[[2],[2],[2]], thread_local_aliases=[ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ]), # noqa: E501
TensorCore(dims=[8,8,8], dtype_in=dtypes.half, dtype_out=dtypes.half, wmma_func="__metal_wmma<half2,simdgroup_half8x8,half2>", threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[[2],[2],[2]], thread_local_aliases=[ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ]), # noqa: E501
],
"HSA": [
TensorCore(dims=[16,16,16], dtype_in=dtypes.half, dtype_out=dtypes.float, wmma_func="__builtin_amdgcn_wmma_f32_16x16x16_f16_w32", threads=[(0,8),(0,2),(1,2)], thread_local_sizes=[[16],[16],[4,2]], thread_local_aliases=[ [[2],[0],[0],[-1],[1]], [[0],[2],[1],[-1],[0]], [[-2],[2],[1],[0],[3,-1]] ]), # noqa: E501
TensorCore(dims=[16,16,16], dtype_in=dtypes.half, dtype_out=dtypes.half, wmma_func="__hip_wmma_f16_f16", threads=[(0,8),(0,2),(1,2)], thread_local_sizes=[[16],[16],[4,2]], thread_local_aliases=[ [[2],[0],[0],[-1],[1]], [[0],[2],[1],[-1],[0]], [[-2],[2],[1],[0],[3,-1]] ]), # noqa: E501
],
"CUDA": [
TensorCore(dims=[8,16,16], dtype_in=dtypes.half, dtype_out=dtypes.float, wmma_func="__cuda_mma_m16n8k16_f16_f32", threads=[(0,2),(0,2),(1,2),(1,2),(0,2)], thread_local_sizes=[[2,2,2],[2,2],[2,2]], thread_local_aliases=[ [[0],[-2],[5],[0],[0],[-1,1,2,-3],[3,4]], [[5],[0],[0],[4],[3],[-1,1,2,-2],[0]], [[2],[-2],[5],[1],[-1],[0],[3,4]] ]), # noqa: E501
],
"METAL": [TensorCore(dims=(8,8,8), threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[[2],[2],[2]], thread_local_aliases=[ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ], dtype_in=di, dtype_out=do) for (di, do) in [(dtypes.float, dtypes.float), (dtypes.half, dtypes.float), (dtypes.half, dtypes.half)]], # noqa: E501
"HSA": [TensorCore(dims=(16,16,16), threads=[(0,8),(0,2),(1,2)], thread_local_sizes=[[16],[16],[4,2]], thread_local_aliases=[ [[2],[0],[0],[-1],[1]], [[0],[2],[1],[-1],[0]], [[-2],[2],[1],[0],[3,-1]] ], dtype_in=di, dtype_out=do) for (di, do) in [(dtypes.half, dtypes.float), (dtypes.half, dtypes.half)]], # noqa: E501
"CUDA": [TensorCore(dims=(8,16,16), threads=[(0,2),(0,2),(1,2),(1,2),(0,2)], thread_local_sizes=[[2,2,2],[2,2],[2,2]], thread_local_aliases=[ [[0],[-2],[5],[0],[0],[-1,1,2,-3],[3,4]], [[5],[0],[0],[4],[3],[-1,1,2,-2],[0]], [[2],[-2],[5],[1],[-1],[0],[3,4]] ], dtype_in=di, dtype_out=do) for (di, do) in [(dtypes.half, dtypes.float), (dtypes.bfloat16, dtypes.float)]], # noqa: E501
}
class LocalBuffer(NamedTuple):

View File

@ -291,10 +291,10 @@ class Linearizer(Kernel):
if (tc:=self.tensor_core):
min_alias_idx = min(self.local_alias.keys())
replace_input_idxs = calc_tc_idxs(tc.thread_local_sizes[i-min_alias_idx], tc.thread_local_aliases[i-min_alias_idx])
for n in range(tc.num_threads()):
buf_idxs[self.first_reduce-tc.num_threads()+n] = replace_input_idxs[n] # replace locals
for n in range(len(tc.threads)):
buf_idxs[self.first_reduce-len(tc.threads)+n] = replace_input_idxs[n] # replace locals
for n in range(tc.num_upcasts()):
buf_idxs[self.shape_len-self.upcasted+n] = replace_input_idxs[tc.num_threads()+n] # replace upcasts
buf_idxs[self.shape_len-self.upcasted+n] = replace_input_idxs[len(tc.threads)+n] # replace upcasts
if DEBUG >= 3: print(f"{localbuf_idx} alias {i}: sts={self.sts[i]} idxs={buf_idxs}")
ll = self.global_load(i, buf_idxs)
locals_to_store.append((localbuf_idx, buf_idxs, ll))
@ -315,13 +315,13 @@ class Linearizer(Kernel):
strides.append((0 if stride == 0 else next, sz))
next *= 1 if stride == 0 else sz
return strides
upcasts = [upcast_strides(x) for x in [locals_to_store[0][0], locals_to_store[1][0], 0]]
for iter in [x[::-1] for x in [x for x in itertools.product(*[x for x in [range(sz) for _,sz in upcasts[0]][::-1]])]]:
upcasts, dev = [upcast_strides(x) for x in [locals_to_store[0][0], locals_to_store[1][0], 0]], self.opts.device
for iter in [x[::-1] for x in itertools.product(*[x for x in [range(sz) for _,sz in upcasts[0]][::-1]])]:
offs = [x*y for (x,y) in zip([sum([prod(x) for x in zip(iter, [stride for stride,_ in y])]) for y in upcasts], wmma_sz)]
ops = (self.uops.add(UOps.CAST, tc.dtype_in.vec(wmma_sz[0]), tuple(locals_to_store[0][2][offs[0]:offs[0]+wmma_sz[0]])),
self.uops.add(UOps.CAST, tc.dtype_in.vec(wmma_sz[1]), tuple(locals_to_store[1][2][offs[1]:offs[1]+wmma_sz[1]])),
self.uops.add(UOps.CAST, tc.dtype_out.vec(wmma_sz[2]), tuple(op3:=acc[offs[2]:offs[2]+wmma_sz[2]])))
ret = self.uops.add(UOps.WMMA, tc.dtype_out.vec(wmma_sz[2]), ops, tc.wmma_func)
self.uops.add(UOps.CAST, (dt3:=tc.dtype_out.vec(wmma_sz[2])), tuple(op3:=acc[offs[2]:offs[2]+wmma_sz[2]])))
ret = self.uops.add(UOps.WMMA, dt3, ops, (str(tc), tc.dims, tc.dtype_in, tc.dtype_out, tuple(map(prod, tc.thread_local_sizes)), dev))
for z in range(wmma_sz[2]):
acc[offs[2]+z] = self.uops.add(UOps.PHI, tc.dtype_out, (op3[z], self.uops.add(UOps.GEP, tc.dtype_out, (ret,), z)) + loop_ctx)
else:

View File

@ -2,7 +2,7 @@ from __future__ import annotations
import functools, math, operator, itertools
from typing import List, Set, Optional, Tuple, Any, Dict, DefaultDict, Callable, cast
from collections import defaultdict
from tinygrad.helpers import DEBUG, flatten
from tinygrad.helpers import DEBUG, flatten, prod
from tinygrad.dtype import dtypes, DType
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
from tinygrad.shape.symbolic import sint, Variable, Node, NumNode, MulNode, DivNode, SumNode
@ -385,8 +385,6 @@ class UOpGraph:
assert u.vin[2].dtype is not None
mem += u.vin[2].dtype.itemsize * mults
elif u.uop is UOps.WMMA:
if u.arg.startswith("__metal_wmma"): flops += 2*(8*8*8)//32 * mults
elif u.arg == "__hip_wmma_f16_f16" or u.arg == "__builtin_amdgcn_wmma_f32_16x16x16_f16_w32": flops += 2*(16*16*16)//32 * mults
elif u.arg == "__cuda_mma_m16n8k16_f16_f32": flops += 2*(8*16*16)//32 * mults
else: raise NotImplementedError(f"not implemented wmma {u.arg=}")
assert u.arg[1] is not None
flops += 2 * prod(u.arg[1]) // 32 * mults
return flops, mem

View File

@ -164,7 +164,7 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:UOpGraph) -> str
assert len(bufs) == args[0], f"missed a global buffer {len(bufs)} {args}"
bufs.append((args[1], (dtype,args[2])))
r[u] = args[1]
elif uop is UOps.WMMA: kk(f"{lang.render_dtype(dtype)} {ssa(u, 'wmma')} = {args}({r[vin[0]]}, {r[vin[1]]}, {r[vin[2]]});")
elif uop is UOps.WMMA: kk(f"{lang.render_dtype(dtype)} {ssa(u, 'wmma')} = __{args[0]}({r[vin[0]]}, {r[vin[1]]}, {r[vin[2]]});")
elif uop is UOps.DEFINE_ACC: kk(f"{lang.render_dtype(dtype)} {ssa(u,'acc')} = {lang.render_const(args, dtype)};")
elif uop is UOps.CONST: r[u] = lang.render_const(args, dtype) if args >= 0 else f"({lang.render_const(args, dtype)})"
elif uop is UOps.GEP:
@ -215,11 +215,11 @@ class MetalLanguage(CStyleLanguage):
return f"as_type<{self.render_dtype(var_dtype)}>({x[0]})" if bitcast else super().render_cast(x, var_dtype)
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None):
prefix = ["#include <metal_stdlib>","using namespace metal;"]
if any(uop.uop == UOps.WMMA for uop in uops): prefix.append("""template<typename T, typename S, typename U> U __metal_wmma(T m, T n, U o) {
S a,b,c; a.thread_elements()[0] = m.x; a.thread_elements()[1] = m.y; b.thread_elements()[0] = n.x; b.thread_elements()[1] = n.y;
c.thread_elements()[0] = o.x; c.thread_elements()[1] = o.y; simdgroup_multiply_accumulate(c, a, b, c);
return U(c.thread_elements()[0], c.thread_elements()[1]);\n}""")
prefix, wmma_args = ["#include <metal_stdlib>","using namespace metal;"], set([uop.arg for uop in uops if uop.uop == UOps.WMMA])
for arg in wmma_args: prefix.append(f"""{arg[3].name}2 __{arg[0]}({arg[2].name}2 m, {arg[2].name}2 n, {arg[3].name}2 o) {{
simdgroup_{arg[3].name}8x8 a,b,c; a.thread_elements()[0] = m.x; a.thread_elements()[1] = m.y; b.thread_elements()[0] = n.x;
b.thread_elements()[1] = n.y; c.thread_elements()[0] = o.x; c.thread_elements()[1] = o.y; simdgroup_multiply_accumulate(c, a, b, c);
return {arg[3].name}2(c.thread_elements()[0], c.thread_elements()[1]);\n}}""")
return super().render_kernel(function_name, kernel, bufs, uops, prefix)
MetalRenderer = functools.partial(uops_to_cstyle, MetalLanguage())
@ -229,6 +229,11 @@ code_for_op_half = {BinaryOps.MAX: lambda a,b,dtype: f"__hmax({a},{b})" if dtype
UnaryOps.LOG2: lambda x,dtype: f"hlog2({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"log2({x})",
UnaryOps.EXP2: lambda x,dtype: f"hexp2({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"exp2({x})",}
_nms = "xyzwabcdefghijkl"
def _make_cuda_dtype(base_type, name, cnt):
vec, elems, header = f"{name}{cnt}", ', '.join(_nms[:cnt]), ', '.join([f"{base_type} {x}" for x in _nms[:cnt]])
return f"struct {vec} {{ {base_type} {elems}; }}; __device__ {vec} make_{vec}({header}) {{ {vec} r={{{elems}}}; return r; }}"
class CUDALanguage(CStyleLanguage):
kernel_prefix = "extern \"C\" __global__ "
smem_prefix = "__shared__ "
@ -241,17 +246,24 @@ class CUDALanguage(CStyleLanguage):
type_map = {dtypes.bfloat16: "nv_bfloat16"}
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None):
# TODO: why is dtypes.bfloat16.name == "__bf16"? would be easier not override dtypes.name
dt_map = { dtypes.float: ("float","f32"), dtypes.half: ("half","f16"), dtypes.bfloat16: ("bfloat16","bf16"), }
prefix = ["#define INFINITY (__int_as_float(0x7f800000))","#define NAN (__int_as_float(0x7fffffff))"]
if any(uop.dtype == dtypes.half for uop in uops):
prefix += ["#include <cuda_fp16.h>", "struct half4 { half x, y, z, w; };", "struct half8 { half x, y, z, w, a, b, c, d; };",
"__device__ half4 make_half4(half x, half y, half z, half w) { half4 r={x, y, z, w}; return r; }",
"__device__ half8 make_half8(half x, half y, half z, half w, half a, half b, half c, half d) { half8 r={x, y, z, w, a, b, c, d}; return r; }",
"""__device__ float4 __cuda_mma_m16n8k16_f16_f32(half8 a, half4 b, float4 c) { int *a_pk = (int *) (&a), *b_pk = (int *) (&b);
asm( "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 { %0, %1, %2, %3 }, { %4, %5, %6, %7 }, { %8, %9 }, { %0, %1, %2, %3 };"
: "+f"(c.x), "+f"(c.y), "+f"(c.z), "+f"(c.w) : "r"(a_pk[0]), "r"(a_pk[1]), "r"(a_pk[2]), "r"(a_pk[3]), "r"(b_pk[0]), "r"(b_pk[1]) );
return c;}""",]
prefix += ["#include <cuda_fp16.h>"] + [_make_cuda_dtype("half", "half", x) for x in [4, 8]]
if any(uop.dtype == dtypes.bfloat16 for uop in uops):
prefix += ["#include <cuda_bf16.h>"] + [_make_cuda_dtype("nv_bfloat16", "bfloat16", x) for x in [4, 8]]
# TODO: this has to be way better to generate for arbitrary M,N,K: use arg[1] for MNK, use arg[4] for vec sizes, encode register packing
for arg in set([uop.arg for uop in uops if uop.uop == UOps.WMMA]):
fn, ti, to, ci, co = arg[0], dt_map[arg[2]][0], dt_map[arg[3]][0], dt_map[arg[2]][1], dt_map[arg[3]][1]
prefix.append(f"""__device__ {to}4 __{fn}({ti}8 a, {ti}4 b, {to}4 c) {{ int *a_pk = (int *) (&a), *b_pk = (int *) (&b);
asm( "mma.sync.aligned.m16n8k16.row.col.{co}.{ci}.{ci}.{co} {{ %0, %1, %2, %3 }}, {{ %4, %5, %6, %7 }}, {{ %8, %9 }}, {{ %0, %1, %2, %3 }};"
: "+f"(c.x), "+f"(c.y), "+f"(c.z), "+f"(c.w) : "r"(a_pk[0]), "r"(a_pk[1]), "r"(a_pk[2]), "r"(a_pk[3]), "r"(b_pk[0]), "r"(b_pk[1]) );
return c;}}""")
if any(uop.dtype == dtypes.bfloat16 for uop in uops): prefix.append("#include <cuda_bf16.h>")
return super().render_kernel(function_name, kernel, bufs, uops, prefix=prefix)
CUDARenderer = functools.partial(uops_to_cstyle, CUDALanguage())
@ -273,43 +285,31 @@ def _make_hip_code_for_op():
return { k:wrapper(k,v) for k,v in {**CStyleLanguage().code_for_op, **code_for_op_hip}.items() }
def _make_hip_dtype(base_type, name, cnt):
nms = "xyzwabcdefghijkl"[:cnt]
elems, header = ', '.join(_nms[:cnt]), ', '.join([f"{base_type} {x}" for x in _nms[:cnt]])
return f"typedef {base_type} {name}{cnt} __attribute__((ext_vector_type({cnt})));\n" + \
f"static inline __attribute__((device)) {name}{cnt} make_{name}{cnt}(" + ', '.join([f"{base_type} {x}" for x in nms]) + \
") { return {" + ', '.join(nms) + "}; }"
f"static inline __attribute__((device)) {name}{cnt} make_{name}{cnt}({header}) {{ return {{{elems}}}; }}"
class HIPLanguage(CStyleLanguage):
kernel_prefix = """
#define half _Float16
extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_local_id(unsigned int);
extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_group_id(unsigned int);
extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_local_size(unsigned int);
extern "C" {
__attribute__((device)) __attribute__((const)) float __ocml_fmax_f32(float, float);
__attribute__((device)) __attribute__((pure)) float __ocml_exp2_f32(float);
__attribute__((device)) __attribute__((pure)) float __ocml_log2_f32(float);
__attribute__((device)) float __ocml_sin_f32(float);
__attribute__((device)) __attribute__((const)) float __ocml_sqrt_f32(float);
__attribute__((device)) __attribute__((const)) double __ocml_fmax_f64(double, double);
__attribute__((device)) __attribute__((pure)) double __ocml_exp2_f64(double);
__attribute__((device)) __attribute__((pure)) double __ocml_log2_f64(double);
__attribute__((device)) double __ocml_sin_f64(double);
__attribute__((device)) __attribute__((const)) double __ocml_sqrt_f64(double);
__attribute__((device)) __attribute__((const)) _Float16 __ocml_fmax_f16(_Float16, _Float16);
__attribute__((device)) __attribute__((pure)) _Float16 __ocml_exp2_f16(_Float16);
__attribute__((device)) __attribute__((pure)) _Float16 __ocml_log2_f16(_Float16);
__attribute__((device)) _Float16 __ocml_sin_f16(_Float16);
__attribute__((device)) __attribute__((const)) _Float16 __ocml_sqrt_f16(_Float16);
}\n""" + '\n'.join([_make_hip_dtype(*x) for x in [
("_Float16", "half", 2), ("_Float16", "half", 4), ("_Float16", "half", 8), ("_Float16", "half", 16),
("float", "float", 8)]]) + """
static __attribute__((device)) half8 __hip_wmma_f16_f16(half16 a, half16 b, half8 c) {
half16 c_frag = {}; half8 d; for (int n = 0; n < 8; n++) { c_frag[n*2] = c[n]; }
c_frag = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(a, b, c_frag, false);
for (int n = 0; n < 8; n++) { d[n] = c_frag[n*2]; } return d;
}\nextern "C" __attribute__((global))"""
kernel_prefix = """extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_local_id(unsigned int);
extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_group_id(unsigned int);
extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_local_size(unsigned int);
extern "C" {
__attribute__((device)) __attribute__((const)) float __ocml_fmax_f32(float, float);
__attribute__((device)) __attribute__((pure)) float __ocml_exp2_f32(float);
__attribute__((device)) __attribute__((pure)) float __ocml_log2_f32(float);
__attribute__((device)) float __ocml_sin_f32(float);
__attribute__((device)) __attribute__((const)) float __ocml_sqrt_f32(float);
__attribute__((device)) __attribute__((const)) double __ocml_fmax_f64(double, double);
__attribute__((device)) __attribute__((pure)) double __ocml_exp2_f64(double);
__attribute__((device)) __attribute__((pure)) double __ocml_log2_f64(double);
__attribute__((device)) double __ocml_sin_f64(double);
__attribute__((device)) __attribute__((const)) double __ocml_sqrt_f64(double);
__attribute__((device)) __attribute__((const)) _Float16 __ocml_fmax_f16(_Float16, _Float16);
__attribute__((device)) __attribute__((pure)) _Float16 __ocml_exp2_f16(_Float16);
__attribute__((device)) __attribute__((pure)) _Float16 __ocml_log2_f16(_Float16);
__attribute__((device)) _Float16 __ocml_sin_f16(_Float16);
__attribute__((device)) __attribute__((const)) _Float16 __ocml_sqrt_f16(_Float16);
}\nextern "C" __attribute__((global))"""
code_for_workitem = {"g": lambda x: f"__ockl_get_group_id({x})", "l": lambda x: f"__ockl_get_local_id({x})",
"i": lambda x: f"(__ockl_get_group_id({x})*__ockl_get_local_size({x})+__ockl_get_local_id({x}))"}
code_for_op = _make_hip_code_for_op()
@ -321,8 +321,10 @@ class HIPLanguage(CStyleLanguage):
type_map = {dtypes.bfloat16: "hip_bfloat16"}
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
prefix = ["#define INFINITY (__builtin_inff())\n#define NAN (__builtin_nanf(\"\"))",
"typedef long unsigned int size_t;"]
prefix = ["#define INFINITY (__builtin_inff())", "#define NAN (__builtin_nanf(\"\"))", "typedef long unsigned int size_t;"]
vec_dts = [("float", "float", 2), ("float", "float", 4), ("float", "float", 8), ("signed int", "int", 4), ("signed int", "int", 2)]
# TODO: add BF16 vec dts
if any(uop.dtype == dtypes.bfloat16 for uop in uops): prefix.append("""
struct hip_bfloat16 {
unsigned short data;
@ -339,8 +341,19 @@ struct hip_bfloat16 {
static __attribute__((device)) bool operator<(hip_bfloat16 a, hip_bfloat16 b) { return ((float)a) < ((float)b); }
static __attribute__((device)) bool operator==(hip_bfloat16 a, hip_bfloat16 b) { return ((float)a) == ((float)b); }
""")
prefix.append('\n'.join(_make_hip_dtype(*x) for x in [("float", "float", 2), ("float", "float", 4),
("signed int", "int", 4), ("signed int", "int", 2)]))
if any(uop.dtype == dtypes.half for uop in uops):
prefix.append("#define half _Float16")
vec_dts += [("_Float16", "half", 2), ("_Float16", "half", 4), ("_Float16", "half", 8), ("_Float16", "half", 16)]
prefix += [_make_hip_dtype(*x) for x in vec_dts]
for arg in set([uop.arg for uop in uops if uop.uop == UOps.WMMA]): # TODO: handle TCs f32_bf16 and bf16_bf16 w/ wrapper
if arg[3] == dtypes.float: prefix.append(f"#define __{arg[0]} __builtin_amdgcn_wmma_f32_16x16x16_f16_w32")
else: prefix.append(f"static __attribute__((device)) half8 __{arg[0]}"+"""(half16 a, half16 b, half8 c) {
half16 c_frag = {}; half8 d; for (int n = 0; n < 8; n++) { c_frag[n*2] = c[n]; }
c_frag = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(a, b, c_frag, false);
for (int n = 0; n < 8; n++) { d[n] = c_frag[n*2]; } return d;\n}""")
return super().render_kernel(function_name, kernel, bufs, uops, prefix)
def get_kernel_modifier(self, uops:UOpGraph) -> str:

View File

@ -145,13 +145,14 @@ class PythonProgram:
out[elem_idx][goff+lane_id] += sum(a_elem(inp[0], _k, c_j, goff) * b_elem(inp[1], c_i, _k, goff) for _k in range(K))
return out
if arg.startswith('__metal_wmma'):
# TODO: refactor these to a shared TensorCoreLayout in kernel.py
if arg[5] == "METAL":
# A (2 elements on 32 threads): row major
def a_b_elem(x, i, j, goff): return x[(i%2)][goff+(i//2)%2+(j%4)*2+(i//4)*8+(j//4)*16]
# (i, j), C, D (2 elements on 32 threads): row major same as A/B
def c_map(lane, elem): return (elem + ((lane%2)*2) + ((lane//8)%2)*4, ((lane//2)%4) + (lane//16)*4)
ul[i] = wmma_helper(32, 8, 2, 2, 2, a_b_elem, a_b_elem, c_map)
elif arg == '__builtin_amdgcn_wmma_f32_16x16x16_f16_w32' or arg == '__hip_wmma_f16_f16':
elif arg[5] == "HSA":
# A (16 elements on 32 threads): col major, lane 16-32 == lane 0-15
def a_elem(x, i, j, goff):
assert x[i][goff+j] == x[i][goff+j+16], "warp elements not duplicated properly across lanes"
@ -160,7 +161,7 @@ class PythonProgram:
def b_elem(x, i, j, goff): return a_elem(x, j, i, goff)
def c_map(lane, elem): return (lane%16, lane//16+elem*2) # (i, j), C, D (8 elements on 32 threads): row major
ul[i] = wmma_helper(32, 16, 16, 16, 8, a_elem, b_elem, c_map)
elif arg == '__cuda_mma_m16n8k16_f16_f32':
elif arg[5] == "CUDA":
# A (8 elements on 32 threads)
def a_elem(x, i, j, goff): return x[(i%2)+(j//8)*2+(i//8)*4][goff+((i//2)%4)+(j%8)*4]
# B (4 elements on 32 threads)