diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index a1fa6661..e40dd00f 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -108,7 +108,9 @@ jobs: - name: Test speed vs torch run: CUDA=1 BIG=2 TORCHCUDA=1 python3 test/test_speed_v_torch.py | tee torch_speed.txt - name: Run Tensor Core GEMM - run: CUDA=1 HALF=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul.txt + run: | + CUDA=1 HALF=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul.txt + CUDA=1 BFLOAT16=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul_bfloat16.txt - name: Run LLaMA run: | CUDA=1 JIT=0 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_unjitted.txt @@ -140,6 +142,7 @@ jobs: onnx_inference_speed.csv torch_speed.txt matmul.txt + matmul_bfloat16.txt llama_unjitted.txt llama_jitted.txt llama_beam.txt diff --git a/extra/gemm/simple_conv.py b/extra/gemm/simple_conv.py index c6d4bc3f..d7e08ef8 100644 --- a/extra/gemm/simple_conv.py +++ b/extra/gemm/simple_conv.py @@ -1,8 +1,8 @@ from tinygrad.helpers import getenv from tinygrad import dtypes, Tensor -dtype_in = dtypes.half if getenv("HALF") else dtypes.float -acc_dtype = dtypes.half if getenv("ACC_HALF") else None +dtype_in = dtypes.half if getenv("HALF") else dtypes.bfloat16 if getenv("BFLOAT16") else dtypes.float +acc_dtype = dtypes.half if getenv("ACC_HALF") else dtypes.bfloat16 if getenv("ACC_BFLOAT16") else None CNT = getenv("CNT", 8) BS = getenv("BS", 16) @@ -12,6 +12,8 @@ HW = getenv("HW", 128) K = getenv("K", 3) PADDING = getenv("PADDING", 1) COMP = getenv("COMP", 0) +ATOL = getenv("ATOL", 1e-4) +RTOL = getenv("RTOL", 3e-2) FLOPS = BS*K*K*CIN*HW*HW*COUT*2 def rand_input(): return Tensor.rand(BS, CIN, HW, HW, dtype=dtype_in).realize(), Tensor.rand(COUT, CIN, K, K, dtype=dtype_in).realize() @@ -28,4 +30,4 @@ if __name__ == "__main__": torch_device = "cuda:0" if torch.cuda.is_available() else ("mps" if getenv("MPS", 0) else "cpu") ta, tb = torch.from_numpy(a.numpy()).to(torch_device), torch.from_numpy(b.numpy()).to(torch_device) tc = torch.nn.functional.conv2d(ta, tb, padding=PADDING) - np.testing.assert_allclose(c.numpy(), tc.cpu(), atol=1e-4, rtol=3e-2) + np.testing.assert_allclose(c.numpy(), tc.cpu(), atol=ATOL, rtol=RTOL) diff --git a/extra/gemm/simple_matmul.py b/extra/gemm/simple_matmul.py index 9512f65e..f8e0e1fe 100644 --- a/extra/gemm/simple_matmul.py +++ b/extra/gemm/simple_matmul.py @@ -1,10 +1,12 @@ import numpy as np from tinygrad.helpers import getenv from tinygrad import dtypes, Tensor -dtype_in = dtypes.half if getenv("HALF") else dtypes.float -acc_dtype = dtypes.half if getenv("ACC_HALF") else None +dtype_in = dtypes.half if getenv("HALF") else dtypes.bfloat16 if getenv("BFLOAT16") else dtypes.float +acc_dtype = dtypes.half if getenv("ACC_HALF") else dtypes.bfloat16 if getenv("ACC_BFLOAT16") else None N = getenv("N", 4096) CNT = getenv("CNT", 10) +ATOL = getenv("ATOL", 1e-4) +RTOL = getenv("RTOL", 3e-2) if __name__ == "__main__": a, b = Tensor.rand(N, N, dtype=dtype_in).realize(), Tensor.rand(N, N, dtype=dtype_in).realize() @@ -14,4 +16,4 @@ if __name__ == "__main__": c = a.matmul(b, acc_dtype=acc_dtype).realize() comp = a.numpy().astype(np.float32) @ b.numpy().astype(np.float32) nc = c.numpy() - np.testing.assert_allclose(nc, comp, atol=1e-4, rtol=3e-2) + np.testing.assert_allclose(nc, comp, atol=ATOL, rtol=RTOL) diff --git a/test/test_linearizer.py b/test/test_linearizer.py index 86046100..57834215 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -12,7 +12,7 @@ from tinygrad.tensor import Tensor from tinygrad.engine.jit import CacheCollector from tinygrad.engine.schedule import create_schedule from tinygrad.engine.realize import run_schedule -from tinygrad.helpers import prod, Context +from tinygrad.helpers import prod, Context, getenv from tinygrad.dtype import DType, dtypes from tinygrad.codegen.uops import UOpGraph @@ -186,6 +186,7 @@ class TestLinearizer(unittest.TestCase): if not Device[Device.DEFAULT].compiler.linearizer_opts.has_tensor_cores: self.skipTest("device doesn't have tensor cores") for tc in tensor_cores[Device[Device.DEFAULT].compiler.linearizer_opts.device]: + if getenv("EMULATE_CUDA") and (tc.dtype_in == dtypes.bfloat16 or tc.dtype_out == dtypes.bfloat16): continue a, b = Tensor.rand(tc.dims[1], tc.dims[2], dtype=tc.dtype_in), Tensor.rand(tc.dims[2], tc.dims[0], dtype=tc.dtype_in) np_a, np_b = a.numpy(), b.numpy() r = a.matmul(b, acc_dtype=tc.dtype_out) diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index 76e0b902..d6711223 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -30,16 +30,14 @@ class Opt: @dataclass(frozen=True) class TensorCore: # D = A * B + C, A is (M x K), B is (K x N), C and D are (M x N) - dims: List[int] # N, M, K + dims: Tuple[int,int,int] # N, M, K dtype_in: DType # dtype for A and B dtype_out: DType # dtype for C and D threads: List[Tuple[int,int]] # list of (TC dim,amt) that construct the warp thread structure thread_local_aliases: List[List[List[int]]] # a list of [threads_1, ..., threads_n, upcast_1(unrolled), upcast_2(upcast)] defining the alias (-1 is upcast, 1-n is warp threads) for each TC dim # noqa: E501 thread_local_sizes: List[List[int]] # in each thread, the number of elements stored in registers for each TC dim - wmma_func: str # name of wmma function to call - def __str__(self): return f"tensor_core<{self.dims}, {self.dtype_in}, {self.dtype_out}>" - def num_threads(self): return len(self.threads) - def num_upcasts(self): return len(self.thread_local_aliases[0]) - self.num_threads() + def __str__(self): return "_".join(["WMMA"] + list(map(str, self.dims)) + [self.dtype_in.name, self.dtype_out.name]) + def num_upcasts(self): return len(self.thread_local_aliases[0]) - len(self.threads) class TensorCoreOptions(NamedTuple): bufs: Tuple[int, int] # the local aliased buffers for A and B @@ -51,18 +49,9 @@ class TensorCoreOptions(NamedTuple): elif removed_axis == self.axes[tc_dim]: self.axes_exist[tc_dim] = False tensor_cores: Dict[str, List[TensorCore]] = { - "METAL": [ - TensorCore(dims=[8,8,8], dtype_in=dtypes.float, dtype_out=dtypes.float, wmma_func="__metal_wmma", threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[[2],[2],[2]], thread_local_aliases=[ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ]), # noqa: E501 - TensorCore(dims=[8,8,8], dtype_in=dtypes.half, dtype_out=dtypes.float, wmma_func="__metal_wmma", threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[[2],[2],[2]], thread_local_aliases=[ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ]), # noqa: E501 - TensorCore(dims=[8,8,8], dtype_in=dtypes.half, dtype_out=dtypes.half, wmma_func="__metal_wmma", threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[[2],[2],[2]], thread_local_aliases=[ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ]), # noqa: E501 - ], - "HSA": [ - TensorCore(dims=[16,16,16], dtype_in=dtypes.half, dtype_out=dtypes.float, wmma_func="__builtin_amdgcn_wmma_f32_16x16x16_f16_w32", threads=[(0,8),(0,2),(1,2)], thread_local_sizes=[[16],[16],[4,2]], thread_local_aliases=[ [[2],[0],[0],[-1],[1]], [[0],[2],[1],[-1],[0]], [[-2],[2],[1],[0],[3,-1]] ]), # noqa: E501 - TensorCore(dims=[16,16,16], dtype_in=dtypes.half, dtype_out=dtypes.half, wmma_func="__hip_wmma_f16_f16", threads=[(0,8),(0,2),(1,2)], thread_local_sizes=[[16],[16],[4,2]], thread_local_aliases=[ [[2],[0],[0],[-1],[1]], [[0],[2],[1],[-1],[0]], [[-2],[2],[1],[0],[3,-1]] ]), # noqa: E501 - ], - "CUDA": [ - TensorCore(dims=[8,16,16], dtype_in=dtypes.half, dtype_out=dtypes.float, wmma_func="__cuda_mma_m16n8k16_f16_f32", threads=[(0,2),(0,2),(1,2),(1,2),(0,2)], thread_local_sizes=[[2,2,2],[2,2],[2,2]], thread_local_aliases=[ [[0],[-2],[5],[0],[0],[-1,1,2,-3],[3,4]], [[5],[0],[0],[4],[3],[-1,1,2,-2],[0]], [[2],[-2],[5],[1],[-1],[0],[3,4]] ]), # noqa: E501 - ], + "METAL": [TensorCore(dims=(8,8,8), threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[[2],[2],[2]], thread_local_aliases=[ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ], dtype_in=di, dtype_out=do) for (di, do) in [(dtypes.float, dtypes.float), (dtypes.half, dtypes.float), (dtypes.half, dtypes.half)]], # noqa: E501 + "HSA": [TensorCore(dims=(16,16,16), threads=[(0,8),(0,2),(1,2)], thread_local_sizes=[[16],[16],[4,2]], thread_local_aliases=[ [[2],[0],[0],[-1],[1]], [[0],[2],[1],[-1],[0]], [[-2],[2],[1],[0],[3,-1]] ], dtype_in=di, dtype_out=do) for (di, do) in [(dtypes.half, dtypes.float), (dtypes.half, dtypes.half)]], # noqa: E501 + "CUDA": [TensorCore(dims=(8,16,16), threads=[(0,2),(0,2),(1,2),(1,2),(0,2)], thread_local_sizes=[[2,2,2],[2,2],[2,2]], thread_local_aliases=[ [[0],[-2],[5],[0],[0],[-1,1,2,-3],[3,4]], [[5],[0],[0],[4],[3],[-1,1,2,-2],[0]], [[2],[-2],[5],[1],[-1],[0],[3,4]] ], dtype_in=di, dtype_out=do) for (di, do) in [(dtypes.half, dtypes.float), (dtypes.bfloat16, dtypes.float)]], # noqa: E501 } class LocalBuffer(NamedTuple): diff --git a/tinygrad/codegen/linearizer.py b/tinygrad/codegen/linearizer.py index 231382df..1c1da11d 100644 --- a/tinygrad/codegen/linearizer.py +++ b/tinygrad/codegen/linearizer.py @@ -291,10 +291,10 @@ class Linearizer(Kernel): if (tc:=self.tensor_core): min_alias_idx = min(self.local_alias.keys()) replace_input_idxs = calc_tc_idxs(tc.thread_local_sizes[i-min_alias_idx], tc.thread_local_aliases[i-min_alias_idx]) - for n in range(tc.num_threads()): - buf_idxs[self.first_reduce-tc.num_threads()+n] = replace_input_idxs[n] # replace locals + for n in range(len(tc.threads)): + buf_idxs[self.first_reduce-len(tc.threads)+n] = replace_input_idxs[n] # replace locals for n in range(tc.num_upcasts()): - buf_idxs[self.shape_len-self.upcasted+n] = replace_input_idxs[tc.num_threads()+n] # replace upcasts + buf_idxs[self.shape_len-self.upcasted+n] = replace_input_idxs[len(tc.threads)+n] # replace upcasts if DEBUG >= 3: print(f"{localbuf_idx} alias {i}: sts={self.sts[i]} idxs={buf_idxs}") ll = self.global_load(i, buf_idxs) locals_to_store.append((localbuf_idx, buf_idxs, ll)) @@ -315,13 +315,13 @@ class Linearizer(Kernel): strides.append((0 if stride == 0 else next, sz)) next *= 1 if stride == 0 else sz return strides - upcasts = [upcast_strides(x) for x in [locals_to_store[0][0], locals_to_store[1][0], 0]] - for iter in [x[::-1] for x in [x for x in itertools.product(*[x for x in [range(sz) for _,sz in upcasts[0]][::-1]])]]: + upcasts, dev = [upcast_strides(x) for x in [locals_to_store[0][0], locals_to_store[1][0], 0]], self.opts.device + for iter in [x[::-1] for x in itertools.product(*[x for x in [range(sz) for _,sz in upcasts[0]][::-1]])]: offs = [x*y for (x,y) in zip([sum([prod(x) for x in zip(iter, [stride for stride,_ in y])]) for y in upcasts], wmma_sz)] ops = (self.uops.add(UOps.CAST, tc.dtype_in.vec(wmma_sz[0]), tuple(locals_to_store[0][2][offs[0]:offs[0]+wmma_sz[0]])), self.uops.add(UOps.CAST, tc.dtype_in.vec(wmma_sz[1]), tuple(locals_to_store[1][2][offs[1]:offs[1]+wmma_sz[1]])), - self.uops.add(UOps.CAST, tc.dtype_out.vec(wmma_sz[2]), tuple(op3:=acc[offs[2]:offs[2]+wmma_sz[2]]))) - ret = self.uops.add(UOps.WMMA, tc.dtype_out.vec(wmma_sz[2]), ops, tc.wmma_func) + self.uops.add(UOps.CAST, (dt3:=tc.dtype_out.vec(wmma_sz[2])), tuple(op3:=acc[offs[2]:offs[2]+wmma_sz[2]]))) + ret = self.uops.add(UOps.WMMA, dt3, ops, (str(tc), tc.dims, tc.dtype_in, tc.dtype_out, tuple(map(prod, tc.thread_local_sizes)), dev)) for z in range(wmma_sz[2]): acc[offs[2]+z] = self.uops.add(UOps.PHI, tc.dtype_out, (op3[z], self.uops.add(UOps.GEP, tc.dtype_out, (ret,), z)) + loop_ctx) else: diff --git a/tinygrad/codegen/uops.py b/tinygrad/codegen/uops.py index 879e019e..41878464 100644 --- a/tinygrad/codegen/uops.py +++ b/tinygrad/codegen/uops.py @@ -2,7 +2,7 @@ from __future__ import annotations import functools, math, operator, itertools from typing import List, Set, Optional, Tuple, Any, Dict, DefaultDict, Callable, cast from collections import defaultdict -from tinygrad.helpers import DEBUG, flatten +from tinygrad.helpers import DEBUG, flatten, prod from tinygrad.dtype import dtypes, DType from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps from tinygrad.shape.symbolic import sint, Variable, Node, NumNode, MulNode, DivNode, SumNode @@ -385,8 +385,6 @@ class UOpGraph: assert u.vin[2].dtype is not None mem += u.vin[2].dtype.itemsize * mults elif u.uop is UOps.WMMA: - if u.arg.startswith("__metal_wmma"): flops += 2*(8*8*8)//32 * mults - elif u.arg == "__hip_wmma_f16_f16" or u.arg == "__builtin_amdgcn_wmma_f32_16x16x16_f16_w32": flops += 2*(16*16*16)//32 * mults - elif u.arg == "__cuda_mma_m16n8k16_f16_f32": flops += 2*(8*16*16)//32 * mults - else: raise NotImplementedError(f"not implemented wmma {u.arg=}") + assert u.arg[1] is not None + flops += 2 * prod(u.arg[1]) // 32 * mults return flops, mem diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 0eaeea12..ed484ee7 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -164,7 +164,7 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:UOpGraph) -> str assert len(bufs) == args[0], f"missed a global buffer {len(bufs)} {args}" bufs.append((args[1], (dtype,args[2]))) r[u] = args[1] - elif uop is UOps.WMMA: kk(f"{lang.render_dtype(dtype)} {ssa(u, 'wmma')} = {args}({r[vin[0]]}, {r[vin[1]]}, {r[vin[2]]});") + elif uop is UOps.WMMA: kk(f"{lang.render_dtype(dtype)} {ssa(u, 'wmma')} = __{args[0]}({r[vin[0]]}, {r[vin[1]]}, {r[vin[2]]});") elif uop is UOps.DEFINE_ACC: kk(f"{lang.render_dtype(dtype)} {ssa(u,'acc')} = {lang.render_const(args, dtype)};") elif uop is UOps.CONST: r[u] = lang.render_const(args, dtype) if args >= 0 else f"({lang.render_const(args, dtype)})" elif uop is UOps.GEP: @@ -215,11 +215,11 @@ class MetalLanguage(CStyleLanguage): return f"as_type<{self.render_dtype(var_dtype)}>({x[0]})" if bitcast else super().render_cast(x, var_dtype) def render_kernel(self, function_name, kernel, bufs, uops, prefix=None): - prefix = ["#include ","using namespace metal;"] - if any(uop.uop == UOps.WMMA for uop in uops): prefix.append("""template U __metal_wmma(T m, T n, U o) { - S a,b,c; a.thread_elements()[0] = m.x; a.thread_elements()[1] = m.y; b.thread_elements()[0] = n.x; b.thread_elements()[1] = n.y; - c.thread_elements()[0] = o.x; c.thread_elements()[1] = o.y; simdgroup_multiply_accumulate(c, a, b, c); - return U(c.thread_elements()[0], c.thread_elements()[1]);\n}""") + prefix, wmma_args = ["#include ","using namespace metal;"], set([uop.arg for uop in uops if uop.uop == UOps.WMMA]) + for arg in wmma_args: prefix.append(f"""{arg[3].name}2 __{arg[0]}({arg[2].name}2 m, {arg[2].name}2 n, {arg[3].name}2 o) {{ + simdgroup_{arg[3].name}8x8 a,b,c; a.thread_elements()[0] = m.x; a.thread_elements()[1] = m.y; b.thread_elements()[0] = n.x; + b.thread_elements()[1] = n.y; c.thread_elements()[0] = o.x; c.thread_elements()[1] = o.y; simdgroup_multiply_accumulate(c, a, b, c); + return {arg[3].name}2(c.thread_elements()[0], c.thread_elements()[1]);\n}}""") return super().render_kernel(function_name, kernel, bufs, uops, prefix) MetalRenderer = functools.partial(uops_to_cstyle, MetalLanguage()) @@ -229,6 +229,11 @@ code_for_op_half = {BinaryOps.MAX: lambda a,b,dtype: f"__hmax({a},{b})" if dtype UnaryOps.LOG2: lambda x,dtype: f"hlog2({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"log2({x})", UnaryOps.EXP2: lambda x,dtype: f"hexp2({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"exp2({x})",} +_nms = "xyzwabcdefghijkl" +def _make_cuda_dtype(base_type, name, cnt): + vec, elems, header = f"{name}{cnt}", ', '.join(_nms[:cnt]), ', '.join([f"{base_type} {x}" for x in _nms[:cnt]]) + return f"struct {vec} {{ {base_type} {elems}; }}; __device__ {vec} make_{vec}({header}) {{ {vec} r={{{elems}}}; return r; }}" + class CUDALanguage(CStyleLanguage): kernel_prefix = "extern \"C\" __global__ " smem_prefix = "__shared__ " @@ -241,17 +246,24 @@ class CUDALanguage(CStyleLanguage): type_map = {dtypes.bfloat16: "nv_bfloat16"} def render_kernel(self, function_name, kernel, bufs, uops, prefix=None): + # TODO: why is dtypes.bfloat16.name == "__bf16"? would be easier not override dtypes.name + dt_map = { dtypes.float: ("float","f32"), dtypes.half: ("half","f16"), dtypes.bfloat16: ("bfloat16","bf16"), } + prefix = ["#define INFINITY (__int_as_float(0x7f800000))","#define NAN (__int_as_float(0x7fffffff))"] if any(uop.dtype == dtypes.half for uop in uops): - prefix += ["#include ", "struct half4 { half x, y, z, w; };", "struct half8 { half x, y, z, w, a, b, c, d; };", - "__device__ half4 make_half4(half x, half y, half z, half w) { half4 r={x, y, z, w}; return r; }", - "__device__ half8 make_half8(half x, half y, half z, half w, half a, half b, half c, half d) { half8 r={x, y, z, w, a, b, c, d}; return r; }", - """__device__ float4 __cuda_mma_m16n8k16_f16_f32(half8 a, half4 b, float4 c) { int *a_pk = (int *) (&a), *b_pk = (int *) (&b); - asm( "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 { %0, %1, %2, %3 }, { %4, %5, %6, %7 }, { %8, %9 }, { %0, %1, %2, %3 };" - : "+f"(c.x), "+f"(c.y), "+f"(c.z), "+f"(c.w) : "r"(a_pk[0]), "r"(a_pk[1]), "r"(a_pk[2]), "r"(a_pk[3]), "r"(b_pk[0]), "r"(b_pk[1]) ); - return c;}""",] + prefix += ["#include "] + [_make_cuda_dtype("half", "half", x) for x in [4, 8]] + + if any(uop.dtype == dtypes.bfloat16 for uop in uops): + prefix += ["#include "] + [_make_cuda_dtype("nv_bfloat16", "bfloat16", x) for x in [4, 8]] + + # TODO: this has to be way better to generate for arbitrary M,N,K: use arg[1] for MNK, use arg[4] for vec sizes, encode register packing + for arg in set([uop.arg for uop in uops if uop.uop == UOps.WMMA]): + fn, ti, to, ci, co = arg[0], dt_map[arg[2]][0], dt_map[arg[3]][0], dt_map[arg[2]][1], dt_map[arg[3]][1] + prefix.append(f"""__device__ {to}4 __{fn}({ti}8 a, {ti}4 b, {to}4 c) {{ int *a_pk = (int *) (&a), *b_pk = (int *) (&b); +asm( "mma.sync.aligned.m16n8k16.row.col.{co}.{ci}.{ci}.{co} {{ %0, %1, %2, %3 }}, {{ %4, %5, %6, %7 }}, {{ %8, %9 }}, {{ %0, %1, %2, %3 }};" + : "+f"(c.x), "+f"(c.y), "+f"(c.z), "+f"(c.w) : "r"(a_pk[0]), "r"(a_pk[1]), "r"(a_pk[2]), "r"(a_pk[3]), "r"(b_pk[0]), "r"(b_pk[1]) ); +return c;}}""") - if any(uop.dtype == dtypes.bfloat16 for uop in uops): prefix.append("#include ") return super().render_kernel(function_name, kernel, bufs, uops, prefix=prefix) CUDARenderer = functools.partial(uops_to_cstyle, CUDALanguage()) @@ -273,43 +285,31 @@ def _make_hip_code_for_op(): return { k:wrapper(k,v) for k,v in {**CStyleLanguage().code_for_op, **code_for_op_hip}.items() } def _make_hip_dtype(base_type, name, cnt): - nms = "xyzwabcdefghijkl"[:cnt] + elems, header = ', '.join(_nms[:cnt]), ', '.join([f"{base_type} {x}" for x in _nms[:cnt]]) return f"typedef {base_type} {name}{cnt} __attribute__((ext_vector_type({cnt})));\n" + \ - f"static inline __attribute__((device)) {name}{cnt} make_{name}{cnt}(" + ', '.join([f"{base_type} {x}" for x in nms]) + \ - ") { return {" + ', '.join(nms) + "}; }" + f"static inline __attribute__((device)) {name}{cnt} make_{name}{cnt}({header}) {{ return {{{elems}}}; }}" class HIPLanguage(CStyleLanguage): - kernel_prefix = """ - #define half _Float16 - - extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_local_id(unsigned int); - extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_group_id(unsigned int); - extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_local_size(unsigned int); - - extern "C" { - __attribute__((device)) __attribute__((const)) float __ocml_fmax_f32(float, float); - __attribute__((device)) __attribute__((pure)) float __ocml_exp2_f32(float); - __attribute__((device)) __attribute__((pure)) float __ocml_log2_f32(float); - __attribute__((device)) float __ocml_sin_f32(float); - __attribute__((device)) __attribute__((const)) float __ocml_sqrt_f32(float); - __attribute__((device)) __attribute__((const)) double __ocml_fmax_f64(double, double); - __attribute__((device)) __attribute__((pure)) double __ocml_exp2_f64(double); - __attribute__((device)) __attribute__((pure)) double __ocml_log2_f64(double); - __attribute__((device)) double __ocml_sin_f64(double); - __attribute__((device)) __attribute__((const)) double __ocml_sqrt_f64(double); - __attribute__((device)) __attribute__((const)) _Float16 __ocml_fmax_f16(_Float16, _Float16); - __attribute__((device)) __attribute__((pure)) _Float16 __ocml_exp2_f16(_Float16); - __attribute__((device)) __attribute__((pure)) _Float16 __ocml_log2_f16(_Float16); - __attribute__((device)) _Float16 __ocml_sin_f16(_Float16); - __attribute__((device)) __attribute__((const)) _Float16 __ocml_sqrt_f16(_Float16); - }\n""" + '\n'.join([_make_hip_dtype(*x) for x in [ - ("_Float16", "half", 2), ("_Float16", "half", 4), ("_Float16", "half", 8), ("_Float16", "half", 16), - ("float", "float", 8)]]) + """ - static __attribute__((device)) half8 __hip_wmma_f16_f16(half16 a, half16 b, half8 c) { - half16 c_frag = {}; half8 d; for (int n = 0; n < 8; n++) { c_frag[n*2] = c[n]; } - c_frag = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(a, b, c_frag, false); - for (int n = 0; n < 8; n++) { d[n] = c_frag[n*2]; } return d; - }\nextern "C" __attribute__((global))""" + kernel_prefix = """extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_local_id(unsigned int); +extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_group_id(unsigned int); +extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_local_size(unsigned int); +extern "C" { +__attribute__((device)) __attribute__((const)) float __ocml_fmax_f32(float, float); +__attribute__((device)) __attribute__((pure)) float __ocml_exp2_f32(float); +__attribute__((device)) __attribute__((pure)) float __ocml_log2_f32(float); +__attribute__((device)) float __ocml_sin_f32(float); +__attribute__((device)) __attribute__((const)) float __ocml_sqrt_f32(float); +__attribute__((device)) __attribute__((const)) double __ocml_fmax_f64(double, double); +__attribute__((device)) __attribute__((pure)) double __ocml_exp2_f64(double); +__attribute__((device)) __attribute__((pure)) double __ocml_log2_f64(double); +__attribute__((device)) double __ocml_sin_f64(double); +__attribute__((device)) __attribute__((const)) double __ocml_sqrt_f64(double); +__attribute__((device)) __attribute__((const)) _Float16 __ocml_fmax_f16(_Float16, _Float16); +__attribute__((device)) __attribute__((pure)) _Float16 __ocml_exp2_f16(_Float16); +__attribute__((device)) __attribute__((pure)) _Float16 __ocml_log2_f16(_Float16); +__attribute__((device)) _Float16 __ocml_sin_f16(_Float16); +__attribute__((device)) __attribute__((const)) _Float16 __ocml_sqrt_f16(_Float16); +}\nextern "C" __attribute__((global))""" code_for_workitem = {"g": lambda x: f"__ockl_get_group_id({x})", "l": lambda x: f"__ockl_get_local_id({x})", "i": lambda x: f"(__ockl_get_group_id({x})*__ockl_get_local_size({x})+__ockl_get_local_id({x}))"} code_for_op = _make_hip_code_for_op() @@ -321,8 +321,10 @@ class HIPLanguage(CStyleLanguage): type_map = {dtypes.bfloat16: "hip_bfloat16"} def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str: - prefix = ["#define INFINITY (__builtin_inff())\n#define NAN (__builtin_nanf(\"\"))", - "typedef long unsigned int size_t;"] + prefix = ["#define INFINITY (__builtin_inff())", "#define NAN (__builtin_nanf(\"\"))", "typedef long unsigned int size_t;"] + vec_dts = [("float", "float", 2), ("float", "float", 4), ("float", "float", 8), ("signed int", "int", 4), ("signed int", "int", 2)] + + # TODO: add BF16 vec dts if any(uop.dtype == dtypes.bfloat16 for uop in uops): prefix.append(""" struct hip_bfloat16 { unsigned short data; @@ -339,8 +341,19 @@ struct hip_bfloat16 { static __attribute__((device)) bool operator<(hip_bfloat16 a, hip_bfloat16 b) { return ((float)a) < ((float)b); } static __attribute__((device)) bool operator==(hip_bfloat16 a, hip_bfloat16 b) { return ((float)a) == ((float)b); } """) - prefix.append('\n'.join(_make_hip_dtype(*x) for x in [("float", "float", 2), ("float", "float", 4), - ("signed int", "int", 4), ("signed int", "int", 2)])) + + if any(uop.dtype == dtypes.half for uop in uops): + prefix.append("#define half _Float16") + vec_dts += [("_Float16", "half", 2), ("_Float16", "half", 4), ("_Float16", "half", 8), ("_Float16", "half", 16)] + + prefix += [_make_hip_dtype(*x) for x in vec_dts] + + for arg in set([uop.arg for uop in uops if uop.uop == UOps.WMMA]): # TODO: handle TCs f32_bf16 and bf16_bf16 w/ wrapper + if arg[3] == dtypes.float: prefix.append(f"#define __{arg[0]} __builtin_amdgcn_wmma_f32_16x16x16_f16_w32") + else: prefix.append(f"static __attribute__((device)) half8 __{arg[0]}"+"""(half16 a, half16 b, half8 c) { + half16 c_frag = {}; half8 d; for (int n = 0; n < 8; n++) { c_frag[n*2] = c[n]; } + c_frag = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(a, b, c_frag, false); + for (int n = 0; n < 8; n++) { d[n] = c_frag[n*2]; } return d;\n}""") return super().render_kernel(function_name, kernel, bufs, uops, prefix) def get_kernel_modifier(self, uops:UOpGraph) -> str: diff --git a/tinygrad/runtime/ops_python.py b/tinygrad/runtime/ops_python.py index 764a52f5..0cad1b91 100644 --- a/tinygrad/runtime/ops_python.py +++ b/tinygrad/runtime/ops_python.py @@ -145,13 +145,14 @@ class PythonProgram: out[elem_idx][goff+lane_id] += sum(a_elem(inp[0], _k, c_j, goff) * b_elem(inp[1], c_i, _k, goff) for _k in range(K)) return out - if arg.startswith('__metal_wmma'): + # TODO: refactor these to a shared TensorCoreLayout in kernel.py + if arg[5] == "METAL": # A (2 elements on 32 threads): row major def a_b_elem(x, i, j, goff): return x[(i%2)][goff+(i//2)%2+(j%4)*2+(i//4)*8+(j//4)*16] # (i, j), C, D (2 elements on 32 threads): row major same as A/B def c_map(lane, elem): return (elem + ((lane%2)*2) + ((lane//8)%2)*4, ((lane//2)%4) + (lane//16)*4) ul[i] = wmma_helper(32, 8, 2, 2, 2, a_b_elem, a_b_elem, c_map) - elif arg == '__builtin_amdgcn_wmma_f32_16x16x16_f16_w32' or arg == '__hip_wmma_f16_f16': + elif arg[5] == "HSA": # A (16 elements on 32 threads): col major, lane 16-32 == lane 0-15 def a_elem(x, i, j, goff): assert x[i][goff+j] == x[i][goff+j+16], "warp elements not duplicated properly across lanes" @@ -160,7 +161,7 @@ class PythonProgram: def b_elem(x, i, j, goff): return a_elem(x, j, i, goff) def c_map(lane, elem): return (lane%16, lane//16+elem*2) # (i, j), C, D (8 elements on 32 threads): row major ul[i] = wmma_helper(32, 16, 16, 16, 8, a_elem, b_elem, c_map) - elif arg == '__cuda_mma_m16n8k16_f16_f32': + elif arg[5] == "CUDA": # A (8 elements on 32 threads) def a_elem(x, i, j, goff): return x[(i%2)+(j//8)*2+(i//8)*4][goff+((i//2)%4)+(j%8)*4] # B (4 elements on 32 threads)