mirror of https://github.com/commaai/tinygrad.git
move tc to renderers (#4631)
* move tc to renderers * missed import * fix typo * fix * fix imports * remove from tests * fix 4607 * nv emulate timestamp * time is int * correct time
This commit is contained in:
parent
d70988dddf
commit
daf57af3eb
|
@ -114,7 +114,10 @@ class GPFIFO:
|
|||
val = self._state64_le(nv_gpu.NVC56F_SEM_PAYLOAD_LO)
|
||||
flags = self._next_dword()
|
||||
typ = (flags >> 0) & 0b111
|
||||
if typ == 1: to_mv(signal, 8).cast('Q')[0] = val
|
||||
timestamp = (flags & (1 << 25)) == (1 << 25)
|
||||
if typ == 1:
|
||||
to_mv(signal, 8).cast('Q')[0] = val
|
||||
if timestamp: to_mv(signal + 8, 8).cast('Q')[0] = int(time.perf_counter() * 1e9)
|
||||
elif typ == 3:
|
||||
mval = to_mv(signal, 8).cast('Q')[0]
|
||||
return SchedResult.CONT if mval >= val else SchedResult.YIELD
|
||||
|
|
|
@ -3,7 +3,7 @@ import numpy as np
|
|||
import unittest
|
||||
from dataclasses import replace
|
||||
|
||||
from tinygrad.codegen.kernel import Opt, OptOps, KernelOptError, tensor_cores
|
||||
from tinygrad.codegen.kernel import Opt, OptOps, KernelOptError
|
||||
from tinygrad.codegen.linearizer import Linearizer, UOp, UOps, expand_node, expand_idxs
|
||||
from tinygrad.device import Device, Buffer
|
||||
from tinygrad.ops import BinaryOps, BufferOps, MemBuffer, ConstBuffer, LazyOp, LoadOps, TernaryOps, ReduceOps, UnaryOps
|
||||
|
@ -261,16 +261,16 @@ class TestLinearizer(unittest.TestCase):
|
|||
helper_arg_acc_dtype(d.conv2d(w, acc_dtype=acc_dtype), expected_dtype)
|
||||
|
||||
def test_tensor_cores(self):
|
||||
if not Device[Device.DEFAULT].renderer.has_tensor_cores:
|
||||
if not Device[Device.DEFAULT].renderer.tensor_cores:
|
||||
self.skipTest("device doesn't have tensor cores")
|
||||
for tc in tensor_cores[Device[Device.DEFAULT].renderer.device]:
|
||||
for tc in Device[Device.DEFAULT].renderer.tensor_cores:
|
||||
if getenv("EMULATE_CUDA") and (tc.dtype_in == dtypes.bfloat16 or tc.dtype_out == dtypes.bfloat16): continue
|
||||
helper_tc_allclose(tc.dims[0], tc.dims[1], tc.dims[2], tc.dtype_in, tc.dtype_out, axis=0, tc_opt=0)
|
||||
|
||||
def test_tensor_cores_padded(self):
|
||||
if not Device[Device.DEFAULT].renderer.has_tensor_cores:
|
||||
if not Device[Device.DEFAULT].renderer.tensor_cores:
|
||||
self.skipTest("device doesn't have tensor cores")
|
||||
for tc in tensor_cores[Device[Device.DEFAULT].renderer.device]:
|
||||
for tc in Device[Device.DEFAULT].renderer.tensor_cores:
|
||||
if getenv("EMULATE_CUDA") and (tc.dtype_in == dtypes.bfloat16 or tc.dtype_out == dtypes.bfloat16): continue
|
||||
pad = 1
|
||||
|
||||
|
@ -294,9 +294,9 @@ class TestLinearizer(unittest.TestCase):
|
|||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI is really slow here")
|
||||
def test_tensor_cores_multi_reduce(self):
|
||||
if not Device[Device.DEFAULT].renderer.has_tensor_cores:
|
||||
if not Device[Device.DEFAULT].renderer.tensor_cores:
|
||||
self.skipTest("device doesn't have tensor cores")
|
||||
for tc in tensor_cores[Device[Device.DEFAULT].renderer.device]:
|
||||
for tc in Device[Device.DEFAULT].renderer.tensor_cores:
|
||||
if getenv("EMULATE_CUDA") and (tc.dtype_in == dtypes.bfloat16 or tc.dtype_out == dtypes.bfloat16): continue
|
||||
# this will be a M=G16, N=G32, M=G16, M=G16, K=R16, K=R16, K=R16 with 9 choices of TC MNK axes
|
||||
golden_result = None
|
||||
|
@ -786,10 +786,8 @@ class TestKernelOpts(unittest.TestCase):
|
|||
])
|
||||
|
||||
def test_invalid_tensor_core_extra_opts(self):
|
||||
if not Device[Device.DEFAULT].renderer.has_tensor_cores:
|
||||
if not Device[Device.DEFAULT].renderer.tensor_cores:
|
||||
self.skipTest("device doesn't have tensor cores")
|
||||
if Device.DEFAULT not in tensor_cores:
|
||||
self.skipTest("No tensor cores for device")
|
||||
|
||||
N = 128
|
||||
Tensor.manual_seed(1552)
|
||||
|
@ -807,10 +805,8 @@ class TestKernelOpts(unittest.TestCase):
|
|||
assert k.apply_tensor_cores(use_tensor_cores=1, extra_opts=x), "no valid tensor core" # for METAL in runners
|
||||
|
||||
def test_buf_index_not_found_tensor_core(self):
|
||||
if not Device[Device.DEFAULT].renderer.has_tensor_cores:
|
||||
if not Device[Device.DEFAULT].renderer.tensor_cores:
|
||||
self.skipTest("device doesn't have tensor cores")
|
||||
if Device.DEFAULT not in tensor_cores:
|
||||
self.skipTest("No tensor cores for device")
|
||||
|
||||
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BinaryOps.CMPEQ, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(1243, 256), strides=(0, 1), offset=0, mask=None, contiguous=False),)))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(1243, 256), strides=(1, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=dtypes.float), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=3, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1243, 256), strides=(1, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=(0,)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 256), strides=(0, 1), offset=0, mask=None, contiguous=True),)))) # noqa: E501
|
||||
k = Linearizer(ast, opts=Device[Device.DEFAULT].renderer)
|
||||
|
@ -818,14 +814,12 @@ class TestKernelOpts(unittest.TestCase):
|
|||
k.apply_opt(Opt(OptOps.TC, 0, 1))
|
||||
|
||||
def test_tensor_core_opts(self):
|
||||
if not Device[Device.DEFAULT].renderer.has_tensor_cores:
|
||||
if not Device[Device.DEFAULT].renderer.tensor_cores:
|
||||
self.skipTest("device doesn't have tensor cores")
|
||||
if Device.DEFAULT not in tensor_cores:
|
||||
self.skipTest("No tensor cores for device")
|
||||
|
||||
N = 128
|
||||
Tensor.manual_seed(1552)
|
||||
for tc in tensor_cores[Device[Device.DEFAULT].renderer.device]:
|
||||
for tc in Device[Device.DEFAULT].renderer.tensor_cores:
|
||||
# bf16 buffer returns float32 numpy outputs so test would fail. testing opt with half suffices.
|
||||
if tc.dtype_in == dtypes.bfloat16: continue
|
||||
a, b = Tensor.rand(N, N, dtype=tc.dtype_in), Tensor.rand(N, N, dtype=tc.dtype_in)
|
||||
|
|
|
@ -14,7 +14,6 @@ from tinygrad.shape.shapetracker import ShapeTracker
|
|||
from tinygrad.shape.view import View
|
||||
|
||||
class TestTimeLinearizer(unittest.TestCase):
|
||||
@unittest.skipIf(Device.DEFAULT in {"AMD", "NV"}, "Tries to open HSA/CUDA. #4607")
|
||||
def test_reasonable_time(self):
|
||||
si = [i for i in create_schedule([Tensor([1,2,3,4]).add(1).lazydata]) if i.ast[0].op not in LoadOps][0]
|
||||
out = Buffer(Device.DEFAULT, si.outputs[0].size, si.outputs[0].dtype).allocate()
|
||||
|
@ -23,7 +22,6 @@ class TestTimeLinearizer(unittest.TestCase):
|
|||
tm = time_linearizer(Linearizer(*si.ast), rawbufs, allow_test_size=False, cnt=10)
|
||||
assert tm > 0 and tm != float('inf')
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT in {"AMD", "NV"}, "Tries to open HSA/CUDA. #4607")
|
||||
def test_bufs_from_lin(self):
|
||||
si = [i for i in create_schedule([Tensor([1,2,3,4]).add(1).lazydata]) if i.ast[0].op not in LoadOps][0]
|
||||
rawbufs = bufs_from_lin(lin:=Linearizer(*si.ast))
|
||||
|
@ -33,7 +31,6 @@ class TestTimeLinearizer(unittest.TestCase):
|
|||
assert all(r.size > 0 for r in rawbufs)
|
||||
|
||||
class TestBEAM(unittest.TestCase):
|
||||
@unittest.skipIf(Device.DEFAULT in {"AMD", "NV"}, "Tries to open HSA/CUDA. #4607")
|
||||
def test_dynamic_beam(self):
|
||||
# TODO: make this infra globally usable
|
||||
class Capture:
|
||||
|
@ -69,7 +66,6 @@ class TestBEAM(unittest.TestCase):
|
|||
if Opt(OptOps.GROUPTOP, 0, 0) in actions:
|
||||
assert len([x for x in lins if x.applied_opts[0] == Opt(OptOps.GROUPTOP, axis=0, amt=3)]) == 0, "did not de-dup GROUPTOP"
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT in {"NV"}, "Tries to open CUDA. #4607")
|
||||
def test_filter_global_buffer(self):
|
||||
# taken from https://github.com/tinygrad/tinygrad/issues/4612
|
||||
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.MAX, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=0, mask=((0, 64128),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-64128, mask=((64128, 128256),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False)))))), arg=None), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=3, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-128256, mask=((128256, 192384),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False)))))), arg=None), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=4, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-192384, mask=((192384, 256512),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False)))))), arg=None), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=5, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-256512, mask=((256512, 320640),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False)))))), arg=None), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=6, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-320640, mask=((320640, 384768),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False)))))), arg=None), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1.4285714285714286, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 501, 256), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=(1,)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 256), strides=(0, 0, 1), offset=0, mask=None, contiguous=True),)))) # noqa: E501
|
||||
|
|
|
@ -3,7 +3,7 @@ import math, itertools
|
|||
from typing import NamedTuple, Optional, List, Tuple, cast, Dict, Union
|
||||
from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, ReduceOps, MemBuffer, ConstBuffer, BufferOps, UNSAFE_PAD_OPS
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.renderer import Renderer, TensorCore
|
||||
from tinygrad.dtype import dtypes, ImageDType, DType
|
||||
from tinygrad.helpers import colored, ansilen, dedup, flatten, getenv, prod, DEBUG, round_up, all_int, get_contraction
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
|
@ -34,17 +34,6 @@ class Opt:
|
|||
if self.op in {OptOps.GROUP, OptOps.GROUPTOP}: return k.first_reduce+k.group_for_reduces+self.axis
|
||||
return self.axis
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TensorCore: # D = A * B + C, A is (M x K), B is (K x N), C and D are (M x N)
|
||||
dims: Tuple[int,int,int] # N, M, K
|
||||
dtype_in: DType # dtype for A and B
|
||||
dtype_out: DType # dtype for C and D
|
||||
threads: List[Tuple[int,int]] # list of (TC dim,amt) that construct the warp thread structure
|
||||
thread_local_aliases: List[List[List[int]]] # a list of [threads_1, ..., threads_n, upcast_1(unrolled), upcast_2(upcast)] defining the alias (-1 is upcast, 1-n is warp threads) for each TC dim # noqa: E501
|
||||
thread_local_sizes: List[List[int]] # in each thread, the number of elements stored in registers for each TC dim
|
||||
def __str__(self): return "_".join(["WMMA"] + list(map(str, self.dims)) + [self.dtype_in.name, self.dtype_out.name])
|
||||
def num_upcasts(self): return len(self.thread_local_aliases[0]) - len(self.threads)
|
||||
|
||||
class TensorCoreOptions(NamedTuple):
|
||||
bufs: Tuple[int, int] # the local aliased buffers for A and B
|
||||
axes: List[int] # the location of the original N and M axes if still in the shape
|
||||
|
@ -54,12 +43,6 @@ class TensorCoreOptions(NamedTuple):
|
|||
if removed_axis < self.axes[tc_dim]: self.axes[tc_dim] -= 1
|
||||
elif removed_axis == self.axes[tc_dim]: self.axes_exist[tc_dim] = False
|
||||
|
||||
tensor_cores: Dict[str, List[TensorCore]] = {
|
||||
"METAL": [TensorCore(dims=(8,8,8), threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[[2],[2],[2]], thread_local_aliases=[ [[0],[2],[0],[4],[-1, 1, 3],[0]], [[1],[0],[3],[0],[2, 4],[-1]], [[1],[2],[3],[4],[0],[-1]] ], dtype_in=di, dtype_out=do) for (di, do) in [(dtypes.float, dtypes.float), (dtypes.half, dtypes.float), (dtypes.half, dtypes.half)]], # noqa: E501
|
||||
"HSA": [TensorCore(dims=(16,16,16), threads=[(0,8),(0,2),(1,2)], thread_local_sizes=[[16],[16],[4,2]], thread_local_aliases=[ [[0],[0],[2],[-1],[1]], [[1],[2],[0],[-1],[0]], [[1],[2],[-2],[0],[3,-1]] ], dtype_in=di, dtype_out=do) for (di, do) in [(dtypes.half, dtypes.float), (dtypes.half, dtypes.half)]], # noqa: E501
|
||||
"CUDA": [TensorCore(dims=(8,16,16), threads=[(0,2),(0,2),(1,2),(1,2),(0,2)], thread_local_sizes=[[2,2,2],[2,2],[2,2]], thread_local_aliases=[ [[0],[0],[5],[-2],[0],[-1,1,2,-3],[3,4]], [[3],[4],[0],[0],[5],[-1,1,2,-2],[0]], [[-1],[1],[5],[-2],[2],[0],[3,4]] ], dtype_in=di, dtype_out=do) for (di, do) in ([(dtypes.half, dtypes.float)] if getenv("PTX") else [(dtypes.half, dtypes.float), (dtypes.bfloat16, dtypes.float)])], # noqa: E501
|
||||
}
|
||||
|
||||
class LocalBuffer(NamedTuple):
|
||||
name: str
|
||||
size: int
|
||||
|
@ -340,8 +323,8 @@ class Kernel:
|
|||
# ******************** high level optimizers ********************
|
||||
|
||||
def _apply_tc_opt(self, use_tensor_cores:int, axis:int, opt_level:int) -> bool:
|
||||
if use_tensor_cores and self.opts.has_local and self.reduceop and self.reduceop.op is ReduceOps.SUM and self.opts.device in tensor_cores:
|
||||
for tc in tensor_cores[self.opts.device]:
|
||||
if use_tensor_cores and self.opts.has_local and self.reduceop and self.reduceop.op is ReduceOps.SUM:
|
||||
for tc in self.opts.tensor_cores:
|
||||
has_cast = tc.dtype_in != tc.dtype_out
|
||||
if has_cast and not(self.reduceop.src[0].op is UnaryOps.CAST and self.reduceop.src[0].arg == tc.dtype_out): continue
|
||||
|
||||
|
@ -403,7 +386,7 @@ class Kernel:
|
|||
1: allows kernels with multiple reduce axes and also multiplication of UnaryOps.CAST'd buffers
|
||||
2: allows kernels with M, N, K axes that are not multiples of the tensor core dimensions by applying padding those axes as needed
|
||||
"""
|
||||
if not self.opts.has_tensor_cores and use_tensor_cores != 2: return False
|
||||
if not self.opts.tensor_cores and use_tensor_cores != 2: return False
|
||||
try: # check TC first and apply hand-coded opts if successful
|
||||
self.apply_opt(Opt(OptOps.TC, axis, tc_opt))
|
||||
|
||||
|
@ -435,7 +418,7 @@ class Kernel:
|
|||
if opt.op is OptOps.TC:
|
||||
check(len(self.applied_opts) == 0, "tensor core opts must be first") # TODO: things like PADTO might be fine
|
||||
check(opt.axis is not None and opt.amt is not None, "tensor core opts must have an axis and amt")
|
||||
check((use_tensor_cores:=getenv("TC", 1)) == 2 or self.opts.has_tensor_cores, "must have tensor cores or TC=2")
|
||||
check((use_tensor_cores:=getenv("TC", 1)) == 2 or len(self.opts.tensor_cores) > 0, "must have tensor cores or TC=2")
|
||||
check(self._apply_tc_opt(use_tensor_cores, cast(int, opt.axis), cast(int, opt.amt)), "no tensor core available")
|
||||
self.applied_opts.append(opt)
|
||||
return
|
||||
|
|
|
@ -4,6 +4,18 @@ from dataclasses import dataclass
|
|||
from tinygrad.helpers import to_function_name
|
||||
from tinygrad.codegen.uops import UOpGraph
|
||||
from tinygrad.shape.symbolic import sym_infer, sint, Variable
|
||||
from tinygrad.dtype import DType
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TensorCore: # D = A * B + C, A is (M x K), B is (K x N), C and D are (M x N)
|
||||
dims: Tuple[int,int,int] # N, M, K
|
||||
dtype_in: DType # dtype for A and B
|
||||
dtype_out: DType # dtype for C and D
|
||||
threads: List[Tuple[int,int]] # list of (TC dim,amt) that construct the warp thread structure
|
||||
thread_local_aliases: List[List[List[int]]] # a list of [threads_1, ..., threads_n, upcast_1(unrolled), upcast_2(upcast)] defining the alias (-1 is upcast, 1-n is warp threads) for each TC dim # noqa: E501
|
||||
thread_local_sizes: List[List[int]] # in each thread, the number of elements stored in registers for each TC dim
|
||||
def __str__(self): return "_".join(["WMMA"] + list(map(str, self.dims)) + [self.dtype_in.name, self.dtype_out.name])
|
||||
def num_upcasts(self): return len(self.thread_local_aliases[0]) - len(self.threads)
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Program:
|
||||
|
@ -40,10 +52,10 @@ class Renderer:
|
|||
supports_float4: bool = True
|
||||
has_local: bool = True
|
||||
has_shared: bool = True
|
||||
has_tensor_cores: bool = False
|
||||
# NOTE: these two should be in z,y,x(reversed) order for cstyle backends, they are flipped when kernel is rendered
|
||||
global_max: Optional[List[int]] = None
|
||||
local_max: Optional[List[int]] = None
|
||||
shared_max: int = 32768
|
||||
tensor_cores: List[TensorCore] = []
|
||||
|
||||
def render(self, name:str, uops:UOpGraph) -> str: raise NotImplementedError("needs a renderer")
|
||||
|
|
|
@ -5,7 +5,7 @@ from tinygrad.codegen.linearizer import UOps, UOp
|
|||
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Op
|
||||
from tinygrad.dtype import dtypes, DType, PtrDType, ConstType
|
||||
from tinygrad.codegen.uops import UOpGraph, PatternMatcher
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.renderer import Renderer, TensorCore
|
||||
|
||||
def render_val(x, dtype):
|
||||
if dtypes.is_float(dtype):
|
||||
|
@ -48,11 +48,11 @@ def optimize_gated_loads(uops: UOpGraph):
|
|||
class PTXRenderer(Renderer):
|
||||
device = "CUDA"
|
||||
suffix = "PTX"
|
||||
global_max=[65535, 65535, 2147483647]
|
||||
local_max=[64, 1024, 1024]
|
||||
shared_max=49152
|
||||
has_tensor_cores = False
|
||||
def __init__(self, arch:str): self.has_tensor_cores=int(arch[3:]) >= 80
|
||||
global_max = [65535, 65535, 2147483647]
|
||||
local_max = [64, 1024, 1024]
|
||||
shared_max = 49152
|
||||
tensor_cores = [TensorCore(dims=(8,16,16), threads=[(0,2),(0,2),(1,2),(1,2),(0,2)], thread_local_sizes=[[2,2,2],[2,2],[2,2]], thread_local_aliases=[ [[0],[0],[5],[-2],[0],[-1,1,2,-3],[3,4]], [[3],[4],[0],[0],[5],[-1,1,2,-2],[0]], [[-1],[1],[5],[-2],[2],[0],[3,4]] ], dtype_in=di, dtype_out=do) for (di, do) in ([(dtypes.half, dtypes.float)])] # noqa: E501
|
||||
def __init__(self, arch:str): self.tensor_cores = PTXRenderer.tensor_cores if int(arch[3:]) >= 80 else []
|
||||
|
||||
# language options
|
||||
kernel_prefix = """.version VERSION
|
||||
|
|
|
@ -6,7 +6,7 @@ from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
|
|||
from tinygrad.helpers import strip_parens, getenv, prod
|
||||
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType, ConstType
|
||||
from tinygrad.codegen.uops import UOpGraph
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.renderer import Renderer, TensorCore
|
||||
|
||||
class CStyleLanguage(Renderer):
|
||||
kernel_prefix: str = ""
|
||||
|
@ -206,8 +206,9 @@ class OpenCLRenderer(CStyleLanguage):
|
|||
|
||||
class MetalRenderer(CStyleLanguage):
|
||||
device = "METAL"
|
||||
has_tensor_cores=os.uname().machine == "arm64"
|
||||
shared_max=32768
|
||||
shared_max = 32768
|
||||
tensor_cores = [TensorCore(dims=(8,8,8), threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[[2],[2],[2]], thread_local_aliases=[ [[0],[2],[0],[4],[-1, 1, 3],[0]], [[1],[0],[3],[0],[2, 4],[-1]], [[1],[2],[3],[4],[0],[-1]] ], dtype_in=di, dtype_out=do) for (di, do) in [(dtypes.float, dtypes.float), (dtypes.half, dtypes.float), (dtypes.half, dtypes.half)]] # noqa: E501
|
||||
def __init__(self): self.tensor_cores = MetalRenderer.tensor_cores if os.uname().machine == "arm64" else []
|
||||
|
||||
# language options
|
||||
kernel_prefix = "kernel "
|
||||
|
@ -251,11 +252,11 @@ def _make_cuda_dtype(base_type, name, cnt):
|
|||
|
||||
class CUDARenderer(CStyleLanguage):
|
||||
device = "CUDA"
|
||||
global_max=[65535, 65535, 2147483647]
|
||||
local_max=[64, 1024, 1024]
|
||||
shared_max=49152
|
||||
has_tensor_cores = False
|
||||
def __init__(self, arch:str): self.has_tensor_cores=int(arch[3:]) >= 80
|
||||
global_max = [65535, 65535, 2147483647]
|
||||
local_max = [64, 1024, 1024]
|
||||
shared_max = 49152
|
||||
tensor_cores = [TensorCore(dims=(8,16,16), threads=[(0,2),(0,2),(1,2),(1,2),(0,2)], thread_local_sizes=[[2,2,2],[2,2],[2,2]], thread_local_aliases=[ [[0],[0],[5],[-2],[0],[-1,1,2,-3],[3,4]], [[3],[4],[0],[0],[5],[-1,1,2,-2],[0]], [[-1],[1],[5],[-2],[2],[0],[3,4]] ], dtype_in=di, dtype_out=do) for (di, do) in ([(dtypes.half, dtypes.float), (dtypes.bfloat16, dtypes.float)])] # noqa: E501
|
||||
def __init__(self, arch:str): self.tensor_cores = CUDARenderer.tensor_cores if int(arch[3:]) >= 80 else []
|
||||
|
||||
# language options
|
||||
kernel_prefix = "extern \"C\" __global__ "
|
||||
|
@ -313,8 +314,8 @@ def _make_hip_dtype(base_type, name, cnt):
|
|||
|
||||
class HIPRenderer(CStyleLanguage):
|
||||
device = "HSA"
|
||||
has_tensor_cores = True
|
||||
shared_max = 65536
|
||||
tensor_cores = [TensorCore(dims=(16,16,16), threads=[(0,8),(0,2),(1,2)], thread_local_sizes=[[16],[16],[4,2]], thread_local_aliases=[ [[0],[0],[2],[-1],[1]], [[1],[2],[0],[-1],[0]], [[1],[2],[-2],[0],[3,-1]] ], dtype_in=di, dtype_out=do) for (di, do) in [(dtypes.half, dtypes.float), (dtypes.half, dtypes.half)]] # noqa: E501
|
||||
|
||||
# language options
|
||||
kernel_prefix = """extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_local_id(unsigned int);
|
||||
|
@ -378,3 +379,6 @@ static __attribute__((device)) bool operator==(hip_bfloat16 a, hip_bfloat16 b) {
|
|||
# https://clang.llvm.org/docs/AttributeReference.html#amdgpu-flat-work-group-size
|
||||
# NOTE: this makes hlb_cifar10 twice as fast, there may be more gains in tweaking these parameters
|
||||
return f"__attribute__((amdgpu_flat_work_group_size(1, {requiredMaxThreadsPerBlock})))"
|
||||
|
||||
class NVRenderer(CUDARenderer): device = "NV"
|
||||
class AMDRenderer(HIPRenderer): device = "AMD"
|
||||
|
|
|
@ -3,7 +3,7 @@ from typing import Tuple, List, Any, cast
|
|||
import os, fcntl, ctypes, ctypes.util, functools, re, pathlib, mmap, struct, errno, subprocess, time
|
||||
from tinygrad.device import Compiled, Compiler, BufferOptions, LRUAllocator
|
||||
from tinygrad.helpers import getenv, from_mv, init_c_struct_t, to_mv, round_up, DEBUG
|
||||
from tinygrad.renderer.cstyle import HIPRenderer
|
||||
from tinygrad.renderer.cstyle import AMDRenderer
|
||||
from tinygrad.runtime.driver.hip_comgr import compile_hip
|
||||
from tinygrad.runtime.ops_hsa import HSACompiler
|
||||
import tinygrad.runtime.autogen.kfd as kfd
|
||||
|
@ -549,7 +549,7 @@ class AMDDevice(Compiled):
|
|||
self.pm4_doorbell = to_mv(self.doorbells + self.pm4_queue.doorbell_offset - self.doorbells_base, 8).cast("Q")
|
||||
|
||||
from tinygrad.runtime.graph.hcq import HCQGraph
|
||||
super().__init__(device, AMDAllocator(self), HIPRenderer(), HSACompiler(self.arch),
|
||||
super().__init__(device, AMDAllocator(self), AMDRenderer(), HSACompiler(self.arch),
|
||||
functools.partial(AMDProgram, self),
|
||||
functools.partial(HCQGraph, AMDDevice, HWPM4Queue, HWCopyQueue))
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@ import os, ctypes, pathlib, re, fcntl, functools, mmap, struct, tempfile, hashli
|
|||
from typing import Tuple, List, Any, cast
|
||||
from tinygrad.device import Compiled, Compiler, LRUAllocator, BufferOptions
|
||||
from tinygrad.helpers import getenv, from_mv, init_c_struct_t, to_mv, round_up, to_char_p_p, DEBUG, prod
|
||||
from tinygrad.renderer.cstyle import CUDARenderer
|
||||
from tinygrad.renderer.cstyle import NVRenderer
|
||||
from tinygrad.runtime.ops_cuda import check as cuda_check, _get_bytes, CUDACompiler
|
||||
import tinygrad.runtime.autogen.cuda as cuda
|
||||
import tinygrad.runtime.autogen.nv_gpu as nv_gpu
|
||||
|
@ -519,7 +519,7 @@ class NVDevice(Compiled):
|
|||
self.arch: str = "sm_89" if not MOCKGPU else "sm_35" # TODO: fix
|
||||
|
||||
from tinygrad.runtime.graph.hcq import HCQGraph
|
||||
super().__init__(device, NVAllocator(self), CUDARenderer(self.arch), CUDACompiler(self.arch) if MOCKGPU else NVCompiler(self.arch),
|
||||
super().__init__(device, NVAllocator(self), NVRenderer(self.arch), CUDACompiler(self.arch) if MOCKGPU else NVCompiler(self.arch),
|
||||
functools.partial(NVProgram, self), functools.partial(HCQGraph, NVDevice, HWComputeQueue, HWCopyQueue))
|
||||
|
||||
self._cmdq_setup_compute_gpfifo()
|
||||
|
|
|
@ -9,6 +9,7 @@ from tinygrad.device import Compiled, Compiler, Allocator
|
|||
from tinygrad.codegen.uops import UOpGraph, UOps
|
||||
from tinygrad.ops import BinaryOps, TernaryOps, exec_alu
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.renderer.cstyle import CUDARenderer, MetalRenderer, HIPRenderer
|
||||
|
||||
def _load(m, i):
|
||||
if i < 0 or i >= len(m): raise IndexError(f"load out of bounds, size is {len(m)} and access is {i}")
|
||||
|
@ -181,9 +182,9 @@ class PythonProgram:
|
|||
class PythonRenderer(Renderer):
|
||||
device = "PYTHON"
|
||||
def __init__(self):
|
||||
if getenv("EMULATE_METAL"): self.device, self.has_tensor_cores = "METAL", True
|
||||
if getenv("EMULATE_HSA"): self.device, self.has_tensor_cores = "HSA", True
|
||||
if getenv("EMULATE_CUDA"): self.device, self.has_tensor_cores = "CUDA", True
|
||||
if getenv("EMULATE_METAL"): self.device, self.tensor_cores = "METAL", MetalRenderer.tensor_cores
|
||||
if getenv("EMULATE_HSA"): self.device, self.tensor_cores = "HSA", HIPRenderer.tensor_cores
|
||||
if getenv("EMULATE_CUDA"): self.device, self.tensor_cores = "CUDA", CUDARenderer.tensor_cores
|
||||
|
||||
def render(self, name:str, uops:UOpGraph) -> str:
|
||||
lops = [(u.uop, u.dtype, [uops.uops.index(v) for v in u.vin], u.arg) for u in uops]
|
||||
|
|
Loading…
Reference in New Issue