move tc to renderers (#4631)

* move tc to renderers

* missed import

* fix typo

* fix

* fix imports

* remove from tests

* fix 4607

* nv emulate timestamp

* time is int

* correct time
This commit is contained in:
nimlgen 2024-05-18 00:36:29 +03:00 committed by GitHub
parent d70988dddf
commit daf57af3eb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 60 additions and 67 deletions

View File

@ -114,7 +114,10 @@ class GPFIFO:
val = self._state64_le(nv_gpu.NVC56F_SEM_PAYLOAD_LO)
flags = self._next_dword()
typ = (flags >> 0) & 0b111
if typ == 1: to_mv(signal, 8).cast('Q')[0] = val
timestamp = (flags & (1 << 25)) == (1 << 25)
if typ == 1:
to_mv(signal, 8).cast('Q')[0] = val
if timestamp: to_mv(signal + 8, 8).cast('Q')[0] = int(time.perf_counter() * 1e9)
elif typ == 3:
mval = to_mv(signal, 8).cast('Q')[0]
return SchedResult.CONT if mval >= val else SchedResult.YIELD

View File

@ -3,7 +3,7 @@ import numpy as np
import unittest
from dataclasses import replace
from tinygrad.codegen.kernel import Opt, OptOps, KernelOptError, tensor_cores
from tinygrad.codegen.kernel import Opt, OptOps, KernelOptError
from tinygrad.codegen.linearizer import Linearizer, UOp, UOps, expand_node, expand_idxs
from tinygrad.device import Device, Buffer
from tinygrad.ops import BinaryOps, BufferOps, MemBuffer, ConstBuffer, LazyOp, LoadOps, TernaryOps, ReduceOps, UnaryOps
@ -261,16 +261,16 @@ class TestLinearizer(unittest.TestCase):
helper_arg_acc_dtype(d.conv2d(w, acc_dtype=acc_dtype), expected_dtype)
def test_tensor_cores(self):
if not Device[Device.DEFAULT].renderer.has_tensor_cores:
if not Device[Device.DEFAULT].renderer.tensor_cores:
self.skipTest("device doesn't have tensor cores")
for tc in tensor_cores[Device[Device.DEFAULT].renderer.device]:
for tc in Device[Device.DEFAULT].renderer.tensor_cores:
if getenv("EMULATE_CUDA") and (tc.dtype_in == dtypes.bfloat16 or tc.dtype_out == dtypes.bfloat16): continue
helper_tc_allclose(tc.dims[0], tc.dims[1], tc.dims[2], tc.dtype_in, tc.dtype_out, axis=0, tc_opt=0)
def test_tensor_cores_padded(self):
if not Device[Device.DEFAULT].renderer.has_tensor_cores:
if not Device[Device.DEFAULT].renderer.tensor_cores:
self.skipTest("device doesn't have tensor cores")
for tc in tensor_cores[Device[Device.DEFAULT].renderer.device]:
for tc in Device[Device.DEFAULT].renderer.tensor_cores:
if getenv("EMULATE_CUDA") and (tc.dtype_in == dtypes.bfloat16 or tc.dtype_out == dtypes.bfloat16): continue
pad = 1
@ -294,9 +294,9 @@ class TestLinearizer(unittest.TestCase):
@unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI is really slow here")
def test_tensor_cores_multi_reduce(self):
if not Device[Device.DEFAULT].renderer.has_tensor_cores:
if not Device[Device.DEFAULT].renderer.tensor_cores:
self.skipTest("device doesn't have tensor cores")
for tc in tensor_cores[Device[Device.DEFAULT].renderer.device]:
for tc in Device[Device.DEFAULT].renderer.tensor_cores:
if getenv("EMULATE_CUDA") and (tc.dtype_in == dtypes.bfloat16 or tc.dtype_out == dtypes.bfloat16): continue
# this will be a M=G16, N=G32, M=G16, M=G16, K=R16, K=R16, K=R16 with 9 choices of TC MNK axes
golden_result = None
@ -786,10 +786,8 @@ class TestKernelOpts(unittest.TestCase):
])
def test_invalid_tensor_core_extra_opts(self):
if not Device[Device.DEFAULT].renderer.has_tensor_cores:
if not Device[Device.DEFAULT].renderer.tensor_cores:
self.skipTest("device doesn't have tensor cores")
if Device.DEFAULT not in tensor_cores:
self.skipTest("No tensor cores for device")
N = 128
Tensor.manual_seed(1552)
@ -807,10 +805,8 @@ class TestKernelOpts(unittest.TestCase):
assert k.apply_tensor_cores(use_tensor_cores=1, extra_opts=x), "no valid tensor core" # for METAL in runners
def test_buf_index_not_found_tensor_core(self):
if not Device[Device.DEFAULT].renderer.has_tensor_cores:
if not Device[Device.DEFAULT].renderer.tensor_cores:
self.skipTest("device doesn't have tensor cores")
if Device.DEFAULT not in tensor_cores:
self.skipTest("No tensor cores for device")
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BinaryOps.CMPEQ, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(1243, 256), strides=(0, 1), offset=0, mask=None, contiguous=False),)))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(1243, 256), strides=(1, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=dtypes.float), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=3, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1243, 256), strides=(1, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=(0,)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 256), strides=(0, 1), offset=0, mask=None, contiguous=True),)))) # noqa: E501
k = Linearizer(ast, opts=Device[Device.DEFAULT].renderer)
@ -818,14 +814,12 @@ class TestKernelOpts(unittest.TestCase):
k.apply_opt(Opt(OptOps.TC, 0, 1))
def test_tensor_core_opts(self):
if not Device[Device.DEFAULT].renderer.has_tensor_cores:
if not Device[Device.DEFAULT].renderer.tensor_cores:
self.skipTest("device doesn't have tensor cores")
if Device.DEFAULT not in tensor_cores:
self.skipTest("No tensor cores for device")
N = 128
Tensor.manual_seed(1552)
for tc in tensor_cores[Device[Device.DEFAULT].renderer.device]:
for tc in Device[Device.DEFAULT].renderer.tensor_cores:
# bf16 buffer returns float32 numpy outputs so test would fail. testing opt with half suffices.
if tc.dtype_in == dtypes.bfloat16: continue
a, b = Tensor.rand(N, N, dtype=tc.dtype_in), Tensor.rand(N, N, dtype=tc.dtype_in)

