TC=2 still sets tensor cores (and TC=3 support for locals) (#5780)

* TC=2 still sets tensor cores

* add TC=3 support for using locals

* bugfix

* lines + TC=3 tests

* CUDA can use threads, fix fuzz linearizer
This commit is contained in:
George Hotz 2024-07-28 16:16:53 -07:00 committed by GitHub
parent 71a64d8252
commit 0392123e6e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 43 additions and 21 deletions

View File

@ -50,6 +50,11 @@ jobs:
PYTHONPATH=. DEBUG=2 EMULATE_METAL=1 FORWARD_ONLY=1 PYTHON=1 python3 ./test/test_linearizer.py TestLinearizer.test_tensor_cores
PYTHONPATH=. DEBUG=2 EMULATE_AMD=1 FORWARD_ONLY=1 PYTHON=1 python3 ./test/test_linearizer.py TestLinearizer.test_tensor_cores
PYTHONPATH=. DEBUG=2 EMULATE_CUDA=1 FORWARD_ONLY=1 PYTHON=1 python3 ./test/test_linearizer.py TestLinearizer.test_tensor_cores
- name: Test tensor cores (TC=3)
run: |
TC=3 DEBUG=3 EMULATE_METAL=1 FORWARD_ONLY=1 PYTHON=1 python3 test/test_ops.py TestOps.test_gemm
TC=3 PYTHONPATH=. DEBUG=3 EMULATE_AMD=1 FORWARD_ONLY=1 PYTHON=1 N=16 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py
TC=3 DEBUG=3 EMULATE_CUDA=1 FORWARD_ONLY=1 PYTHON=1 python3 test/test_ops.py TestOps.test_gemm_fp16
- name: Test dtype with Python emulator
run: DEBUG=1 PYTHONPATH=. PYTHON=1 python3 -m pytest -n=auto test/test_dtype.py test/test_dtype_alu.py
- name: Test ops with Python emulator

View File

@ -2,7 +2,7 @@ from __future__ import annotations
import itertools, functools
from dataclasses import replace
from collections import defaultdict
from typing import Optional, List, Tuple, cast, Dict, Union, Final, DefaultDict
from typing import Optional, List, Tuple, cast, Dict, Union, Final, DefaultDict, Any
from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, ReduceOps, MemBuffer, ConstBuffer, BufferOps, MetaOps, UNSAFE_PAD_OPS, verify_lazyop, KernelInfo
from tinygrad.device import Device
@ -104,6 +104,7 @@ class Kernel:
self.local_dims: int = 0
self.tensor_core: Optional[TensorCore] = None
self.tensor_core_opts: Optional[TensorCoreOptions] = None
self.use_tensor_cores: int = 0
# the local aliased buffers for A and B
self.bufs_for_tensor_core: Dict[LazyOp, Tuple[int, int]] = {}
self.dont_use_locals: bool = False
@ -126,7 +127,8 @@ class Kernel:
# parameters for optimizations
ret.applied_opts, ret.group_for_reduces, ret.upcasted, ret.local_dims, ret.dont_use_locals = \
self.applied_opts[:], self.group_for_reduces, self.upcasted, self.local_dims, self.dont_use_locals
ret.tensor_core, ret.tensor_core_opts, ret.bufs_for_tensor_core = self.tensor_core, self.tensor_core_opts, self.bufs_for_tensor_core
ret.tensor_core, ret.tensor_core_opts, ret.bufs_for_tensor_core, ret.use_tensor_cores = \
self.tensor_core, self.tensor_core_opts, self.bufs_for_tensor_core, self.use_tensor_cores
return ret
@ -343,13 +345,10 @@ class Kernel:
# NOTE: LOCALS and UPCAST can be swapped here. it doesn't seem faster
self.apply_opt(Opt(OptOps.UPCAST, tc_opts.axes[1], 2), append_opt=False)
self.apply_opt(Opt(OptOps.UPCAST, tc_opts.axes[0], 2), append_opt=False)
self.apply_opt(Opt(OptOps.LOCAL, tc_opts.axes[0], 2), append_opt=False)
self.apply_opt(Opt(OptOps.LOCAL, tc_opts.axes[0], 2), append_opt=False)
self.apply_opt(Opt(OptOps.LOCAL, tc_opts.axes[1], 2), append_opt=False)
self.apply_opt(Opt(OptOps.LOCAL, tc_opts.axes[1], 2), append_opt=False)
self.apply_opt(Opt(OptOps.LOCAL, tc_opts.axes[1], 2), append_opt=False)
# assert tensor core
if use_tensor_cores == 1: self.tensor_core = tc # TC=2 will do the shape ops without the WMMA
for (tc_dim, tc_amt) in tc.threads:
self.apply_opt(Opt(OptOps.LOCAL, tc_opts.axes[tc_dim], tc_amt), append_opt=False)
self.tensor_core = tc
self.use_tensor_cores = use_tensor_cores # TC=2 will do the shape ops without the WMMA
return True
return False
@ -644,8 +643,12 @@ class Kernel:
@functools.lru_cache(None)
def fixup_ast(op:LazyOp, apply_to_st=None) -> LazyOp:
if op.op in BufferOps:
idx = self.bufs.index(op.arg)
arg = replace(op.arg, st=self.sts[idx] if apply_to_st is None else apply_to_st(self.sts[idx]))
if isinstance(op.arg, MemBuffer) and op.arg.idx < 0:
# for locals, we use the ShapeTracker that's in the MemBuffer
arg:Any = replace(op.arg, st=apply_to_st(op.arg.st)) if apply_to_st is not None else op.arg
else:
idx = self.bufs.index(op.arg)
arg = replace(op.arg, st=self.sts[idx] if apply_to_st is None else apply_to_st(self.sts[idx]))
elif op.op in ReduceOps:
reduce_idx = len(self.bufs) + self.reduceops.index(op)*2
arg = tuple(i for i in range(self.first_reduce+self.group_for_reduces, self.shape_len)
@ -661,11 +664,8 @@ class Kernel:
assert st1.shape[wd:wd+len(warp_dims)] == warp_dims, f"warp dims wrong: {st1.shape[wd:wd+len(warp_dims)]=} != {warp_dims=}"
assert st1.shape[tcd:tcd+len(tcd_dims)] == tcd_dims, f"tcd dims wrong: {st1.shape[tcd:tcd+len(tcd_dims)]=} != {tcd_dims=}"
new_shape = st1.shape[:tcd] + tcd_expand + st1.shape[tcd+len(tcd_dims):] # expand the tcd
permaxis = list(range(wd))
permaxis += [y + (wd if x == 0 else tcd) for x,y in pattern_1]
permaxis += list(range(wd+len(warp_dims), tcd))
permaxis += [y + (wd if x == 0 else tcd) for x,y in pattern_2]
permaxis += list(range(tcd+len(tcd_expand), len(new_shape)))
permaxis = list(range(wd)) + [y + (wd if x == 0 else tcd) for x,y in pattern_1] + list(range(wd+len(warp_dims), tcd)) + \
[y + (wd if x == 0 else tcd) for x,y in pattern_2] + list(range(tcd+len(tcd_expand), len(new_shape)))
return st1.reshape(new_shape).simplify().permute(tuple(permaxis)).reshape(st1.shape).simplify()
if self.opts.device in {"AMD", "HIP"}:
@ -691,7 +691,24 @@ class Kernel:
wmma_arg = (str(tc), tc.dims, tc.dtype_in, tc.dtype_out, self.opts.device,
tuple(tuple((self.first_upcast+ax, sz) for ax, sz in up) for up in upcast_axis),
tuple(self.first_upcast+ax for ax in reduce_axes))
ret = LazyOp(ReduceOps.WMMA, (fixup_ast(rsrc.src[0], fix_st1), fixup_ast(rsrc.src[1], fix_st2)), wmma_arg)
if self.use_tensor_cores >= 2:
if self.use_tensor_cores == 3:
# TC=3, emulate the warp addressing with locals
ex_shape = tuple(1 if i < self.global_dims or (i >= self.first_reduce and i < self.first_upcast) else s \
for i,s in enumerate(self.full_shape))
srcs = []
for i,(src,fix_st_fxn) in enumerate(zip(rsrc.src, [fix_st1, fix_st2])):
st_load = [self.sts[self.bufs.index(op.arg)].real_strides() for op in src.lazyops if op.op is BufferOps.LOAD]
local_shape = tuple(s if max(cast(int, x[i]) for x in st_load) != 0 else 1 for i,s in enumerate(ex_shape))
membuf = MemBuffer(-1-i, tc.dtype_in, ShapeTracker.from_shape(local_shape).expand(ex_shape))
srcs.append(LazyOp(BufferOps.LOAD, (fixup_ast(LazyOp(BufferOps.STORE, (src,), membuf), fix_st_fxn),), membuf))
else:
# for TC=2, we can't do the shapetracker fixup
srcs = [fixup_ast(rsrc.src[0]), fixup_ast(rsrc.src[1])]
# MUL/SUM instead of WMMA
ret = LazyOp(ReduceOps.SUM, (LazyOp(BinaryOps.MUL, tuple(srcs)),), wmma_arg[-1])
else:
ret = LazyOp(ReduceOps.WMMA, (fixup_ast(rsrc.src[0], fix_st1), fixup_ast(rsrc.src[1], fix_st2)), wmma_arg)
return LazyOp(op.op, (ret,), new_reduce_axes) if (new_reduce_axes:=tuple(i for i in arg if i-self.first_upcast not in reduce_axes)) else ret
if self.group_for_reduces:
start = LazyOp(op.op, tuple(fixup_ast(x, apply_to_st) for x in op.src), arg)

View File

@ -158,8 +158,8 @@ class IndependentLowerer:
if x.op is BufferOps.LOAD:
barrier = (UOp(UOps.BARRIER, None, (self.to_uop(x.src[0]),)),) if len(x.src) else ()
return UOp(UOps.LOAD, x.arg.dtype.scalar(), (buf, idx) + ((valid, UOp.const(x.arg.dtype.scalar(), 0)) if has_valid else ()) + barrier)
# NOTE: only store the local reduceop in the first thread
if x.arg.idx != -1:
# NOTE: only store the local reduceop in the first thread (this is wrong for non group for reduces!)
if x.arg.idx >= 0:
for oidx, ridx in zip(self.idxs, self.ridxs):
if oidx != ridx: valid = valid * oidx.eq(0)
has_valid = valid.op is not UOps.CONST or valid.arg is not True

View File

@ -265,7 +265,7 @@ class CUDARenderer(CStyleLanguage):
global_max = (2147483647, 65535, 65535)
local_max = (1024, 1024, 64)
shared_max = 49152
tensor_cores = [TensorCore(dims=(8,16,16), threads=[(0,2),(0,2),(1,2),(1,2),(0,2)], dtype_in=di, dtype_out=do) for (di, do) in ([(dtypes.half, dtypes.float), (dtypes.bfloat16, dtypes.float)])] # noqa: E501
tensor_cores = [TensorCore(dims=(8,16,16), threads=[(0,2),(0,2),(1,2),(1,2),(1,2)], dtype_in=di, dtype_out=do) for (di, do) in ([(dtypes.half, dtypes.float), (dtypes.bfloat16, dtypes.float)])] # noqa: E501
def __init__(self, arch:str): self.tensor_cores = CUDARenderer.tensor_cores if int(arch[3:]) >= 80 else []
# language options

View File

@ -23,7 +23,7 @@ class ShapeTracker:
return ShapeTracker(tuple(inverted_views)).reshape(out_shape)
@staticmethod
def from_shape(shape:Tuple[sint, ...]): return ShapeTracker((View.create(shape),))
def from_shape(shape:Tuple[sint, ...]) -> ShapeTracker: return ShapeTracker((View.create(shape),))
@property
def contiguous(self) -> bool: return len(self.views) == 1 and self.views[0].contiguous