Update251203 (#233)

This commit is contained in:
carrot
2025-12-03 10:28:27 +09:00
committed by GitHub
parent d6899edd97
commit c5ebcbcb97
347 changed files with 8678 additions and 13489 deletions

View File

@@ -2,27 +2,39 @@
# a python uops emulator
# works to test the tensor cores, and all the uops in general
# this is the (living) definition of uops
from typing import Any, TYPE_CHECKING
from typing import Any, TYPE_CHECKING, cast
import pickle, base64, itertools, time, struct, sys
from tinygrad.dtype import DType, dtypes, ImageDType, PtrDType, truncate
from tinygrad.helpers import all_same, getenv, flatten, get_single_element
from tinygrad.dtype import DType, dtypes, ImageDType, PtrDType, truncate, float_to_bf16, float_to_fp8, fp8_to_float
from tinygrad.helpers import all_same, getenv, flatten, get_single_element, EMULATE
from tinygrad.device import Compiled, Compiler, Allocator
from tinygrad.codegen.opt import tc
from tinygrad.uop.ops import exec_alu, Ops, UOp, GroupOp
from tinygrad.uop.ops import exec_alu, python_alu, Ops, UOp, GroupOp
from tinygrad.renderer import Renderer
def _load(m, i):
def storage_fmt_for_dtype(dtype: DType): return 'H' if dtype == dtypes.bfloat16 else 'B' if dtype in dtypes.fp8s else dtype.fmt
def to_storage_scalar(x, dtype: DType):
if dtype == dtypes.bfloat16: return (struct.unpack('I', struct.pack('f', float_to_bf16(x)))[0] >> 16) & 0xFFFF
if dtype in dtypes.fp8s: return float_to_fp8(float(x), dtype)
return x
def from_storage_scalar(x, dtype: DType):
if dtype == dtypes.bfloat16: return struct.unpack('f', struct.pack('I', (x & 0xFFFF) << 16))[0]
if dtype in dtypes.fp8s: return fp8_to_float(int(x), dtype)
return x
def _load(m, i, dtype: DType):
if i is None: return 0.0
if i < 0 or i >= len(m): raise IndexError(f"load out of bounds, size is {len(m)} and access is {i}")
return m[i]
return from_storage_scalar(m[i], dtype)
def load(inp, j=0):
if len(inp) == 2: return [_load(m, x+j if x is not None else None) if gate else default for (m,x,gate),default in zip(*inp)]
return [_load(m, x+j if x is not None else None) for m,x,_ in inp[0]]
def load(inp, j, dtype: DType):
if len(inp) == 2: return [_load(m, x+j if x is not None else None, dtype) if gate else default for (m,x,gate),default in zip(*inp)]
return [_load(m, x+j if x is not None else None, dtype) for m,x,_ in inp[0]]
def _store(m, i, v):
def _store(m, i, v, dtype: DType):
if i < 0 or i >= len(m): raise IndexError(f"store out of bounds, size is {len(m)}, access is {i}, value is {v}")
m[i] = v
m[i] = to_storage_scalar(v, dtype)
class PythonProgram:
def __init__(self, name:str, lib:bytes):
@@ -57,24 +69,25 @@ class PythonProgram:
if uop is Ops.STORE:
for j,val in enumerate(inp[1] if dtp[1].count > 1 else [inp[1]]):
for (m,o,g),v in zip(inp[0], val):
if g: _store(m, o+j, v)
if g: _store(m, o+j, v, dtp[1].scalar())
i += 1
continue
if uop in {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG}:
assert isinstance(dtype, PtrDType), dtype
if dtype.fmt is None: raise RuntimeError(f"{dtype=} is not supported")
if TYPE_CHECKING or sys.version_info < (3, 12): assert dtype.fmt != "e"
storage_fmt = storage_fmt_for_dtype(dtype.base.scalar())
if storage_fmt is None: raise RuntimeError(f"{dtype=} is not supported")
if TYPE_CHECKING or sys.version_info < (3, 12): assert storage_fmt != "e"
if uop is Ops.DEFINE_REG:
# REGs are per thread
ul[i] = [memoryview(bytearray(dtype.size*dtype.itemsize)).cast(dtype.fmt) for _ in range(warp_size)]
ul[i] = [memoryview(bytearray(dtype.size*dtype.itemsize)).cast(storage_fmt) for _ in range(warp_size)]
else:
buf = memoryview(bytearray(dtype.size*dtype.itemsize)) if uop is not Ops.DEFINE_GLOBAL else pbufs.pop(0)
ul[i] = [buf.cast(dtype.fmt)] * warp_size
ul[i] = [buf.cast(storage_fmt)] * warp_size
elif uop is Ops.DEFINE_VAR:
ul[i] = [pvals.pop(0)] * warp_size
elif uop is Ops.SPECIAL:
if arg[0][0] == 'g': ul[i] = [idxs[2-int(arg[0][-1])]] * warp_size
elif arg[0][0] == 'l': ul[i] = [x[2-int(arg[0][-1])] for x in warp]
if arg[0] == 'g': ul[i] = [idxs[2-int(arg[-1])]] * warp_size
elif arg[0] == 'l': ul[i] = [x[2-int(arg[-1])] for x in warp]
elif uop is Ops.CONST: ul[i] = [arg] * warp_size
elif uop is Ops.INDEX:
ret:list = []
@@ -98,16 +111,17 @@ class PythonProgram:
continue
elif uop is Ops.VECTORIZE: ul[i] = inp
elif uop is Ops.BITCAST:
assert dtp[0].fmt and dtype.fmt
pack_format, unpack_format = str(warp_size) + dtp[0].fmt, str(warp_size) + dtype.fmt
ul[i] = list(struct.unpack(unpack_format, struct.pack(pack_format, *inp[0])))
packed = struct.pack(str(warp_size) + storage_fmt_for_dtype(dtp[0].scalar()), *[to_storage_scalar(x, dtp[0].scalar()) for x in inp[0]])
ul[i] = list(struct.unpack(str(warp_size) + storage_fmt_for_dtype(dtype.scalar()), packed))
ul[i] = [from_storage_scalar(x, dtype.scalar()) for x in ul[i]]
elif uop is Ops.CAST:
ul[i] = [truncate.get(dtype, lambda dt: dt)(dtypes.as_const(x, dtype)) for x in inp[0]]
elif uop is Ops.LOAD:
if dtype.count > 1:
ul[i] = [load([inp[i][j] if i != 0 and dtp[i].count > 1 else inp[i] for i in range(len(inp))], j) for j in range(dtype.count)]
ul[i] = [load([inp[i][j] if i != 0 and dtp[i].count > 1 else inp[i] for i in range(len(inp))], j, dtype.scalar()) \
for j in range(dtype.count)]
else:
ul[i] = load(inp)
ul[i] = load(inp, 0, dtype)
elif uop is Ops.GEP: ul[i] = inp[0][get_single_element(arg)]
elif uop is Ops.WMMA:
# here are the models for the WMMA instruction on the different hardware
@@ -188,7 +202,7 @@ class PythonProgram:
else: raise NotImplementedError(f"unimplemented tensor core {arg}")
elif uop in GroupOp.ALU:
assert all_same([len(x) for x in inp]), f"{[len(x) for x in inp]} doesn't match on {uop}"
assert all_same([dtype] + dtp) or uop in {Ops.CMPNE, Ops.CMPLT, Ops.WHERE}, f"dtype mismatch on {uop}"
assert all_same([dtype] + dtp) or uop in {*GroupOp.Comparison, Ops.WHERE}, f"dtype mismatch on {uop}"
ul[i] = [exec_alu(uop, dtype, p) for p in zip(*inp)]
assert i in ul, (uop, dtype, idp, arg)
i += 1
@@ -196,18 +210,23 @@ class PythonProgram:
class PythonRenderer(Renderer):
device = "PYTHON"
code_for_op = python_alu
def __init__(self):
if getenv("EMULATE_METAL"): self.device, self.tensor_cores = "METAL", tc.metal
if getenv("EMULATE_AMD"): self.device, self.tensor_cores = "AMD", tc.amd_rdna3
if getenv("EMULATE_AMD_MFMA"): self.device, self.tensor_cores = "AMD", tc.amd_cdna
if getenv("EMULATE_AMD_RDNA4"): self.device, self.tensor_cores = "AMD", tc.amd_rdna4
if getenv("EMULATE_CUDA"): self.device, self.tensor_cores = "CUDA", tc.cuda_sm80
if getenv("EMULATE_CUDA_SM75"): self.device, self.tensor_cores = "CUDA", tc.cuda_sm75
if getenv("EMULATE_INTEL"): self.device, self.suffix, self.tensor_cores = "INTEL", "INTEL", tc.intel
if getenv("EMULATE_AMX"): self.device, self.tensor_cores = "CPU", tc.amx
match cast(str, EMULATE.value):
case "METAL": self.device, self.tensor_cores = "METAL", tc.metal
case "AMD": self.device, self.tensor_cores = "AMD", tc.amd_rdna3
case "AMD_MFMA": self.device, self.tensor_cores = "AMD", tc.amd_cdna
case "AMD_RDNA4": self.device, self.tensor_cores = "AMD", tc.amd_rdna4
case "CUDA": self.device, self.tensor_cores = "CUDA", tc.cuda_sm80
case "CUDA_SM75": self.device, self.tensor_cores = "CUDA", tc.cuda_sm75
case "INTEL": self.device, self.suffix, self.tensor_cores = "INTEL", "INTEL", tc.intel
case "AMX": self.device, self.tensor_cores = "CPU", tc.amx
case "": pass
case _: raise RuntimeError(f"can't EMULATE device: {EMULATE.value}")
def render(self, uops:list[UOp]) -> str:
lops = [(u.op, u.dtype, [uops.index(v) for v in u.src], u.arg) for u in uops]
# the value of SPECIAL comes from local/global_size, not form its source
lops = [(u.op, u.dtype, [uops.index(v) for v in u.src if u.op is not Ops.SPECIAL], u.arg) for u in uops]
return base64.b64encode(pickle.dumps(lops)).decode()
class PythonCompiler(Compiler):
@@ -219,4 +238,4 @@ class PythonAllocator(Allocator['PythonDevice']):
def _copyout(self, dest:memoryview, src): dest[:] = src
class PythonDevice(Compiled):
def __init__(self, device:str): super().__init__(device, PythonAllocator(self), PythonRenderer(), PythonCompiler(), PythonProgram)
def __init__(self, device:str): super().__init__(device, PythonAllocator(self), [(PythonRenderer, PythonCompiler)], PythonProgram)