View File

@ -14,7 +14,6 @@ from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
class TestTimeLinearizer(unittest.TestCase):
@unittest.skipIf(Device.DEFAULT in {"AMD", "NV"}, "Tries to open HSA/CUDA. #4607")
def test_reasonable_time(self):
si = [i for i in create_schedule([Tensor([1,2,3,4]).add(1).lazydata]) if i.ast[0].op not in LoadOps][0]
out = Buffer(Device.DEFAULT, si.outputs[0].size, si.outputs[0].dtype).allocate()
@ -23,7 +22,6 @@ class TestTimeLinearizer(unittest.TestCase):
tm = time_linearizer(Linearizer(*si.ast), rawbufs, allow_test_size=False, cnt=10)
assert tm > 0 and tm != float('inf')
@unittest.skipIf(Device.DEFAULT in {"AMD", "NV"}, "Tries to open HSA/CUDA. #4607")
def test_bufs_from_lin(self):
si = [i for i in create_schedule([Tensor([1,2,3,4]).add(1).lazydata]) if i.ast[0].op not in LoadOps][0]
rawbufs = bufs_from_lin(lin:=Linearizer(*si.ast))
@ -33,7 +31,6 @@ class TestTimeLinearizer(unittest.TestCase):
assert all(r.size > 0 for r in rawbufs)
class TestBEAM(unittest.TestCase):
@unittest.skipIf(Device.DEFAULT in {"AMD", "NV"}, "Tries to open HSA/CUDA. #4607")
def test_dynamic_beam(self):
# TODO: make this infra globally usable
class Capture:
@ -69,7 +66,6 @@ class TestBEAM(unittest.TestCase):
if Opt(OptOps.GROUPTOP, 0, 0) in actions:
assert len([x for x in lins if x.applied_opts[0] == Opt(OptOps.GROUPTOP, axis=0, amt=3)]) == 0, "did not de-dup GROUPTOP"
@unittest.skipIf(Device.DEFAULT in {"NV"}, "Tries to open CUDA. #4607")
def test_filter_global_buffer(self):
# taken from https://github.com/tinygrad/tinygrad/issues/4612
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.MAX, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=0, mask=((0, 64128),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-64128, mask=((64128, 128256),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False)))))), arg=None), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=3, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-128256, mask=((128256, 192384),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False)))))), arg=None), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=4, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-192384, mask=((192384, 256512),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False)))))), arg=None), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=5, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-256512, mask=((256512, 320640),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False)))))), arg=None), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=6, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-320640, mask=((320640, 384768),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False)))))), arg=None), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1.4285714285714286, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 501, 256), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=(1,)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 256), strides=(0, 0, 1), offset=0, mask=None, contiguous=True),)))) # noqa: E501

