From bbb0ad48003756111bdb720639db103b92ebfb4a Mon Sep 17 00:00:00 2001 From: Francis Lam Date: Mon, 22 Apr 2024 13:50:31 -0700 Subject: [PATCH] wmma: widen TC usage in search by using PADTO on TC axes when possible (#4216) * wmma: widen TC usage in search by using PADTO on TC axes when possible * test: start tests for the new padding TC behavior * search: upgrade padded TC search to TC_OPT >= 2 * test: add behavior and correctness test for padded TC added optional argument to apply_tensor_core to set TC_OPT level * linearizer: add tests for the PADTO behvaior and docs --- .github/workflows/benchmark.yml | 7 ++++++ extra/gemm/fuzz_matmul.py | 2 ++ test/test_linearizer.py | 43 +++++++++++++++++++++++++++++++++ tinygrad/codegen/kernel.py | 38 ++++++++++++++++++++++------- tinygrad/features/search.py | 2 +- 5 files changed, 82 insertions(+), 10 deletions(-) diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 9b043fda..af4f4c98 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -38,6 +38,8 @@ jobs: run: | DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul.txt DEBUG=2 HALF=1 python3 extra/gemm/simple_matmul.py | tee matmul_half.txt + - name: Fuzz Padded Tensor Core GEMM + run: METAL=1 M_START=6 M_STOP=10 M_STEP=1 N_START=6 N_STOP=10 N_STEP=1 K_START=6 K_STOP=24 K_STEP=1 TC_OPT=2 DEBUG=2 python3 ./extra/gemm/fuzz_matmul.py - name: Run LLaMA run: | JIT=0 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_unjitted.txt @@ -116,6 +118,8 @@ jobs: CUDA=1 BFLOAT16=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul_bfloat16.txt - name: Run Tensor Core GEMM (PTX) run: CUDA=1 PTX=1 HALF=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul_ptx.txt + - name: Fuzz Padded Tensor Core GEMM(CUDA) + run: CUDA=1 M_START=12 M_STOP=20 M_STEP=1 N_START=6 N_STOP=10 N_STEP=1 K_START=28 K_STOP=36 K_STEP=1 HALF=1 TC_OPT=2 python3 ./extra/gemm/fuzz_matmul.py - name: Run LLaMA run: | CUDA=1 JIT=0 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_unjitted.txt @@ -197,6 +201,9 @@ jobs: run: HSA=1 HALF=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul.txt - name: Run Tensor Core GEMM (KFD) run: KFD=1 HALF=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul_kfd.txt + # TODO: AMD compiler bug causes this to fail + #- name: Fuzz Padded Tensor Core GEMM + # run: HSA=1 M_START=12 M_STOP=20 M_STEP=1 N_START=12 N_STOP=20 N_STEP=1 K_START=28 K_STOP=36 K_STEP=1 HALF=1 TC_OPT=2 DEBUG=2 python3 ./extra/gemm/fuzz_matmul.py - name: Run Stable Diffusion run: HSA=1 python3 examples/stable_diffusion.py --seed 0 --noshow --timing | tee sd.txt - name: Run LLaMA 7B diff --git a/extra/gemm/fuzz_matmul.py b/extra/gemm/fuzz_matmul.py index 4a017400..9d22d553 100644 --- a/extra/gemm/fuzz_matmul.py +++ b/extra/gemm/fuzz_matmul.py @@ -40,3 +40,5 @@ if __name__ == "__main__": pass print(f"failed sizes: {failed}") print(f"num failures: {len(failed)}") + if len(failed) > 0: + raise RuntimeError(f"failed on {len(failed)} kernels") \ No newline at end of file diff --git a/test/test_linearizer.py b/test/test_linearizer.py index 11c30cc8..dc3bcc7e 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -196,6 +196,49 @@ class TestLinearizer(unittest.TestCase): else: tc_atol, tc_rtol = 5e-3, 1e-4 np.testing.assert_allclose(np_c, out, atol=tc_atol, rtol=tc_rtol) + def test_tensor_cores_padded(self): + if not Device[Device.DEFAULT].compiler.compiler_opts.has_tensor_cores: + self.skipTest("device doesn't have tensor cores") + for tc in tensor_cores[Device[Device.DEFAULT].compiler.compiler_opts.device]: + if getenv("EMULATE_CUDA") and (tc.dtype_in == dtypes.bfloat16 or tc.dtype_out == dtypes.bfloat16): continue + pad = 1 + + def ensure_uops_and_opts_count(m:int, k:int, n:int, tc_opt:int, ensure_triggered:bool=True): + a, b = Tensor.rand(m, k, dtype=tc.dtype_in), Tensor.rand(k, n, dtype=tc.dtype_in) + r = a.matmul(b, acc_dtype=tc.dtype_out) + sched = create_schedule([r.lazydata]) + realized_ast = sched[-1].ast[0] + k = Linearizer(realized_ast) + k.apply_tensor_cores(1, tc_opt=tc_opt) + k.linearize() + wmmas = len([uop for uop in k.uops if uop.uop is UOps.WMMA]) + tcs = len([x for x in k.applied_opts if x.op is OptOps.TC]) + if ensure_triggered: + assert wmmas > 0, "tensor core not triggered" + assert tcs == 1, "tensor core opt not included" + else: + assert wmmas == 0, "tensor core is incorrectly triggered" + assert tcs == 0, "tensor core opt is incorrectly included" + + # check that TC is triggered for TC_OPT=2 + ensure_uops_and_opts_count(tc.dims[0]+pad, tc.dims[2]+pad, tc.dims[1]+pad, tc_opt=2, ensure_triggered=True) + + # check that TC is not triggered for TC_OPT<2 + ensure_uops_and_opts_count(tc.dims[0]+pad, tc.dims[2]+pad, tc.dims[1]+pad, tc_opt=1, ensure_triggered=False) + + # check excessive padding doesn't trigger padded TC in TC_OPT=2 + ensure_uops_and_opts_count(tc.dims[0]//2, tc.dims[2], tc.dims[1], tc_opt=2, ensure_triggered=False) + ensure_uops_and_opts_count(tc.dims[0], tc.dims[2]//2, tc.dims[1], tc_opt=2, ensure_triggered=False) + ensure_uops_and_opts_count(tc.dims[0], tc.dims[2], tc.dims[1]//2, tc_opt=2, ensure_triggered=False) + + # check correctness + a, b = Tensor.rand(tc.dims[1]+pad, tc.dims[2]+pad, dtype=tc.dtype_in), Tensor.rand(tc.dims[2]+pad, tc.dims[0]+pad, dtype=tc.dtype_in) + r = a.matmul(b, acc_dtype=tc.dtype_out) + (atol, rtol) = ((0.25, 0.01) if tc.dtype_out == dtypes.half else (3e-2, 1e-3)) if tc.dtype_in == dtypes.half else (1e-4, 1e-4) + helper_linearizer_opt(r, [ + [Opt(OptOps.TC, axis=0, amt=2)], + ], atol=atol, rtol=rtol) + def test_limit_dims_to_max_5d_global(self): t = Tensor.empty(3, 4, 5, 6, 7).pad(((1, 1), (1, 1), (1, 1), (1, 1), (1, 1))) + 1 sched = [si for si in create_schedule([t.lazydata]) if si.ast[0].op not in LoadOps] diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index bd539a37..cfc897ed 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -347,21 +347,25 @@ class Kernel: return None if (buf0:=buf_index(mul_op.src[0])) is None or (buf1:=buf_index(mul_op.src[1])) is None: continue - buf0_strides, buf1_strides, reduce_sz = self.sts[buf0].real_strides(), self.sts[buf1].real_strides(), self.full_shape[self.first_reduce] - axis_buf0 = [(i,self.full_shape[i],buf1_strides[i]) for i,s in enumerate(buf0_strides[:self.first_reduce]) if s == 0 and self.full_shape[i]%tc.dims[0] == 0] # noqa: E501 - axis_buf1 = [(i,self.full_shape[i],buf0_strides[i]) for i,s in enumerate(buf1_strides[:self.first_reduce]) if s == 0 and self.full_shape[i]%tc.dims[1] == 0] # noqa: E501 - if not(axis_buf0 and axis_buf1 and reduce_sz%tc.dims[2] == 0 and reduce_sz >= tc.dims[2]): continue - if not((self.shape_len-self.first_reduce) == 1 or (opt_level >= 1)): continue + buf0_strides, buf1_strides = self.sts[buf0].real_strides(), self.sts[buf1].real_strides() + axis_buf0 = [(i,self.full_shape[i],buf1_strides[i]) for i,s in enumerate(buf0_strides[:self.first_reduce]) if s == 0] + axis_buf1 = [(i,self.full_shape[i],buf0_strides[i]) for i,s in enumerate(buf1_strides[:self.first_reduce]) if s == 0] + if not(axis_buf0 and axis_buf1 and ((self.shape_len-self.first_reduce) == 1 or (opt_level >= 1))): continue axis_choices = list(itertools.product(axis_buf0, axis_buf1)) if not(axis < len(axis_choices)): continue s0, s1 = axis_choices[-(axis+1)][0][0], axis_choices[-(axis+1)][1][0] # s0 is n, s1 is m - assert s0 != s1 and self.full_shape[s0]%tc.dims[0] == 0 and self.full_shape[s1]%tc.dims[1] == 0 + axis_pads = [(x, tc.dims[i]) for i, x in enumerate([s0, s1, self.first_reduce]) if self.full_shape[x]%tc.dims[i] != 0] + if axis_pads and (opt_level < 2): continue # tensor core -- unroll the reduce dim, upcast input, then create the correct thread pattern - if DEBUG >= 3: print("TENSOR CORES", axis_buf0, axis_buf1, tc) self.tensor_core_opts = (tc_opts:=TensorCoreOptions(bufs=(buf0, buf1), axes=[s0, s1], axes_exist=[True, True])) + + # attempt to pad the tensor axes that require it + try: + for axis, dim in axis_pads: self.apply_opt(Opt(OptOps.PADTO, axis, dim), append_opt=False) # PADTO might fail + except KernelOptError: continue self.apply_opt(Opt(OptOps.UNROLL, 0, tc.dims[2]), append_opt=False) for i, sz in enumerate([prod(x) for x in [[x[1] for x in tc.threads if x[0]==dim] for dim in range(2)]]): # upcast non-local'd N, M if tc.dims[i] > sz: self.apply_opt(Opt(OptOps.UPCAST, tc_opts.axes[i], tc.dims[i]//sz), append_opt=False) @@ -369,14 +373,30 @@ class Kernel: self.apply_opt(Opt(OptOps.LOCAL, tc_opts.axes[tc_dim], tc_amt), append_opt=False) # assert tensor core + if DEBUG >= 3: print("TENSOR CORES", axis_buf0, axis_buf1, tc) if use_tensor_cores == 1: self.tensor_core = tc # TC=2 will do the shape ops without the WMMA return True return False - def apply_tensor_cores(self, use_tensor_cores=1, extra_opts:Optional[List[Opt]]=None) -> bool: + + def apply_tensor_cores(self, use_tensor_cores=1, extra_opts:Optional[List[Opt]]=None, tc_opt:Optional[int]=getenv("TC_OPT")) -> bool: + """ Attempts to apply a tensor core optimization to the kernel. If one exists and applies properly, return true, otherwise return false. + Tensor cores are optimized instructions that matrix multiply-accumulate across a wave of threads: D(M, N) = A(M, K) * B(K, N) + C(M, N). + + Keyword arguments: + use_tensor_cores -- controls how tensor cores are applied (default 1) + 0: will disable any tensor core matching + 1: enable tensor cores + 2: apply tensor core shape but don't use UOp.WMMA + extra_opts -- additional Opt's to apply after the tensor core instead of the hand-coded additional Opt's (default None) + tc_opt -- controls which kinds of kernels may be eligible for tensor cores application (default 2 during BEAM, 0 otherwise) + 0: applies to only kernels with a single reduce axis and direct BufferOps.LOAD into BinaryOps.MUL + 1: allows kernels with multiple reduce axes and also multiplication of UnaryOps.CAST'd buffers + 2: allows kernels with M, N, K axes that are not multiples of the tensor core dimensions by applying padding those axes as needed + """ if not self.opts.has_tensor_cores and use_tensor_cores != 2: return False try: # check TC first and apply hand-coded opts if successful - self.apply_opt(Opt(OptOps.TC, 0, 0)) + self.apply_opt(Opt(OptOps.TC, 0, tc_opt)) if (tc_opts:=self.tensor_core_opts) is not None: if extra_opts is not None: diff --git a/tinygrad/features/search.py b/tinygrad/features/search.py index 18da92a3..80cc2814 100644 --- a/tinygrad/features/search.py +++ b/tinygrad/features/search.py @@ -17,7 +17,7 @@ actions += [Opt(op=OptOps.GROUPTOP, axis=axis, amt=amt) for amt in [13,16,29,32, actions += [Opt(op=OptOps.GROUP, axis=axis, amt=amt) for amt in [0,4,8,16] for axis in range(3)] actions += [Opt(op=OptOps.PADTO, axis=axis, amt=amt) for amt in [32] for axis in range(7)] actions += [Opt(op=OptOps.LOCAL, axis=0, amt=32), Opt(op=OptOps.UPCASTMID, axis=1, amt=4), Opt(op=OptOps.TC, axis=0, amt=0)] -actions += [Opt(op=OptOps.TC, axis=axis, amt=1) for axis in range(4)] +actions += [Opt(op=OptOps.TC, axis=axis, amt=getenv("TC_OPT", 2)) for axis in range(4)] if getenv("NOLOCALS"): actions += [Opt(op=OptOps.NOLOCALS)] def _get_test_global_size(global_size, max_global_size, var_vals):