mirror of https://github.com/commaai/tinygrad.git
wmma: widen TC usage in search by using PADTO on TC axes when possible (#4216)
* wmma: widen TC usage in search by using PADTO on TC axes when possible * test: start tests for the new padding TC behavior * search: upgrade padded TC search to TC_OPT >= 2 * test: add behavior and correctness test for padded TC added optional argument to apply_tensor_core to set TC_OPT level * linearizer: add tests for the PADTO behvaior and docs
This commit is contained in:
parent
9e53d6cffa
commit
bbb0ad4800
|
@ -38,6 +38,8 @@ jobs:
|
|||
run: |
|
||||
DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul.txt
|
||||
DEBUG=2 HALF=1 python3 extra/gemm/simple_matmul.py | tee matmul_half.txt
|
||||
- name: Fuzz Padded Tensor Core GEMM
|
||||
run: METAL=1 M_START=6 M_STOP=10 M_STEP=1 N_START=6 N_STOP=10 N_STEP=1 K_START=6 K_STOP=24 K_STEP=1 TC_OPT=2 DEBUG=2 python3 ./extra/gemm/fuzz_matmul.py
|
||||
- name: Run LLaMA
|
||||
run: |
|
||||
JIT=0 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_unjitted.txt
|
||||
|
@ -116,6 +118,8 @@ jobs:
|
|||
CUDA=1 BFLOAT16=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul_bfloat16.txt
|
||||
- name: Run Tensor Core GEMM (PTX)
|
||||
run: CUDA=1 PTX=1 HALF=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul_ptx.txt
|
||||
- name: Fuzz Padded Tensor Core GEMM(CUDA)
|
||||
run: CUDA=1 M_START=12 M_STOP=20 M_STEP=1 N_START=6 N_STOP=10 N_STEP=1 K_START=28 K_STOP=36 K_STEP=1 HALF=1 TC_OPT=2 python3 ./extra/gemm/fuzz_matmul.py
|
||||
- name: Run LLaMA
|
||||
run: |
|
||||
CUDA=1 JIT=0 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_unjitted.txt
|
||||
|
@ -197,6 +201,9 @@ jobs:
|
|||
run: HSA=1 HALF=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul.txt
|
||||
- name: Run Tensor Core GEMM (KFD)
|
||||
run: KFD=1 HALF=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul_kfd.txt
|
||||
# TODO: AMD compiler bug causes this to fail
|
||||
#- name: Fuzz Padded Tensor Core GEMM
|
||||
# run: HSA=1 M_START=12 M_STOP=20 M_STEP=1 N_START=12 N_STOP=20 N_STEP=1 K_START=28 K_STOP=36 K_STEP=1 HALF=1 TC_OPT=2 DEBUG=2 python3 ./extra/gemm/fuzz_matmul.py
|
||||
- name: Run Stable Diffusion
|
||||
run: HSA=1 python3 examples/stable_diffusion.py --seed 0 --noshow --timing | tee sd.txt
|
||||
- name: Run LLaMA 7B
|
||||
|
|
|
@ -40,3 +40,5 @@ if __name__ == "__main__":
|
|||
pass
|
||||
print(f"failed sizes: {failed}")
|
||||
print(f"num failures: {len(failed)}")
|
||||
if len(failed) > 0:
|
||||
raise RuntimeError(f"failed on {len(failed)} kernels")
|
|
@ -196,6 +196,49 @@ class TestLinearizer(unittest.TestCase):
|
|||
else: tc_atol, tc_rtol = 5e-3, 1e-4
|
||||
np.testing.assert_allclose(np_c, out, atol=tc_atol, rtol=tc_rtol)
|
||||
|
||||
def test_tensor_cores_padded(self):
|
||||
if not Device[Device.DEFAULT].compiler.compiler_opts.has_tensor_cores:
|
||||
self.skipTest("device doesn't have tensor cores")
|
||||
for tc in tensor_cores[Device[Device.DEFAULT].compiler.compiler_opts.device]:
|
||||
if getenv("EMULATE_CUDA") and (tc.dtype_in == dtypes.bfloat16 or tc.dtype_out == dtypes.bfloat16): continue
|
||||
pad = 1
|
||||
|
||||
def ensure_uops_and_opts_count(m:int, k:int, n:int, tc_opt:int, ensure_triggered:bool=True):
|
||||
a, b = Tensor.rand(m, k, dtype=tc.dtype_in), Tensor.rand(k, n, dtype=tc.dtype_in)
|
||||
r = a.matmul(b, acc_dtype=tc.dtype_out)
|
||||
sched = create_schedule([r.lazydata])
|
||||
realized_ast = sched[-1].ast[0]
|
||||
k = Linearizer(realized_ast)
|
||||
k.apply_tensor_cores(1, tc_opt=tc_opt)
|
||||
k.linearize()
|
||||
wmmas = len([uop for uop in k.uops if uop.uop is UOps.WMMA])
|
||||
tcs = len([x for x in k.applied_opts if x.op is OptOps.TC])
|
||||
if ensure_triggered:
|
||||
assert wmmas > 0, "tensor core not triggered"
|
||||
assert tcs == 1, "tensor core opt not included"
|
||||
else:
|
||||
assert wmmas == 0, "tensor core is incorrectly triggered"
|
||||
assert tcs == 0, "tensor core opt is incorrectly included"
|
||||
|
||||
# check that TC is triggered for TC_OPT=2
|
||||
ensure_uops_and_opts_count(tc.dims[0]+pad, tc.dims[2]+pad, tc.dims[1]+pad, tc_opt=2, ensure_triggered=True)
|
||||
|
||||
# check that TC is not triggered for TC_OPT<2
|
||||
ensure_uops_and_opts_count(tc.dims[0]+pad, tc.dims[2]+pad, tc.dims[1]+pad, tc_opt=1, ensure_triggered=False)
|
||||
|
||||
# check excessive padding doesn't trigger padded TC in TC_OPT=2
|
||||
ensure_uops_and_opts_count(tc.dims[0]//2, tc.dims[2], tc.dims[1], tc_opt=2, ensure_triggered=False)
|
||||
ensure_uops_and_opts_count(tc.dims[0], tc.dims[2]//2, tc.dims[1], tc_opt=2, ensure_triggered=False)
|
||||
ensure_uops_and_opts_count(tc.dims[0], tc.dims[2], tc.dims[1]//2, tc_opt=2, ensure_triggered=False)
|
||||
|
||||
# check correctness
|
||||
a, b = Tensor.rand(tc.dims[1]+pad, tc.dims[2]+pad, dtype=tc.dtype_in), Tensor.rand(tc.dims[2]+pad, tc.dims[0]+pad, dtype=tc.dtype_in)
|
||||
r = a.matmul(b, acc_dtype=tc.dtype_out)
|
||||
(atol, rtol) = ((0.25, 0.01) if tc.dtype_out == dtypes.half else (3e-2, 1e-3)) if tc.dtype_in == dtypes.half else (1e-4, 1e-4)
|
||||
helper_linearizer_opt(r, [
|
||||
[Opt(OptOps.TC, axis=0, amt=2)],
|
||||
], atol=atol, rtol=rtol)
|
||||
|
||||
def test_limit_dims_to_max_5d_global(self):
|
||||
t = Tensor.empty(3, 4, 5, 6, 7).pad(((1, 1), (1, 1), (1, 1), (1, 1), (1, 1))) + 1
|
||||
sched = [si for si in create_schedule([t.lazydata]) if si.ast[0].op not in LoadOps]
|
||||
|
|
|
@ -347,21 +347,25 @@ class Kernel:
|
|||
return None
|
||||
if (buf0:=buf_index(mul_op.src[0])) is None or (buf1:=buf_index(mul_op.src[1])) is None: continue
|
||||
|
||||
buf0_strides, buf1_strides, reduce_sz = self.sts[buf0].real_strides(), self.sts[buf1].real_strides(), self.full_shape[self.first_reduce]
|
||||
axis_buf0 = [(i,self.full_shape[i],buf1_strides[i]) for i,s in enumerate(buf0_strides[:self.first_reduce]) if s == 0 and self.full_shape[i]%tc.dims[0] == 0] # noqa: E501
|
||||
axis_buf1 = [(i,self.full_shape[i],buf0_strides[i]) for i,s in enumerate(buf1_strides[:self.first_reduce]) if s == 0 and self.full_shape[i]%tc.dims[1] == 0] # noqa: E501
|
||||
if not(axis_buf0 and axis_buf1 and reduce_sz%tc.dims[2] == 0 and reduce_sz >= tc.dims[2]): continue
|
||||
if not((self.shape_len-self.first_reduce) == 1 or (opt_level >= 1)): continue
|
||||
buf0_strides, buf1_strides = self.sts[buf0].real_strides(), self.sts[buf1].real_strides()
|
||||
axis_buf0 = [(i,self.full_shape[i],buf1_strides[i]) for i,s in enumerate(buf0_strides[:self.first_reduce]) if s == 0]
|
||||
axis_buf1 = [(i,self.full_shape[i],buf0_strides[i]) for i,s in enumerate(buf1_strides[:self.first_reduce]) if s == 0]
|
||||
if not(axis_buf0 and axis_buf1 and ((self.shape_len-self.first_reduce) == 1 or (opt_level >= 1))): continue
|
||||
|
||||
axis_choices = list(itertools.product(axis_buf0, axis_buf1))
|
||||
if not(axis < len(axis_choices)): continue
|
||||
|
||||
s0, s1 = axis_choices[-(axis+1)][0][0], axis_choices[-(axis+1)][1][0] # s0 is n, s1 is m
|
||||
assert s0 != s1 and self.full_shape[s0]%tc.dims[0] == 0 and self.full_shape[s1]%tc.dims[1] == 0
|
||||
axis_pads = [(x, tc.dims[i]) for i, x in enumerate([s0, s1, self.first_reduce]) if self.full_shape[x]%tc.dims[i] != 0]
|
||||
if axis_pads and (opt_level < 2): continue
|
||||
|
||||
# tensor core -- unroll the reduce dim, upcast input, then create the correct thread pattern
|
||||
if DEBUG >= 3: print("TENSOR CORES", axis_buf0, axis_buf1, tc)
|
||||
self.tensor_core_opts = (tc_opts:=TensorCoreOptions(bufs=(buf0, buf1), axes=[s0, s1], axes_exist=[True, True]))
|
||||
|
||||
# attempt to pad the tensor axes that require it
|
||||
try:
|
||||
for axis, dim in axis_pads: self.apply_opt(Opt(OptOps.PADTO, axis, dim), append_opt=False) # PADTO might fail
|
||||
except KernelOptError: continue
|
||||
self.apply_opt(Opt(OptOps.UNROLL, 0, tc.dims[2]), append_opt=False)
|
||||
for i, sz in enumerate([prod(x) for x in [[x[1] for x in tc.threads if x[0]==dim] for dim in range(2)]]): # upcast non-local'd N, M
|
||||
if tc.dims[i] > sz: self.apply_opt(Opt(OptOps.UPCAST, tc_opts.axes[i], tc.dims[i]//sz), append_opt=False)
|
||||
|
@ -369,14 +373,30 @@ class Kernel:
|
|||
self.apply_opt(Opt(OptOps.LOCAL, tc_opts.axes[tc_dim], tc_amt), append_opt=False)
|
||||
|
||||
# assert tensor core
|
||||
if DEBUG >= 3: print("TENSOR CORES", axis_buf0, axis_buf1, tc)
|
||||
if use_tensor_cores == 1: self.tensor_core = tc # TC=2 will do the shape ops without the WMMA
|
||||
return True
|
||||
return False
|
||||
|
||||
def apply_tensor_cores(self, use_tensor_cores=1, extra_opts:Optional[List[Opt]]=None) -> bool:
|
||||
|
||||
def apply_tensor_cores(self, use_tensor_cores=1, extra_opts:Optional[List[Opt]]=None, tc_opt:Optional[int]=getenv("TC_OPT")) -> bool:
|
||||
""" Attempts to apply a tensor core optimization to the kernel. If one exists and applies properly, return true, otherwise return false.
|
||||
Tensor cores are optimized instructions that matrix multiply-accumulate across a wave of threads: D(M, N) = A(M, K) * B(K, N) + C(M, N).
|
||||
|
||||
Keyword arguments:
|
||||
use_tensor_cores -- controls how tensor cores are applied (default 1)
|
||||
0: will disable any tensor core matching
|
||||
1: enable tensor cores
|
||||
2: apply tensor core shape but don't use UOp.WMMA
|
||||
extra_opts -- additional Opt's to apply after the tensor core instead of the hand-coded additional Opt's (default None)
|
||||
tc_opt -- controls which kinds of kernels may be eligible for tensor cores application (default 2 during BEAM, 0 otherwise)
|
||||
0: applies to only kernels with a single reduce axis and direct BufferOps.LOAD into BinaryOps.MUL
|
||||
1: allows kernels with multiple reduce axes and also multiplication of UnaryOps.CAST'd buffers
|
||||
2: allows kernels with M, N, K axes that are not multiples of the tensor core dimensions by applying padding those axes as needed
|
||||
"""
|
||||
if not self.opts.has_tensor_cores and use_tensor_cores != 2: return False
|
||||
try: # check TC first and apply hand-coded opts if successful
|
||||
self.apply_opt(Opt(OptOps.TC, 0, 0))
|
||||
self.apply_opt(Opt(OptOps.TC, 0, tc_opt))
|
||||
|
||||
if (tc_opts:=self.tensor_core_opts) is not None:
|
||||
if extra_opts is not None:
|
||||
|
|
|
@ -17,7 +17,7 @@ actions += [Opt(op=OptOps.GROUPTOP, axis=axis, amt=amt) for amt in [13,16,29,32,
|
|||
actions += [Opt(op=OptOps.GROUP, axis=axis, amt=amt) for amt in [0,4,8,16] for axis in range(3)]
|
||||
actions += [Opt(op=OptOps.PADTO, axis=axis, amt=amt) for amt in [32] for axis in range(7)]
|
||||
actions += [Opt(op=OptOps.LOCAL, axis=0, amt=32), Opt(op=OptOps.UPCASTMID, axis=1, amt=4), Opt(op=OptOps.TC, axis=0, amt=0)]
|
||||
actions += [Opt(op=OptOps.TC, axis=axis, amt=1) for axis in range(4)]
|
||||
actions += [Opt(op=OptOps.TC, axis=axis, amt=getenv("TC_OPT", 2)) for axis in range(4)]
|
||||
if getenv("NOLOCALS"): actions += [Opt(op=OptOps.NOLOCALS)]
|
||||
|
||||
def _get_test_global_size(global_size, max_global_size, var_vals):
|
||||
|
|
Loading…
Reference in New Issue