View File

@ -3,7 +3,7 @@ import math, itertools
from typing import NamedTuple, Optional, List, Tuple, cast, Dict, Union
from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, ReduceOps, MemBuffer, ConstBuffer, BufferOps, UNSAFE_PAD_OPS
from tinygrad.device import Device
from tinygrad.renderer import Renderer
from tinygrad.renderer import Renderer, TensorCore
from tinygrad.dtype import dtypes, ImageDType, DType
from tinygrad.helpers import colored, ansilen, dedup, flatten, getenv, prod, DEBUG, round_up, all_int, get_contraction
from tinygrad.shape.shapetracker import ShapeTracker
@ -34,17 +34,6 @@ class Opt:
if self.op in {OptOps.GROUP, OptOps.GROUPTOP}: return k.first_reduce+k.group_for_reduces+self.axis
return self.axis
@dataclass(frozen=True)
class TensorCore: # D = A * B + C, A is (M x K), B is (K x N), C and D are (M x N)
dims: Tuple[int,int,int] # N, M, K
dtype_in: DType # dtype for A and B
dtype_out: DType # dtype for C and D
threads: List[Tuple[int,int]] # list of (TC dim,amt) that construct the warp thread structure
thread_local_aliases: List[List[List[int]]] # a list of [threads_1, ..., threads_n, upcast_1(unrolled), upcast_2(upcast)] defining the alias (-1 is upcast, 1-n is warp threads) for each TC dim # noqa: E501
thread_local_sizes: List[List[int]] # in each thread, the number of elements stored in registers for each TC dim
def __str__(self): return "_".join(["WMMA"] + list(map(str, self.dims)) + [self.dtype_in.name, self.dtype_out.name])
def num_upcasts(self): return len(self.thread_local_aliases[0]) - len(self.threads)
class TensorCoreOptions(NamedTuple):
bufs: Tuple[int, int] # the local aliased buffers for A and B
axes: List[int] # the location of the original N and M axes if still in the shape
@ -54,12 +43,6 @@ class TensorCoreOptions(NamedTuple):
if removed_axis < self.axes[tc_dim]: self.axes[tc_dim] -= 1
elif removed_axis == self.axes[tc_dim]: self.axes_exist[tc_dim] = False
tensor_cores: Dict[str, List[TensorCore]] = {
"METAL": [TensorCore(dims=(8,8,8), threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[[2],[2],[2]], thread_local_aliases=[ [[0],[2],[0],[4],[-1, 1, 3],[0]], [[1],[0],[3],[0],[2, 4],[-1]], [[1],[2],[3],[4],[0],[-1]] ], dtype_in=di, dtype_out=do) for (di, do) in [(dtypes.float, dtypes.float), (dtypes.half, dtypes.float), (dtypes.half, dtypes.half)]], # noqa: E501
"HSA": [TensorCore(dims=(16,16,16), threads=[(0,8),(0,2),(1,2)], thread_local_sizes=[[16],[16],[4,2]], thread_local_aliases=[ [[0],[0],[2],[-1],[1]], [[1],[2],[0],[-1],[0]], [[1],[2],[-2],[0],[3,-1]] ], dtype_in=di, dtype_out=do) for (di, do) in [(dtypes.half, dtypes.float), (dtypes.half, dtypes.half)]], # noqa: E501
"CUDA": [TensorCore(dims=(8,16,16), threads=[(0,2),(0,2),(1,2),(1,2),(0,2)], thread_local_sizes=[[2,2,2],[2,2],[2,2]], thread_local_aliases=[ [[0],[0],[5],[-2],[0],[-1,1,2,-3],[3,4]], [[3],[4],[0],[0],[5],[-1,1,2,-2],[0]], [[-1],[1],[5],[-2],[2],[0],[3,4]] ], dtype_in=di, dtype_out=do) for (di, do) in ([(dtypes.half, dtypes.float)] if getenv("PTX") else [(dtypes.half, dtypes.float), (dtypes.bfloat16, dtypes.float)])], # noqa: E501
}
class LocalBuffer(NamedTuple):
name: str
size: int
@ -340,8 +323,8 @@ class Kernel:
# ******************** high level optimizers ********************
def _apply_tc_opt(self, use_tensor_cores:int, axis:int, opt_level:int) -> bool:
if use_tensor_cores and self.opts.has_local and self.reduceop and self.reduceop.op is ReduceOps.SUM and self.opts.device in tensor_cores:
for tc in tensor_cores[self.opts.device]:
if use_tensor_cores and self.opts.has_local and self.reduceop and self.reduceop.op is ReduceOps.SUM:
for tc in self.opts.tensor_cores:
has_cast = tc.dtype_in != tc.dtype_out
if has_cast and not(self.reduceop.src[0].op is UnaryOps.CAST and self.reduceop.src[0].arg == tc.dtype_out): continue
@ -403,7 +386,7 @@ class Kernel:
1: allows kernels with multiple reduce axes and also multiplication of UnaryOps.CAST'd buffers
2: allows kernels with M, N, K axes that are not multiples of the tensor core dimensions by applying padding those axes as needed
"""
if not self.opts.has_tensor_cores and use_tensor_cores != 2: return False
if not self.opts.tensor_cores and use_tensor_cores != 2: return False
try: # check TC first and apply hand-coded opts if successful
self.apply_opt(Opt(OptOps.TC, axis, tc_opt))
@ -435,7 +418,7 @@ class Kernel:
if opt.op is OptOps.TC:
check(len(self.applied_opts) == 0, "tensor core opts must be first") # TODO: things like PADTO might be fine
check(opt.axis is not None and opt.amt is not None, "tensor core opts must have an axis and amt")
check((use_tensor_cores:=getenv("TC", 1)) == 2 or self.opts.has_tensor_cores, "must have tensor cores or TC=2")
check((use_tensor_cores:=getenv("TC", 1)) == 2 or len(self.opts.tensor_cores) > 0, "must have tensor cores or TC=2")
check(self._apply_tc_opt(use_tensor_cores, cast(int, opt.axis), cast(int, opt.amt)), "no tensor core available")
self.applied_opts.append(opt)
return

View File

@ -4,6 +4,18 @@ from dataclasses import dataclass
from tinygrad.helpers import to_function_name
from tinygrad.codegen.uops import UOpGraph
from tinygrad.shape.symbolic import sym_infer, sint, Variable
from tinygrad.dtype import DType
@dataclass(frozen=True)
class TensorCore: # D = A * B + C, A is (M x K), B is (K x N), C and D are (M x N)
dims: Tuple[int,int,int] # N, M, K
dtype_in: DType # dtype for A and B
dtype_out: DType # dtype for C and D
threads: List[Tuple[int,int]] # list of (TC dim,amt) that construct the warp thread structure
thread_local_aliases: List[List[List[int]]] # a list of [threads_1, ..., threads_n, upcast_1(unrolled), upcast_2(upcast)] defining the alias (-1 is upcast, 1-n is warp threads) for each TC dim # noqa: E501
thread_local_sizes: List[List[int]] # in each thread, the number of elements stored in registers for each TC dim
def __str__(self): return "_".join(["WMMA"] + list(map(str, self.dims)) + [self.dtype_in.name, self.dtype_out.name])
def num_upcasts(self): return len(self.thread_local_aliases[0]) - len(self.threads)
@dataclass(frozen=True)
class Program:
@ -40,10 +52,10 @@ class Renderer:
supports_float4: bool = True
has_local: bool = True
has_shared: bool = True
has_tensor_cores: bool = False
# NOTE: these two should be in z,y,x(reversed) order for cstyle backends, they are flipped when kernel is rendered
global_max: Optional[List[int]] = None
local_max: Optional[List[int]] = None
shared_max: int = 32768
tensor_cores: List[TensorCore] = []
def render(self, name:str, uops:UOpGraph) -> str: raise NotImplementedError("needs a renderer")

View File

@ -5,7 +5,7 @@ from tinygrad.codegen.linearizer import UOps, UOp
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Op
from tinygrad.dtype import dtypes, DType, PtrDType, ConstType
from tinygrad.codegen.uops import UOpGraph, PatternMatcher
from tinygrad.renderer import Renderer
from tinygrad.renderer import Renderer, TensorCore
def render_val(x, dtype):
if dtypes.is_float(dtype):
@ -48,11 +48,11 @@ def optimize_gated_loads(uops: UOpGraph):
class PTXRenderer(Renderer):
device = "CUDA"
suffix = "PTX"
global_max=[65535, 65535, 2147483647]
local_max=[64, 1024, 1024]
shared_max=49152
has_tensor_cores = False
def __init__(self, arch:str): self.has_tensor_cores=int(arch[3:]) >= 80
global_max = [65535, 65535, 2147483647]
local_max = [64, 1024, 1024]
shared_max = 49152
tensor_cores = [TensorCore(dims=(8,16,16), threads=[(0,2),(0,2),(1,2),(1,2),(0,2)], thread_local_sizes=[[2,2,2],[2,2],[2,2]], thread_local_aliases=[ [[0],[0],[5],[-2],[0],[-1,1,2,-3],[3,4]], [[3],[4],[0],[0],[5],[-1,1,2,-2],[0]], [[-1],[1],[5],[-2],[2],[0],[3,4]] ], dtype_in=di, dtype_out=do) for (di, do) in ([(dtypes.half, dtypes.float)])] # noqa: E501
def __init__(self, arch:str): self.tensor_cores = PTXRenderer.tensor_cores if int(arch[3:]) >= 80 else []
# language options
kernel_prefix = """.version VERSION

