wmma: widen TC usage in search by using PADTO on TC axes when possible (#4216)

* wmma: widen TC usage in search by using PADTO on TC axes when possible

* test: start tests for the new padding TC behavior

* search: upgrade padded TC search to TC_OPT >= 2

* test: add behavior and correctness test for padded TC

added optional argument to apply_tensor_core to set TC_OPT level

* linearizer: add tests for the PADTO behvaior and docs
This commit is contained in:
Francis Lam 2024-04-22 13:50:31 -07:00 committed by GitHub
parent 9e53d6cffa
commit bbb0ad4800
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 82 additions and 10 deletions

View File

@ -38,6 +38,8 @@ jobs:
run: | run: |
DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul.txt DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul.txt
DEBUG=2 HALF=1 python3 extra/gemm/simple_matmul.py | tee matmul_half.txt DEBUG=2 HALF=1 python3 extra/gemm/simple_matmul.py | tee matmul_half.txt
- name: Fuzz Padded Tensor Core GEMM
run: METAL=1 M_START=6 M_STOP=10 M_STEP=1 N_START=6 N_STOP=10 N_STEP=1 K_START=6 K_STOP=24 K_STEP=1 TC_OPT=2 DEBUG=2 python3 ./extra/gemm/fuzz_matmul.py
- name: Run LLaMA - name: Run LLaMA
run: | run: |
JIT=0 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_unjitted.txt JIT=0 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_unjitted.txt
@ -116,6 +118,8 @@ jobs:
CUDA=1 BFLOAT16=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul_bfloat16.txt CUDA=1 BFLOAT16=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul_bfloat16.txt
- name: Run Tensor Core GEMM (PTX) - name: Run Tensor Core GEMM (PTX)
run: CUDA=1 PTX=1 HALF=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul_ptx.txt run: CUDA=1 PTX=1 HALF=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul_ptx.txt
- name: Fuzz Padded Tensor Core GEMM(CUDA)
run: CUDA=1 M_START=12 M_STOP=20 M_STEP=1 N_START=6 N_STOP=10 N_STEP=1 K_START=28 K_STOP=36 K_STEP=1 HALF=1 TC_OPT=2 python3 ./extra/gemm/fuzz_matmul.py
- name: Run LLaMA - name: Run LLaMA
run: | run: |
CUDA=1 JIT=0 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_unjitted.txt CUDA=1 JIT=0 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_unjitted.txt
@ -197,6 +201,9 @@ jobs:
run: HSA=1 HALF=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul.txt run: HSA=1 HALF=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul.txt
- name: Run Tensor Core GEMM (KFD) - name: Run Tensor Core GEMM (KFD)
run: KFD=1 HALF=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul_kfd.txt run: KFD=1 HALF=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul_kfd.txt
# TODO: AMD compiler bug causes this to fail
#- name: Fuzz Padded Tensor Core GEMM
# run: HSA=1 M_START=12 M_STOP=20 M_STEP=1 N_START=12 N_STOP=20 N_STEP=1 K_START=28 K_STOP=36 K_STEP=1 HALF=1 TC_OPT=2 DEBUG=2 python3 ./extra/gemm/fuzz_matmul.py
- name: Run Stable Diffusion - name: Run Stable Diffusion
run: HSA=1 python3 examples/stable_diffusion.py --seed 0 --noshow --timing | tee sd.txt run: HSA=1 python3 examples/stable_diffusion.py --seed 0 --noshow --timing | tee sd.txt
- name: Run LLaMA 7B - name: Run LLaMA 7B

View File

@ -40,3 +40,5 @@ if __name__ == "__main__":
pass pass
print(f"failed sizes: {failed}") print(f"failed sizes: {failed}")
print(f"num failures: {len(failed)}") print(f"num failures: {len(failed)}")
if len(failed) > 0:
raise RuntimeError(f"failed on {len(failed)} kernels")

View File

@ -196,6 +196,49 @@ class TestLinearizer(unittest.TestCase):
else: tc_atol, tc_rtol = 5e-3, 1e-4 else: tc_atol, tc_rtol = 5e-3, 1e-4
np.testing.assert_allclose(np_c, out, atol=tc_atol, rtol=tc_rtol) np.testing.assert_allclose(np_c, out, atol=tc_atol, rtol=tc_rtol)
def test_tensor_cores_padded(self):
if not Device[Device.DEFAULT].compiler.compiler_opts.has_tensor_cores:
self.skipTest("device doesn't have tensor cores")
for tc in tensor_cores[Device[Device.DEFAULT].compiler.compiler_opts.device]:
if getenv("EMULATE_CUDA") and (tc.dtype_in == dtypes.bfloat16 or tc.dtype_out == dtypes.bfloat16): continue
pad = 1
def ensure_uops_and_opts_count(m:int, k:int, n:int, tc_opt:int, ensure_triggered:bool=True):
a, b = Tensor.rand(m, k, dtype=tc.dtype_in), Tensor.rand(k, n, dtype=tc.dtype_in)
r = a.matmul(b, acc_dtype=tc.dtype_out)
sched = create_schedule([r.lazydata])
realized_ast = sched[-1].ast[0]
k = Linearizer(realized_ast)
k.apply_tensor_cores(1, tc_opt=tc_opt)
k.linearize()
wmmas = len([uop for uop in k.uops if uop.uop is UOps.WMMA])
tcs = len([x for x in k.applied_opts if x.op is OptOps.TC])
if ensure_triggered:
assert wmmas > 0, "tensor core not triggered"
assert tcs == 1, "tensor core opt not included"
else:
assert wmmas == 0, "tensor core is incorrectly triggered"
assert tcs == 0, "tensor core opt is incorrectly included"
# check that TC is triggered for TC_OPT=2
ensure_uops_and_opts_count(tc.dims[0]+pad, tc.dims[2]+pad, tc.dims[1]+pad, tc_opt=2, ensure_triggered=True)
# check that TC is not triggered for TC_OPT<2
ensure_uops_and_opts_count(tc.dims[0]+pad, tc.dims[2]+pad, tc.dims[1]+pad, tc_opt=1, ensure_triggered=False)
# check excessive padding doesn't trigger padded TC in TC_OPT=2
ensure_uops_and_opts_count(tc.dims[0]//2, tc.dims[2], tc.dims[1], tc_opt=2, ensure_triggered=False)
ensure_uops_and_opts_count(tc.dims[0], tc.dims[2]//2, tc.dims[1], tc_opt=2, ensure_triggered=False)
ensure_uops_and_opts_count(tc.dims[0], tc.dims[2], tc.dims[1]//2, tc_opt=2, ensure_triggered=False)
# check correctness
a, b = Tensor.rand(tc.dims[1]+pad, tc.dims[2]+pad, dtype=tc.dtype_in), Tensor.rand(tc.dims[2]+pad, tc.dims[0]+pad, dtype=tc.dtype_in)
r = a.matmul(b, acc_dtype=tc.dtype_out)
(atol, rtol) = ((0.25, 0.01) if tc.dtype_out == dtypes.half else (3e-2, 1e-3)) if tc.dtype_in == dtypes.half else (1e-4, 1e-4)
helper_linearizer_opt(r, [
[Opt(OptOps.TC, axis=0, amt=2)],
], atol=atol, rtol=rtol)
def test_limit_dims_to_max_5d_global(self): def test_limit_dims_to_max_5d_global(self):
t = Tensor.empty(3, 4, 5, 6, 7).pad(((1, 1), (1, 1), (1, 1), (1, 1), (1, 1))) + 1 t = Tensor.empty(3, 4, 5, 6, 7).pad(((1, 1), (1, 1), (1, 1), (1, 1), (1, 1))) + 1
sched = [si for si in create_schedule([t.lazydata]) if si.ast[0].op not in LoadOps] sched = [si for si in create_schedule([t.lazydata]) if si.ast[0].op not in LoadOps]

View File

@ -347,21 +347,25 @@ class Kernel:
return None return None
if (buf0:=buf_index(mul_op.src[0])) is None or (buf1:=buf_index(mul_op.src[1])) is None: continue if (buf0:=buf_index(mul_op.src[0])) is None or (buf1:=buf_index(mul_op.src[1])) is None: continue
buf0_strides, buf1_strides, reduce_sz = self.sts[buf0].real_strides(), self.sts[buf1].real_strides(), self.full_shape[self.first_reduce] buf0_strides, buf1_strides = self.sts[buf0].real_strides(), self.sts[buf1].real_strides()
axis_buf0 = [(i,self.full_shape[i],buf1_strides[i]) for i,s in enumerate(buf0_strides[:self.first_reduce]) if s == 0 and self.full_shape[i]%tc.dims[0] == 0] # noqa: E501 axis_buf0 = [(i,self.full_shape[i],buf1_strides[i]) for i,s in enumerate(buf0_strides[:self.first_reduce]) if s == 0]
axis_buf1 = [(i,self.full_shape[i],buf0_strides[i]) for i,s in enumerate(buf1_strides[:self.first_reduce]) if s == 0 and self.full_shape[i]%tc.dims[1] == 0] # noqa: E501 axis_buf1 = [(i,self.full_shape[i],buf0_strides[i]) for i,s in enumerate(buf1_strides[:self.first_reduce]) if s == 0]
if not(axis_buf0 and axis_buf1 and reduce_sz%tc.dims[2] == 0 and reduce_sz >= tc.dims[2]): continue if not(axis_buf0 and axis_buf1 and ((self.shape_len-self.first_reduce) == 1 or (opt_level >= 1))): continue
if not((self.shape_len-self.first_reduce) == 1 or (opt_level >= 1)): continue
axis_choices = list(itertools.product(axis_buf0, axis_buf1)) axis_choices = list(itertools.product(axis_buf0, axis_buf1))
if not(axis < len(axis_choices)): continue if not(axis < len(axis_choices)): continue
s0, s1 = axis_choices[-(axis+1)][0][0], axis_choices[-(axis+1)][1][0] # s0 is n, s1 is m s0, s1 = axis_choices[-(axis+1)][0][0], axis_choices[-(axis+1)][1][0] # s0 is n, s1 is m
assert s0 != s1 and self.full_shape[s0]%tc.dims[0] == 0 and self.full_shape[s1]%tc.dims[1] == 0 axis_pads = [(x, tc.dims[i]) for i, x in enumerate([s0, s1, self.first_reduce]) if self.full_shape[x]%tc.dims[i] != 0]
if axis_pads and (opt_level < 2): continue
# tensor core -- unroll the reduce dim, upcast input, then create the correct thread pattern # tensor core -- unroll the reduce dim, upcast input, then create the correct thread pattern
if DEBUG >= 3: print("TENSOR CORES", axis_buf0, axis_buf1, tc)
self.tensor_core_opts = (tc_opts:=TensorCoreOptions(bufs=(buf0, buf1), axes=[s0, s1], axes_exist=[True, True])) self.tensor_core_opts = (tc_opts:=TensorCoreOptions(bufs=(buf0, buf1), axes=[s0, s1], axes_exist=[True, True]))
# attempt to pad the tensor axes that require it
try:
for axis, dim in axis_pads: self.apply_opt(Opt(OptOps.PADTO, axis, dim), append_opt=False) # PADTO might fail
except KernelOptError: continue
self.apply_opt(Opt(OptOps.UNROLL, 0, tc.dims[2]), append_opt=False) self.apply_opt(Opt(OptOps.UNROLL, 0, tc.dims[2]), append_opt=False)
for i, sz in enumerate([prod(x) for x in [[x[1] for x in tc.threads if x[0]==dim] for dim in range(2)]]): # upcast non-local'd N, M for i, sz in enumerate([prod(x) for x in [[x[1] for x in tc.threads if x[0]==dim] for dim in range(2)]]): # upcast non-local'd N, M
if tc.dims[i] > sz: self.apply_opt(Opt(OptOps.UPCAST, tc_opts.axes[i], tc.dims[i]//sz), append_opt=False) if tc.dims[i] > sz: self.apply_opt(Opt(OptOps.UPCAST, tc_opts.axes[i], tc.dims[i]//sz), append_opt=False)
@ -369,14 +373,30 @@ class Kernel:
self.apply_opt(Opt(OptOps.LOCAL, tc_opts.axes[tc_dim], tc_amt), append_opt=False) self.apply_opt(Opt(OptOps.LOCAL, tc_opts.axes[tc_dim], tc_amt), append_opt=False)
# assert tensor core # assert tensor core
if DEBUG >= 3: print("TENSOR CORES", axis_buf0, axis_buf1, tc)
if use_tensor_cores == 1: self.tensor_core = tc # TC=2 will do the shape ops without the WMMA if use_tensor_cores == 1: self.tensor_core = tc # TC=2 will do the shape ops without the WMMA
return True return True
return False return False
def apply_tensor_cores(self, use_tensor_cores=1, extra_opts:Optional[List[Opt]]=None) -> bool:
def apply_tensor_cores(self, use_tensor_cores=1, extra_opts:Optional[List[Opt]]=None, tc_opt:Optional[int]=getenv("TC_OPT")) -> bool:
""" Attempts to apply a tensor core optimization to the kernel. If one exists and applies properly, return true, otherwise return false.
Tensor cores are optimized instructions that matrix multiply-accumulate across a wave of threads: D(M, N) = A(M, K) * B(K, N) + C(M, N).
Keyword arguments:
use_tensor_cores -- controls how tensor cores are applied (default 1)
0: will disable any tensor core matching
1: enable tensor cores
2: apply tensor core shape but don't use UOp.WMMA
extra_opts -- additional Opt's to apply after the tensor core instead of the hand-coded additional Opt's (default None)
tc_opt -- controls which kinds of kernels may be eligible for tensor cores application (default 2 during BEAM, 0 otherwise)
0: applies to only kernels with a single reduce axis and direct BufferOps.LOAD into BinaryOps.MUL
1: allows kernels with multiple reduce axes and also multiplication of UnaryOps.CAST'd buffers
2: allows kernels with M, N, K axes that are not multiples of the tensor core dimensions by applying padding those axes as needed
"""
if not self.opts.has_tensor_cores and use_tensor_cores != 2: return False if not self.opts.has_tensor_cores and use_tensor_cores != 2: return False
try: # check TC first and apply hand-coded opts if successful try: # check TC first and apply hand-coded opts if successful
self.apply_opt(Opt(OptOps.TC, 0, 0)) self.apply_opt(Opt(OptOps.TC, 0, tc_opt))
if (tc_opts:=self.tensor_core_opts) is not None: if (tc_opts:=self.tensor_core_opts) is not None:
if extra_opts is not None: if extra_opts is not None:

View File

@ -17,7 +17,7 @@ actions += [Opt(op=OptOps.GROUPTOP, axis=axis, amt=amt) for amt in [13,16,29,32,
actions += [Opt(op=OptOps.GROUP, axis=axis, amt=amt) for amt in [0,4,8,16] for axis in range(3)] actions += [Opt(op=OptOps.GROUP, axis=axis, amt=amt) for amt in [0,4,8,16] for axis in range(3)]
actions += [Opt(op=OptOps.PADTO, axis=axis, amt=amt) for amt in [32] for axis in range(7)] actions += [Opt(op=OptOps.PADTO, axis=axis, amt=amt) for amt in [32] for axis in range(7)]
actions += [Opt(op=OptOps.LOCAL, axis=0, amt=32), Opt(op=OptOps.UPCASTMID, axis=1, amt=4), Opt(op=OptOps.TC, axis=0, amt=0)] actions += [Opt(op=OptOps.LOCAL, axis=0, amt=32), Opt(op=OptOps.UPCASTMID, axis=1, amt=4), Opt(op=OptOps.TC, axis=0, amt=0)]
actions += [Opt(op=OptOps.TC, axis=axis, amt=1) for axis in range(4)] actions += [Opt(op=OptOps.TC, axis=axis, amt=getenv("TC_OPT", 2)) for axis in range(4)]
if getenv("NOLOCALS"): actions += [Opt(op=OptOps.NOLOCALS)] if getenv("NOLOCALS"): actions += [Opt(op=OptOps.NOLOCALS)]
def _get_test_global_size(global_size, max_global_size, var_vals): def _get_test_global_size(global_size, max_global_size, var_vals):