mirror of https://github.com/commaai/tinygrad.git
TC=2 still sets tensor cores (and TC=3 support for locals) (#5780)
* TC=2 still sets tensor cores * add TC=3 support for using locals * bugfix * lines + TC=3 tests * CUDA can use threads, fix fuzz linearizer
This commit is contained in:
parent
71a64d8252
commit
0392123e6e
|
@ -50,6 +50,11 @@ jobs:
|
|||
PYTHONPATH=. DEBUG=2 EMULATE_METAL=1 FORWARD_ONLY=1 PYTHON=1 python3 ./test/test_linearizer.py TestLinearizer.test_tensor_cores
|
||||
PYTHONPATH=. DEBUG=2 EMULATE_AMD=1 FORWARD_ONLY=1 PYTHON=1 python3 ./test/test_linearizer.py TestLinearizer.test_tensor_cores
|
||||
PYTHONPATH=. DEBUG=2 EMULATE_CUDA=1 FORWARD_ONLY=1 PYTHON=1 python3 ./test/test_linearizer.py TestLinearizer.test_tensor_cores
|
||||
- name: Test tensor cores (TC=3)
|
||||
run: |
|
||||
TC=3 DEBUG=3 EMULATE_METAL=1 FORWARD_ONLY=1 PYTHON=1 python3 test/test_ops.py TestOps.test_gemm
|
||||
TC=3 PYTHONPATH=. DEBUG=3 EMULATE_AMD=1 FORWARD_ONLY=1 PYTHON=1 N=16 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py
|
||||
TC=3 DEBUG=3 EMULATE_CUDA=1 FORWARD_ONLY=1 PYTHON=1 python3 test/test_ops.py TestOps.test_gemm_fp16
|
||||
- name: Test dtype with Python emulator
|
||||
run: DEBUG=1 PYTHONPATH=. PYTHON=1 python3 -m pytest -n=auto test/test_dtype.py test/test_dtype_alu.py
|
||||
- name: Test ops with Python emulator
|
||||
|
|
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
|||
import itertools, functools
|
||||
from dataclasses import replace
|
||||
from collections import defaultdict
|
||||
from typing import Optional, List, Tuple, cast, Dict, Union, Final, DefaultDict
|
||||
from typing import Optional, List, Tuple, cast, Dict, Union, Final, DefaultDict, Any
|
||||
|
||||
from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, ReduceOps, MemBuffer, ConstBuffer, BufferOps, MetaOps, UNSAFE_PAD_OPS, verify_lazyop, KernelInfo
|
||||
from tinygrad.device import Device
|
||||
|
@ -104,6 +104,7 @@ class Kernel:
|
|||
self.local_dims: int = 0
|
||||
self.tensor_core: Optional[TensorCore] = None
|
||||
self.tensor_core_opts: Optional[TensorCoreOptions] = None
|
||||
self.use_tensor_cores: int = 0
|
||||
# the local aliased buffers for A and B
|
||||
self.bufs_for_tensor_core: Dict[LazyOp, Tuple[int, int]] = {}
|
||||
self.dont_use_locals: bool = False
|
||||
|
@ -126,7 +127,8 @@ class Kernel:
|
|||
# parameters for optimizations
|
||||
ret.applied_opts, ret.group_for_reduces, ret.upcasted, ret.local_dims, ret.dont_use_locals = \
|
||||
self.applied_opts[:], self.group_for_reduces, self.upcasted, self.local_dims, self.dont_use_locals
|
||||
ret.tensor_core, ret.tensor_core_opts, ret.bufs_for_tensor_core = self.tensor_core, self.tensor_core_opts, self.bufs_for_tensor_core
|
||||
ret.tensor_core, ret.tensor_core_opts, ret.bufs_for_tensor_core, ret.use_tensor_cores = \
|
||||
self.tensor_core, self.tensor_core_opts, self.bufs_for_tensor_core, self.use_tensor_cores
|
||||
|
||||
return ret
|
||||
|
||||
|
@ -343,13 +345,10 @@ class Kernel:
|
|||
# NOTE: LOCALS and UPCAST can be swapped here. it doesn't seem faster
|
||||
self.apply_opt(Opt(OptOps.UPCAST, tc_opts.axes[1], 2), append_opt=False)
|
||||
self.apply_opt(Opt(OptOps.UPCAST, tc_opts.axes[0], 2), append_opt=False)
|
||||
self.apply_opt(Opt(OptOps.LOCAL, tc_opts.axes[0], 2), append_opt=False)
|
||||
self.apply_opt(Opt(OptOps.LOCAL, tc_opts.axes[0], 2), append_opt=False)
|
||||
self.apply_opt(Opt(OptOps.LOCAL, tc_opts.axes[1], 2), append_opt=False)
|
||||
self.apply_opt(Opt(OptOps.LOCAL, tc_opts.axes[1], 2), append_opt=False)
|
||||
self.apply_opt(Opt(OptOps.LOCAL, tc_opts.axes[1], 2), append_opt=False)
|
||||
# assert tensor core
|
||||
if use_tensor_cores == 1: self.tensor_core = tc # TC=2 will do the shape ops without the WMMA
|
||||
for (tc_dim, tc_amt) in tc.threads:
|
||||
self.apply_opt(Opt(OptOps.LOCAL, tc_opts.axes[tc_dim], tc_amt), append_opt=False)
|
||||
self.tensor_core = tc
|
||||
self.use_tensor_cores = use_tensor_cores # TC=2 will do the shape ops without the WMMA
|
||||
return True
|
||||
return False
|
||||
|
||||
|
@ -644,8 +643,12 @@ class Kernel:
|
|||
@functools.lru_cache(None)
|
||||
def fixup_ast(op:LazyOp, apply_to_st=None) -> LazyOp:
|
||||
if op.op in BufferOps:
|
||||
idx = self.bufs.index(op.arg)
|
||||
arg = replace(op.arg, st=self.sts[idx] if apply_to_st is None else apply_to_st(self.sts[idx]))
|
||||
if isinstance(op.arg, MemBuffer) and op.arg.idx < 0:
|
||||
# for locals, we use the ShapeTracker that's in the MemBuffer
|
||||
arg:Any = replace(op.arg, st=apply_to_st(op.arg.st)) if apply_to_st is not None else op.arg
|
||||
else:
|
||||
idx = self.bufs.index(op.arg)
|
||||
arg = replace(op.arg, st=self.sts[idx] if apply_to_st is None else apply_to_st(self.sts[idx]))
|
||||
elif op.op in ReduceOps:
|
||||
reduce_idx = len(self.bufs) + self.reduceops.index(op)*2
|
||||
arg = tuple(i for i in range(self.first_reduce+self.group_for_reduces, self.shape_len)
|
||||
|
@ -661,11 +664,8 @@ class Kernel:
|
|||
assert st1.shape[wd:wd+len(warp_dims)] == warp_dims, f"warp dims wrong: {st1.shape[wd:wd+len(warp_dims)]=} != {warp_dims=}"
|
||||
assert st1.shape[tcd:tcd+len(tcd_dims)] == tcd_dims, f"tcd dims wrong: {st1.shape[tcd:tcd+len(tcd_dims)]=} != {tcd_dims=}"
|
||||
new_shape = st1.shape[:tcd] + tcd_expand + st1.shape[tcd+len(tcd_dims):] # expand the tcd
|
||||
permaxis = list(range(wd))
|
||||
permaxis += [y + (wd if x == 0 else tcd) for x,y in pattern_1]
|
||||
permaxis += list(range(wd+len(warp_dims), tcd))
|
||||
permaxis += [y + (wd if x == 0 else tcd) for x,y in pattern_2]
|
||||
permaxis += list(range(tcd+len(tcd_expand), len(new_shape)))
|
||||
permaxis = list(range(wd)) + [y + (wd if x == 0 else tcd) for x,y in pattern_1] + list(range(wd+len(warp_dims), tcd)) + \
|
||||
[y + (wd if x == 0 else tcd) for x,y in pattern_2] + list(range(tcd+len(tcd_expand), len(new_shape)))
|
||||
return st1.reshape(new_shape).simplify().permute(tuple(permaxis)).reshape(st1.shape).simplify()
|
||||
|
||||
if self.opts.device in {"AMD", "HIP"}:
|
||||
|
@ -691,7 +691,24 @@ class Kernel:
|
|||
wmma_arg = (str(tc), tc.dims, tc.dtype_in, tc.dtype_out, self.opts.device,
|
||||
tuple(tuple((self.first_upcast+ax, sz) for ax, sz in up) for up in upcast_axis),
|
||||
tuple(self.first_upcast+ax for ax in reduce_axes))
|
||||
ret = LazyOp(ReduceOps.WMMA, (fixup_ast(rsrc.src[0], fix_st1), fixup_ast(rsrc.src[1], fix_st2)), wmma_arg)
|
||||
if self.use_tensor_cores >= 2:
|
||||
if self.use_tensor_cores == 3:
|
||||
# TC=3, emulate the warp addressing with locals
|
||||
ex_shape = tuple(1 if i < self.global_dims or (i >= self.first_reduce and i < self.first_upcast) else s \
|
||||
for i,s in enumerate(self.full_shape))
|
||||
srcs = []
|
||||
for i,(src,fix_st_fxn) in enumerate(zip(rsrc.src, [fix_st1, fix_st2])):
|
||||
st_load = [self.sts[self.bufs.index(op.arg)].real_strides() for op in src.lazyops if op.op is BufferOps.LOAD]
|
||||
local_shape = tuple(s if max(cast(int, x[i]) for x in st_load) != 0 else 1 for i,s in enumerate(ex_shape))
|
||||
membuf = MemBuffer(-1-i, tc.dtype_in, ShapeTracker.from_shape(local_shape).expand(ex_shape))
|
||||
srcs.append(LazyOp(BufferOps.LOAD, (fixup_ast(LazyOp(BufferOps.STORE, (src,), membuf), fix_st_fxn),), membuf))
|
||||
else:
|
||||
# for TC=2, we can't do the shapetracker fixup
|
||||
srcs = [fixup_ast(rsrc.src[0]), fixup_ast(rsrc.src[1])]
|
||||
# MUL/SUM instead of WMMA
|
||||
ret = LazyOp(ReduceOps.SUM, (LazyOp(BinaryOps.MUL, tuple(srcs)),), wmma_arg[-1])
|
||||
else:
|
||||
ret = LazyOp(ReduceOps.WMMA, (fixup_ast(rsrc.src[0], fix_st1), fixup_ast(rsrc.src[1], fix_st2)), wmma_arg)
|
||||
return LazyOp(op.op, (ret,), new_reduce_axes) if (new_reduce_axes:=tuple(i for i in arg if i-self.first_upcast not in reduce_axes)) else ret
|
||||
if self.group_for_reduces:
|
||||
start = LazyOp(op.op, tuple(fixup_ast(x, apply_to_st) for x in op.src), arg)
|
||||
|
|
|
@ -158,8 +158,8 @@ class IndependentLowerer:
|
|||
if x.op is BufferOps.LOAD:
|
||||
barrier = (UOp(UOps.BARRIER, None, (self.to_uop(x.src[0]),)),) if len(x.src) else ()
|
||||
return UOp(UOps.LOAD, x.arg.dtype.scalar(), (buf, idx) + ((valid, UOp.const(x.arg.dtype.scalar(), 0)) if has_valid else ()) + barrier)
|
||||
# NOTE: only store the local reduceop in the first thread
|
||||
if x.arg.idx != -1:
|
||||
# NOTE: only store the local reduceop in the first thread (this is wrong for non group for reduces!)
|
||||
if x.arg.idx >= 0:
|
||||
for oidx, ridx in zip(self.idxs, self.ridxs):
|
||||
if oidx != ridx: valid = valid * oidx.eq(0)
|
||||
has_valid = valid.op is not UOps.CONST or valid.arg is not True
|
||||
|
|
|
@ -265,7 +265,7 @@ class CUDARenderer(CStyleLanguage):
|
|||
global_max = (2147483647, 65535, 65535)
|
||||
local_max = (1024, 1024, 64)
|
||||
shared_max = 49152
|
||||
tensor_cores = [TensorCore(dims=(8,16,16), threads=[(0,2),(0,2),(1,2),(1,2),(0,2)], dtype_in=di, dtype_out=do) for (di, do) in ([(dtypes.half, dtypes.float), (dtypes.bfloat16, dtypes.float)])] # noqa: E501
|
||||
tensor_cores = [TensorCore(dims=(8,16,16), threads=[(0,2),(0,2),(1,2),(1,2),(1,2)], dtype_in=di, dtype_out=do) for (di, do) in ([(dtypes.half, dtypes.float), (dtypes.bfloat16, dtypes.float)])] # noqa: E501
|
||||
def __init__(self, arch:str): self.tensor_cores = CUDARenderer.tensor_cores if int(arch[3:]) >= 80 else []
|
||||
|
||||
# language options
|
||||
|
|
|
@ -23,7 +23,7 @@ class ShapeTracker:
|
|||
return ShapeTracker(tuple(inverted_views)).reshape(out_shape)
|
||||
|
||||
@staticmethod
|
||||
def from_shape(shape:Tuple[sint, ...]): return ShapeTracker((View.create(shape),))
|
||||
def from_shape(shape:Tuple[sint, ...]) -> ShapeTracker: return ShapeTracker((View.create(shape),))
|
||||
|
||||
@property
|
||||
def contiguous(self) -> bool: return len(self.views) == 1 and self.views[0].contiguous
|
||||
|
|
Loading…
Reference in New Issue