View File

@ -6,7 +6,7 @@ from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
from tinygrad.helpers import strip_parens, getenv, prod
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType, ConstType
from tinygrad.codegen.uops import UOpGraph
from tinygrad.renderer import Renderer
from tinygrad.renderer import Renderer, TensorCore
class CStyleLanguage(Renderer):
kernel_prefix: str = ""
@ -206,8 +206,9 @@ class OpenCLRenderer(CStyleLanguage):
class MetalRenderer(CStyleLanguage):
device = "METAL"
has_tensor_cores=os.uname().machine == "arm64"
shared_max=32768
shared_max = 32768
tensor_cores = [TensorCore(dims=(8,8,8), threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[[2],[2],[2]], thread_local_aliases=[ [[0],[2],[0],[4],[-1, 1, 3],[0]], [[1],[0],[3],[0],[2, 4],[-1]], [[1],[2],[3],[4],[0],[-1]] ], dtype_in=di, dtype_out=do) for (di, do) in [(dtypes.float, dtypes.float), (dtypes.half, dtypes.float), (dtypes.half, dtypes.half)]] # noqa: E501
def __init__(self): self.tensor_cores = MetalRenderer.tensor_cores if os.uname().machine == "arm64" else []
# language options
kernel_prefix = "kernel "
@ -251,11 +252,11 @@ def _make_cuda_dtype(base_type, name, cnt):
class CUDARenderer(CStyleLanguage):
device = "CUDA"
global_max=[65535, 65535, 2147483647]
local_max=[64, 1024, 1024]
shared_max=49152
has_tensor_cores = False
def __init__(self, arch:str): self.has_tensor_cores=int(arch[3:]) >= 80
global_max = [65535, 65535, 2147483647]
local_max = [64, 1024, 1024]
shared_max = 49152
tensor_cores = [TensorCore(dims=(8,16,16), threads=[(0,2),(0,2),(1,2),(1,2),(0,2)], thread_local_sizes=[[2,2,2],[2,2],[2,2]], thread_local_aliases=[ [[0],[0],[5],[-2],[0],[-1,1,2,-3],[3,4]], [[3],[4],[0],[0],[5],[-1,1,2,-2],[0]], [[-1],[1],[5],[-2],[2],[0],[3,4]] ], dtype_in=di, dtype_out=do) for (di, do) in ([(dtypes.half, dtypes.float), (dtypes.bfloat16, dtypes.float)])] # noqa: E501
def __init__(self, arch:str): self.tensor_cores = CUDARenderer.tensor_cores if int(arch[3:]) >= 80 else []
# language options
kernel_prefix = "extern \"C\" __global__ "
@ -313,8 +314,8 @@ def _make_hip_dtype(base_type, name, cnt):
class HIPRenderer(CStyleLanguage):
device = "HSA"
has_tensor_cores = True
shared_max = 65536
tensor_cores = [TensorCore(dims=(16,16,16), threads=[(0,8),(0,2),(1,2)], thread_local_sizes=[[16],[16],[4,2]], thread_local_aliases=[ [[0],[0],[2],[-1],[1]], [[1],[2],[0],[-1],[0]], [[1],[2],[-2],[0],[3,-1]] ], dtype_in=di, dtype_out=do) for (di, do) in [(dtypes.half, dtypes.float), (dtypes.half, dtypes.half)]] # noqa: E501
# language options
kernel_prefix = """extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_local_id(unsigned int);
@ -378,3 +379,6 @@ static __attribute__((device)) bool operator==(hip_bfloat16 a, hip_bfloat16 b) {
# https://clang.llvm.org/docs/AttributeReference.html#amdgpu-flat-work-group-size
# NOTE: this makes hlb_cifar10 twice as fast, there may be more gains in tweaking these parameters
return f"__attribute__((amdgpu_flat_work_group_size(1, {requiredMaxThreadsPerBlock})))"
class NVRenderer(CUDARenderer): device = "NV"
class AMDRenderer(HIPRenderer): device = "AMD"

View File

@ -3,7 +3,7 @@ from typing import Tuple, List, Any, cast
import os, fcntl, ctypes, ctypes.util, functools, re, pathlib, mmap, struct, errno, subprocess, time
from tinygrad.device import Compiled, Compiler, BufferOptions, LRUAllocator
from tinygrad.helpers import getenv, from_mv, init_c_struct_t, to_mv, round_up, DEBUG
from tinygrad.renderer.cstyle import HIPRenderer
from tinygrad.renderer.cstyle import AMDRenderer
from tinygrad.runtime.driver.hip_comgr import compile_hip
from tinygrad.runtime.ops_hsa import HSACompiler
import tinygrad.runtime.autogen.kfd as kfd
@ -549,7 +549,7 @@ class AMDDevice(Compiled):
self.pm4_doorbell = to_mv(self.doorbells + self.pm4_queue.doorbell_offset - self.doorbells_base, 8).cast("Q")
from tinygrad.runtime.graph.hcq import HCQGraph
super().__init__(device, AMDAllocator(self), HIPRenderer(), HSACompiler(self.arch),
super().__init__(device, AMDAllocator(self), AMDRenderer(), HSACompiler(self.arch),
functools.partial(AMDProgram, self),
functools.partial(HCQGraph, AMDDevice, HWPM4Queue, HWCopyQueue))

View File

@ -3,7 +3,7 @@ import os, ctypes, pathlib, re, fcntl, functools, mmap, struct, tempfile, hashli
from typing import Tuple, List, Any, cast
from tinygrad.device import Compiled, Compiler, LRUAllocator, BufferOptions
from tinygrad.helpers import getenv, from_mv, init_c_struct_t, to_mv, round_up, to_char_p_p, DEBUG, prod
from tinygrad.renderer.cstyle import CUDARenderer
from tinygrad.renderer.cstyle import NVRenderer
from tinygrad.runtime.ops_cuda import check as cuda_check, _get_bytes, CUDACompiler
import tinygrad.runtime.autogen.cuda as cuda
import tinygrad.runtime.autogen.nv_gpu as nv_gpu
@ -519,7 +519,7 @@ class NVDevice(Compiled):
self.arch: str = "sm_89" if not MOCKGPU else "sm_35" # TODO: fix
from tinygrad.runtime.graph.hcq import HCQGraph
super().__init__(device, NVAllocator(self), CUDARenderer(self.arch), CUDACompiler(self.arch) if MOCKGPU else NVCompiler(self.arch),
super().__init__(device, NVAllocator(self), NVRenderer(self.arch), CUDACompiler(self.arch) if MOCKGPU else NVCompiler(self.arch),
functools.partial(NVProgram, self), functools.partial(HCQGraph, NVDevice, HWComputeQueue, HWCopyQueue))
self._cmdq_setup_compute_gpfifo()

View File

@ -9,6 +9,7 @@ from tinygrad.device import Compiled, Compiler, Allocator
from tinygrad.codegen.uops import UOpGraph, UOps
from tinygrad.ops import BinaryOps, TernaryOps, exec_alu
from tinygrad.renderer import Renderer
from tinygrad.renderer.cstyle import CUDARenderer, MetalRenderer, HIPRenderer
def _load(m, i):
if i < 0 or i >= len(m): raise IndexError(f"load out of bounds, size is {len(m)} and access is {i}")
@ -181,9 +182,9 @@ class PythonProgram:
class PythonRenderer(Renderer):
device = "PYTHON"
def __init__(self):
if getenv("EMULATE_METAL"): self.device, self.has_tensor_cores = "METAL", True
if getenv("EMULATE_HSA"): self.device, self.has_tensor_cores = "HSA", True
if getenv("EMULATE_CUDA"): self.device, self.has_tensor_cores = "CUDA", True
if getenv("EMULATE_METAL"): self.device, self.tensor_cores = "METAL", MetalRenderer.tensor_cores
if getenv("EMULATE_HSA"): self.device, self.tensor_cores = "HSA", HIPRenderer.tensor_cores
if getenv("EMULATE_CUDA"): self.device, self.tensor_cores = "CUDA", CUDARenderer.tensor_cores
def render(self, name:str, uops:UOpGraph) -> str:
lops = [(u.uop, u.dtype, [uops.uops.index(v) for v in u.vin], u.arg) for u in uops]