JIT support in Interpreted (#2314)

* factor that out

* jit is supported everywhere

* fix some tests

* there's no jit supported device, the jit is everywhere

* fix test uops
This commit is contained in:
George Hotz 2023-11-15 11:13:38 -08:00 committed by GitHub
parent 9a20bc08d6
commit 70a65c201e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 181 additions and 158 deletions

View File

@ -1,6 +1,12 @@
repos:
- repo: local
hooks:
- id: mypy
name: mypy
entry: mypy tinygrad/ extra/helpers.py
language: system
always_run: true
pass_filenames: false
- id: ruff
name: ruff
entry: ruff .
@ -19,12 +25,6 @@ repos:
language: system
always_run: true
pass_filenames: false
- id: mypy
name: mypy
entry: mypy tinygrad/ extra/helpers.py
language: system
always_run: true
pass_filenames: false
- id: tests
name: subset of TORCH tests
entry: env PYTHONPATH="." TORCH=1 python3 -m pytest -n=4 test/unit/ test/test_ops.py test/test_dtype.py test/test_schedule.py test/test_custom_function.py test/test_assign.py test/test_symbolic_shapetracker.py

View File

@ -15,11 +15,11 @@ from tinygrad.tensor import Tensor
from tinygrad.nn import Embedding, Linear
from tinygrad.nn.state import safe_load, torch_load, load_state_dict, get_parameters
from tinygrad.helpers import GlobalCounters
from tinygrad.jit import TinyJit, JIT_SUPPORTED_DEVICE
from tinygrad.jit import TinyJit
from tinygrad.shape.symbolic import Variable
MAX_CONTEXT = 1024
JIT = getenv("JIT", 0 if CI else int(Device.DEFAULT in JIT_SUPPORTED_DEVICE))
JIT = getenv("JIT", 0 if CI else 1)
# https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> Tensor:

View File

@ -2,7 +2,7 @@
import unittest
import numpy as np
from tinygrad.tensor import Tensor
from tinygrad.jit import TinyJit, JIT_SUPPORTED_DEVICE
from tinygrad.jit import TinyJit
from tinygrad.helpers import dtypes, CI
from tinygrad.ops import Device
from test.helpers import derandomize_model
@ -14,7 +14,6 @@ def helper_test_jitted_correctness(gen, train, train_jit):
for _ in range(5): jit = train_jit(*gen()).numpy()
np.testing.assert_allclose(nojit, jit, rtol=1e-3, atol=1e-5)
@unittest.skipUnless(Device.DEFAULT in JIT_SUPPORTED_DEVICE, "needs JIT")
class TestJittedModels(unittest.TestCase):
def test_jitted_tiny_llama(self):
old_type = Tensor.default_type

View File

@ -3,7 +3,7 @@ import numpy as np
from tinygrad.tensor import Tensor
from tinygrad.nn import optim
from tinygrad.nn.state import get_parameters
from tinygrad.jit import TinyJit, JIT_SUPPORTED_DEVICE
from tinygrad.jit import TinyJit
from tinygrad.ops import Device, GlobalCounters
from tinygrad.helpers import CI, dtypes, getenv, prod
from test.helpers import derandomize_model
@ -47,7 +47,7 @@ class TestRealWorld(unittest.TestCase):
def test(t, t2): return model(t, 801, t2).realize()
helper_test("test_sd", lambda: (Tensor.randn(1, 4, 64, 64),Tensor.randn(1, 77, 768)), test, 18.0, 967)
@unittest.skipUnless(Device.DEFAULT in JIT_SUPPORTED_DEVICE and Device.DEFAULT not in ["LLVM"], "needs JIT, too long on CI LLVM")
@unittest.skipUnless((Device.DEFAULT not in ["LLVM", "CPU"] or not CI), "needs JIT, too long on CI LLVM")
def test_llama(self):
Tensor.default_type = dtypes.float16
@ -59,7 +59,7 @@ class TestRealWorld(unittest.TestCase):
# NOTE: only test one pass, not testing the dynamic shape autoregressive part
helper_test("test_llama", lambda: (Tensor([[1,]]),), test, 0.22 if CI else 13.5, 137 if CI else 521, all_jitted=True)
@unittest.skipUnless(Device.DEFAULT in JIT_SUPPORTED_DEVICE and (Device.DEFAULT not in ["LLVM"] or not CI), "needs JIT, too long on CI LLVM")
@unittest.skipUnless((Device.DEFAULT not in ["LLVM", "CPU"] or not CI), "needs JIT, too long on CI LLVM")
def test_gpt2(self):
Tensor.default_type = dtypes.float16
@ -70,7 +70,7 @@ class TestRealWorld(unittest.TestCase):
def test(t): return model(t, 0).realize()
helper_test("test_gpt2", lambda: (Tensor([[1,]]),), test, 0.21 if CI else 0.9, 140 if CI else 396, all_jitted=True)
@unittest.skipUnless(Device.DEFAULT in JIT_SUPPORTED_DEVICE and (Device.DEFAULT not in ["LLVM", "CLANG"] or not CI), "needs JIT, too long on CI LLVM and CLANG")
@unittest.skipUnless((Device.DEFAULT not in ["LLVM", "CLANG", "CPU"] or not CI), "needs JIT, too long on CI LLVM and CLANG")
def test_train_cifar(self):
# TODO: with default device
#old_default = Device.DEFAULT

View File

@ -9,7 +9,7 @@ from tinygrad.helpers import prod, dtypes
# *** first, we implement the atan2 op at the lowest level ***
# `atan2_gpu` for GPUBuffers and `atan2_cpu` for CPUBuffers
from tinygrad.lazy import LazyBuffer, create_lazybuffer
from tinygrad.ops import ASTRunner, Device
from tinygrad.ops import CompiledASTRunner, Device
from tinygrad.shape.shapetracker import ShapeTracker
import pytest
@ -20,7 +20,7 @@ def atan2_gpu(ret:LazyBuffer, a:LazyBuffer, b:LazyBuffer):
assert a.device == "GPU" and b.device == "GPU", "gpu function requires GPUBuffers"
assert a.dtype == b.dtype and a.dtype == dtypes.float32, "gpu function only supports float32"
ret.realized = Device[ret.device].buffer(prod(ret.shape), ret.dtype)
ASTRunner("atan2_gpu", """
CompiledASTRunner(None, "atan2_gpu", """
__kernel void atan2_gpu(global float *c, global float *a, global float *b) {
int idx = get_global_id(0);
c[idx] = atan2(a[idx], b[idx]);
@ -89,6 +89,7 @@ class TestCustomFunction(unittest.TestCase):
np.testing.assert_allclose(a.grad.numpy(), ta.grad.numpy(), atol=1e-5)
np.testing.assert_allclose(b.grad.numpy(), tb.grad.numpy(), atol=1e-5)
@unittest.skipIf(Device.DEFAULT in ["CPU"], "atan2_cpu not jittable")
def test_atan2_jit(self):
# custom ops even work in the JIT!
from tinygrad.jit import TinyJit

View File

@ -2,13 +2,13 @@
import unittest
import numpy as np
from tinygrad.tensor import Tensor, Device
from tinygrad.jit import TinyJit, JIT_SUPPORTED_DEVICE
from tinygrad.jit import TinyJit
import pytest
pytestmark = pytest.mark.webgpu
# NOTE: METAL fails, might be platform and optimization options dependent.
@unittest.skipUnless(Device.DEFAULT in JIT_SUPPORTED_DEVICE and Device.DEFAULT not in ["METAL", "WEBGPU"], f"no JIT on {Device.DEFAULT}")
@unittest.skipUnless(Device.DEFAULT not in ["METAL", "WEBGPU"], f"no JIT on {Device.DEFAULT}")
class TestJit(unittest.TestCase):
def test_simple_jit(self):
@TinyJit

View File

@ -6,7 +6,6 @@ from tinygrad.tensor import Tensor, Device
import numpy as np
@unittest.skipIf(getenv("ARM64") or getenv("PTX"), "ARM64 and PTX are not supported")
@unittest.skipUnless(Device.DEFAULT in ["GPU", "METAL", "CLANG", "CUDA", "LLVM"], f"{Device.DEFAULT} is not supported")
class TestSymbolicJit(unittest.TestCase):
def test_plus1(self):
def f(a): return (a+1).realize()

View File

@ -1,5 +1,4 @@
import unittest
from tinygrad.jit import JIT_SUPPORTED_DEVICE
from tinygrad.shape.symbolic import Variable
from tinygrad.helpers import getenv
from tinygrad.tensor import Tensor, Device

View File

@ -3,14 +3,14 @@ import unittest, math
import numpy as np
from tinygrad.helpers import dtypes, getenv, DType, PtrDType
from tinygrad.tensor import Device
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, ASTRunner, Compiled
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, CompiledASTRunner, Compiled
from tinygrad.codegen.linearizer import UOps, UOp
def _uops_to_prg(uops):
src, runtime_args = Device[Device.DEFAULT].renderer("test", uops)
return ASTRunner("test", src,
[1] if Device[Device.DEFAULT].linearizer_opts.has_local else None, [1] if Device[Device.DEFAULT].linearizer_opts.has_local else None,
runtime_args=runtime_args).build(Device[Device.DEFAULT].compiler, Device[Device.DEFAULT].runtime)
return CompiledASTRunner(None, "test", src,
[1] if Device[Device.DEFAULT].linearizer_opts.has_local else None, [1] if Device[Device.DEFAULT].linearizer_opts.has_local else None,
runtime_args=runtime_args).build(Device[Device.DEFAULT].compiler, Device[Device.DEFAULT].runtime)
def uop(uops:List[UOp], uop:UOps, dtype:Optional[DType], vin:Tuple[UOp, ...], arg:Any=None) -> UOp:
uops.append(UOp(uop, dtype, tuple(vin), arg))

View File

@ -8,8 +8,6 @@ from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.symbolic import Variable
from weakref import ref, WeakKeyDictionary
JIT_SUPPORTED_DEVICE = ["GPU", "CLANG", "METAL", "CUDA", "HIP", "WEBGPU", "LLVM"]
class TinyJit:
def __init__(self, fxn:Callable):
self.fxn: Callable = fxn
@ -28,8 +26,6 @@ class TinyJit:
def __get__(self, obj, objtype): return functools.partial(self.__call__, obj)
def __call__(self, *args, **kwargs) -> Any:
if Device.DEFAULT.split(":")[0] not in JIT_SUPPORTED_DEVICE: return self.fxn(*args, **kwargs) # only jit on supported device
# all inputs are realized
input_tensors: Dict[Union[int, str], Tensor] = {cast(Union[int, str], k):v.realize() for k,v in itertools.chain(enumerate(args), kwargs.items()) if v.__class__ is Tensor}
expected_sts_dtype = tuple([(v.lazydata.st.unbind(), v.dtype) for v in input_tensors.values()])

View File

@ -1,5 +1,5 @@
from __future__ import annotations
import importlib, inspect, functools, pathlib
import importlib, inspect, functools, pathlib, time
from enum import Enum, auto
from typing import TYPE_CHECKING, Union, Type, Tuple, Any, List, Optional, Dict, Callable, Mapping, cast
from tinygrad.helpers import ansilen, prod, DEBUG, getenv, GlobalCounters, DType, colored, BEAM, NOOPT, dedup, all_int
@ -26,6 +26,8 @@ OpType = Union[Type[UnaryOps], Type[BinaryOps], Type[ReduceOps], Type[MovementOp
if TYPE_CHECKING:
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.lazy import LazyBuffer
from tinygrad.codegen.linearizer import Linearizer
from tinygrad.codegen.kernel import LinearizerOptions
@dataclass(frozen=True)
class MemBuffer:
@ -104,48 +106,6 @@ class _Device:
return "CPU"
Device = _Device()
# **************** batch executor ****************
@dataclass(frozen=True)
class JitItem:
prg: ASTRunner
rawbufs: List[Optional[RawBuffer]]
class BatchExecutor:
def __init__(self, jit_cache: List[JitItem], input_rawbuffers: Dict[Union[int, str], RawBuffer], var_vals: Dict[Variable, int]):
self.jit_cache: List[JitItem] = jit_cache
self.input_replace: Dict[Tuple[int, int], Union[int, str]] = {}
self.op_estimate, self.mem_estimate = NumNode(0), NumNode(0)
for j,ji in enumerate(jit_cache):
if isinstance(ji.prg, ASTRunner): # TODO: this is just for world and needs to be refactored
self.op_estimate += ji.prg.op_estimate
self.mem_estimate += ji.prg.mem_estimate
for i,a in enumerate(ji.rawbufs):
if a in [v for v in input_rawbuffers.values()]:
self.input_replace[(j,i)] = [k for k,v in input_rawbuffers.items() if v == a][0]
assert set(self.input_replace.values()) == set(input_rawbuffers.keys()), "some input tensors not found"
self.clear_jit_inputs()
def __call__(self, input_rawbuffers: Dict[Union[int, str], RawBuffer], var_vals: Dict[Variable, int], wait=False):
for (j,i),input_name in self.input_replace.items(): self.jit_cache[j].rawbufs[i] = input_rawbuffers[input_name]
for ji in self.jit_cache: ji.prg(cast(List[RawBuffer], ji.rawbufs), var_vals, jit=True)
self.clear_jit_inputs()
def update_stats(self, var_vals: Dict[Variable, int], et: Optional[float]):
# TODO: this is mostly copied from ASTRunner
op_estimate = sym_infer(self.op_estimate, var_vals)
mem_estimate = sym_infer(self.mem_estimate, var_vals)
if DEBUG >= 2:
print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', 'CYAN')} kernels:{len(self.jit_cache):4d} inputs:{len(self.input_replace):3d} {' '.join([f'{k.expr}={v}' for k,v in var_vals.items()])[:50]:50s} OPs {int(op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " +
(str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)"))
GlobalCounters.kernel_count += len(self.jit_cache)
GlobalCounters.global_ops += sym_infer(self.op_estimate, var_vals)
GlobalCounters.global_mem += sym_infer(self.mem_estimate, var_vals)
if et is not None: GlobalCounters.time_sum_s += et
def clear_jit_inputs(self):
for (j,i) in self.input_replace.keys(): self.jit_cache[j].rawbufs[i] = None
# **************** independent FlopCounter ****************
@dataclass
@ -174,114 +134,159 @@ def get_lazyop_info(ast:LazyOp) -> FlopCounter:
def run_ast(ast): return InterpretedFlopCounter[ast.op](*([run_ast(x) for x in ast.src]+([ast.arg] if ast.arg is not None else [])))
return run_ast(ast)
# **************** batch executor ****************
@dataclass(frozen=True)
class JitItem:
prg: ASTRunner
rawbufs: List[Optional[RawBuffer]]
class BatchExecutor:
def __init__(self, jit_cache: List[JitItem], input_rawbuffers: Dict[Union[int, str], RawBuffer], var_vals: Dict[Variable, int]):
self.jit_cache: List[JitItem] = jit_cache
self.input_replace: Dict[Tuple[int, int], Union[int, str]] = {}
self.op_estimate, self.mem_estimate = NumNode(0), NumNode(0)
for j,ji in enumerate(jit_cache):
if isinstance(ji.prg, ASTRunner): # TODO: this is just for world and needs to be refactored
self.op_estimate += ji.prg.op_estimate
self.mem_estimate += ji.prg.mem_estimate
for i,a in enumerate(ji.rawbufs):
if a in [v for v in input_rawbuffers.values()]:
self.input_replace[(j,i)] = [k for k,v in input_rawbuffers.items() if v == a][0]
assert set(self.input_replace.values()) == set(input_rawbuffers.keys()), "some input tensors not found"
self.clear_jit_inputs()
def __call__(self, input_rawbuffers: Dict[Union[int, str], RawBuffer], var_vals: Dict[Variable, int], wait=False):
for (j,i),input_name in self.input_replace.items(): self.jit_cache[j].rawbufs[i] = input_rawbuffers[input_name]
for ji in self.jit_cache: ji.prg(ji.rawbufs, var_vals, jit=True)
self.clear_jit_inputs()
def update_stats(self, var_vals: Dict[Variable, int], et: Optional[float]):
# TODO: this is mostly copied from ASTRunner
op_estimate = sym_infer(self.op_estimate, var_vals)
mem_estimate = sym_infer(self.mem_estimate, var_vals)
if DEBUG >= 2:
print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', 'CYAN')} kernels:{len(self.jit_cache):4d} inputs:{len(self.input_replace):3d} {' '.join([f'{k.expr}={v}' for k,v in var_vals.items()])[:50]:50s} OPs {int(op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " +
(str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)"))
GlobalCounters.kernel_count += len(self.jit_cache)
GlobalCounters.global_ops += sym_infer(self.op_estimate, var_vals)
GlobalCounters.global_mem += sym_infer(self.mem_estimate, var_vals)
if et is not None: GlobalCounters.time_sum_s += et
def clear_jit_inputs(self):
for (j,i) in self.input_replace.keys(): self.jit_cache[j].rawbufs[i] = None
class ASTRunner:
def __init__(self, ast:Optional[LazyOp]):
if ast is None:
self.op_estimate, self.mem_estimate, self.vars = 0, 0, []
else:
info = get_lazyop_info(ast)
self.op_estimate, self.mem_estimate = info.flops, info.mem_estimate
from tinygrad.lazy import vars_from_ast
self.vars = vars_from_ast(ast)
assert all(v._val is None for v in self.vars), f"ASTRunner contains bound Variable {self.vars}"
def exec(self, rawbufs:List[Optional[RawBuffer]], var_vals:Optional[Dict[Variable, int]]=None, force_wait=False) -> Optional[float]:
from tinygrad.jit import CacheCollector
et = self(rawbufs, var_vals, force_wait=force_wait)
CacheCollector.add(self, rawbufs, var_vals if var_vals is not None else {})
return et
def update_stats(self, name, var_vals: Optional[Dict[Variable, int]], et: Optional[float], buf_count, lra, jit):
if var_vals is None: var_vals = {}
op_estimate = sym_infer(self.op_estimate, var_vals)
mem_estimate = sym_infer(self.mem_estimate, var_vals)
if DEBUG >= 2:
print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', 'magenta' if jit else None)} {name+' '*(37-ansilen(name))} arg {buf_count:3d} sz {str(lra.get('global_size', '')):18s} {str(lra.get('local_size', '')):12s} OPs {int(op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " +
(str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)"))
GlobalCounters.kernel_count += 1
GlobalCounters.global_ops += op_estimate
GlobalCounters.global_mem += mem_estimate
def __call__(self, rawbufs:List[Optional[RawBuffer]], var_vals:Optional[Dict[Variable, int]]=None, jit=False, force_wait=False) -> Optional[float]:
raise NotImplementedError("override this")
# **************** for Interpreted Buffers ****************
class InterpretedASTRunner(ASTRunner):
def __init__(self, ast:LazyOp, fxn:Callable):
self.fxn = fxn
super().__init__(ast)
def __call__(self, rawbufs:List[Optional[RawBuffer]], var_vals:Optional[Dict[Variable, int]]=None, jit=False, force_wait=False) -> float:
st = time.perf_counter()
ret: RawBuffer = self.fxn(rawbufs[1:], var_vals)
et = time.perf_counter() - st
self.update_stats(f"<interpreted {ret.size}>", var_vals, et, len(rawbufs), {}, jit)
if rawbufs[0] is not None:
assert rawbufs[0].dtype == ret.dtype
rawbufs[0].size = ret.size # NOTE: for symbolic this can change
rawbufs[0]._buf = ret._buf
else: rawbufs[0] = ret
return et
from tinygrad.runtime.interpreted import interpret_ast
class Interpreted:
def __init__(self, buffer: Type[RawBuffer], compiler: Callable[[LazyOp], Callable]):
self.buffer, self.compiler = buffer, compiler
def __init__(self, buffer: Type[RawBuffer], fxn_for_op:Dict[Op, Callable], from_underlying:Optional[Callable]=None):
self.buffer, self.fxn_for_op, self.from_underlying = buffer, fxn_for_op, from_underlying
self.synchronize = lambda: None
self.batch_executor = BatchExecutor
self.codegen = None
self.method_cache: Dict[LazyOp, Callable] = {}
self.method_cache: Dict[LazyOp, InterpretedASTRunner] = {}
def exec_ast(self, ast:LazyOp, output:LazyBuffer, inputs:Tuple[LazyBuffer, ...], var_vals:Dict[Variable, int], **kwargs):
if ast not in self.method_cache: self.method_cache[ast] = self.compiler(ast)
output.realized = output.output_buffer # NOTE: assuming this is the right size and dtype from assign
ret: RawBuffer = self.method_cache[ast]([x.realized for x in inputs] if inputs else None, var_vals)
assert output.dtype == ret.dtype, f"expected {output.dtype}, got {ret.dtype}"
if output.realized is not None: output.realized._buf = ret._buf
else: output.realized = ret
if ast not in self.method_cache: self.method_cache[ast] = InterpretedASTRunner(ast, interpret_ast(self.fxn_for_op, self.from_underlying, ast))
rawbufs = [output.realized if output.realized is not None else output.output_buffer] + [x.realized for x in inputs]
self.method_cache[ast].exec(rawbufs, var_vals)
output.realized = rawbufs[0]
# **************** for Compiled Buffers ****************
class ASTRunner:
def __init__(self, name:str, prg:str, global_size:Optional[List[int]]=None, local_size:Optional[List[int]]=None, op_estimate=0, mem_estimate=0, display_name:Optional[str]=None, runtime_args:Optional[dict]=None):
class CompiledASTRunner(ASTRunner):
def __init__(self, ast:Optional[LazyOp], name:str, prg:str, global_size:Optional[List[int]]=None, local_size:Optional[List[int]]=None, op_estimate=0, mem_estimate=0, display_name:Optional[str]=None, runtime_args:Optional[dict]=None):
if DEBUG >= 4: print(prg)
self.name, self.prg, self.global_size, self.local_size, self.op_estimate, self.mem_estimate, self.display_name, self.runtime_args = name, prg, global_size, local_size, op_estimate, mem_estimate, display_name, runtime_args if runtime_args is not None else {}
self.vars:List[Variable] = []
self.name, self.prg, self.global_size, self.local_size, self.op_estimate, self.mem_estimate, self.display_name, self.runtime_args = \
name, prg, global_size, local_size, op_estimate, mem_estimate, display_name, runtime_args if runtime_args is not None else {}
super().__init__(ast)
def build(self, compiler, runtime):
self.lib = compiler.__wrapped__(self.prg) if getenv("DISABLE_COMPILER_CACHE") else compiler(self.prg)
self.clprg = runtime(self.name, self.lib)
return self
def exec(self, rawbufs:List[RawBuffer], var_vals:Optional[Dict[Variable, int]]=None, force_wait=False) -> Optional[float]:
from tinygrad.jit import CacheCollector
CacheCollector.add(self, rawbufs, var_vals if var_vals is not None else {})
return self(rawbufs, var_vals, force_wait=force_wait)
def launch_dims(self, var_vals):
global_size = ([sym_infer(sz, var_vals) for sz in self.global_size] + [1]*(3-len(self.global_size))) if self.global_size is not None else self.global_size
local_size = ([sym_infer(sz, var_vals) for sz in self.local_size] + [1]*(3-len(self.local_size))) if self.local_size is not None else self.local_size
return global_size, local_size
def __call__(self, rawbufs:List[RawBuffer], var_vals:Optional[Dict[Variable, int]]=None, jit=False, force_wait=False) -> Optional[float]:
def __call__(self, rawbufs:List[Optional[RawBuffer]], var_vals:Optional[Dict[Variable, int]]=None, jit=False, force_wait=False) -> Optional[float]:
if var_vals is None: var_vals = {}
var_vals = {k:var_vals[k] for k in self.vars} # filter the var_vals
global_size, local_size = self.launch_dims(var_vals)
if global_size is not None and local_size is None and all_int(self.global_size): # type: ignore[arg-type]
# TODO: this is copied from get_program
from tinygrad.features.search import optimize_local_size
local_size = self.local_size = optimize_local_size(self.clprg, global_size, rawbufs)
local_size = self.local_size = optimize_local_size(self.clprg, global_size, cast(List[RawBuffer], rawbufs))
global_size = self.global_size = [g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)]
lra = self.runtime_args.copy()
if global_size: lra['global_size'] = global_size
if local_size and 'local_size' not in lra: lra['local_size'] = local_size
if et := self.clprg(*rawbufs, *var_vals.values(), **lra, wait=force_wait or DEBUG>=2): GlobalCounters.time_sum_s += et
op_estimate = sym_infer(self.op_estimate, var_vals)
mem_estimate = sym_infer(self.mem_estimate, var_vals)
if DEBUG >= 2:
print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', 'magenta' if jit else None)} {(self.display_name+' '*(37-ansilen(self.display_name))) if self.display_name is not None else self.name:33s} arg {len(rawbufs):3d} sz {str(global_size):18s} {str(local_size):12s} OPs {int(op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " +
(str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)"))
GlobalCounters.kernel_count += 1
GlobalCounters.global_ops += op_estimate
GlobalCounters.global_mem += mem_estimate
self.update_stats(self.display_name if self.display_name is not None else self.name, var_vals, et, len(rawbufs), lra, jit)
return et
class Compiled:
def __init__(self, buffer: Type[RawBuffer], linearizer_opts, renderer, compiler, runtime, synchronize=lambda: None, batch_executor=BatchExecutor):
def __init__(self, buffer: Type[RawBuffer], linearizer_opts:LinearizerOptions, renderer, compiler, runtime, synchronize=lambda: None, batch_executor=BatchExecutor):
self.buffer, self.linearizer_opts, self.renderer, self.compiler, self.runtime, self.synchronize = buffer, linearizer_opts, renderer, compiler, runtime, synchronize
self.batch_executor = BatchExecutor if getenv("JIT") == 2 else batch_executor
self.method_cache: Dict[LazyOp, ASTRunner] = {}
self.method_cache: Dict[LazyOp, CompiledASTRunner] = {}
def to_program(self, k) -> ASTRunner:
def to_program(self, k:Linearizer) -> CompiledASTRunner:
k.linearize()
src, runtime_args = self.renderer(k.function_name, k.uops)
return ASTRunner(k.function_name, src, k.global_size, k.local_size,
op_estimate=k.info.flops, mem_estimate=k.info.mem_estimate,
display_name=k.display_name, runtime_args=runtime_args).build(self.compiler, self.runtime)
def get_optimized_program(self, ast:LazyOp, rawbuffers:List[RawBuffer]) -> ASTRunner:
if DEBUG >= 3:
from tinygrad.graph import print_tree
print_tree(ast)
from tinygrad.codegen.linearizer import Linearizer
from tinygrad.lazy import vars_from_ast
k = Linearizer(ast, self.linearizer_opts)
assert k.info.dtype == rawbuffers[0].dtype, f"linearizer must match dtype. linearizer wants {k.info.dtype} but buffer is {rawbuffers[0].dtype}"
if not NOOPT:
if not (used_tensor_cores:=k.apply_tensor_cores(getenv("TC", 1))): k.hand_coded_optimizations()
if BEAM >= 1:
lins = [(("tc" if used_tensor_cores else "hc"), k)]
# allocate a scratch buffer if output buffer is also input
test_rawbuffers = [type(rawbuffers[0])(rawbuffers[0].size, rawbuffers[0].dtype), *rawbuffers[1:]] if rawbuffers[0] in rawbuffers[1:] else rawbuffers
kb = Linearizer(ast, self.linearizer_opts)
kb.required_optimizations()
from tinygrad.features.search import beam_search, time_linearizer
lins.append((f"beam{BEAM.value}", beam_search(kb, test_rawbuffers, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1)))))
if used_tensor_cores:
lins.append(("hc", Linearizer(ast, self.linearizer_opts)))
lins[-1][1].hand_coded_optimizations()
timed = sorted([(nm, tk, time_linearizer(tk, test_rawbuffers, allow_test_size=False, disable_cache=True, clear_l2=True)) for nm, tk in lins], key=lambda x: x[2])
if DEBUG >= 1: print(" < ".join(f"{nm:6s} : {lin.colored_shape(30, dense=True)} : {tm*1e6:8.2f} us" for nm, lin, tm in timed))
k = timed[0][1]
else:
k.required_optimizations()
prg = self.to_program(k)
# extract real vars used in ast
prg.vars = vars_from_ast(ast)
assert all(v._val is None for v in prg.vars), f"ast contains bound Variable {prg.vars}"
return prg
return CompiledASTRunner(k.ast, k.function_name, src, k.global_size, k.local_size,
display_name=k.display_name, runtime_args=runtime_args).build(self.compiler, self.runtime)
def exec_ast(self, ast:LazyOp, output:LazyBuffer, inputs:Tuple[LazyBuffer, ...], var_vals:Dict[Variable, int], **kwargs):
# check if we can reuse the output buffer
@ -305,5 +310,32 @@ class Compiled:
# all the rawbuffers
rawbuffers = [output.realized] + [x.realized for x in inputs]
if ast not in self.method_cache: self.method_cache[ast] = self.get_optimized_program(ast, rawbuffers)
if ast not in self.method_cache: self.method_cache[ast] = get_optimized_program(self.linearizer_opts, self.to_program, ast, rawbuffers)
self.method_cache[ast].exec(rawbuffers, var_vals)
def get_optimized_program(linearizer_opts:LinearizerOptions, to_program, ast:LazyOp, rawbuffers:List[RawBuffer]) -> CompiledASTRunner:
if DEBUG >= 3:
from tinygrad.graph import print_tree
print_tree(ast)
from tinygrad.codegen.linearizer import Linearizer
k = Linearizer(ast, linearizer_opts)
assert k.info.dtype == rawbuffers[0].dtype, f"linearizer must match dtype. linearizer wants {k.info.dtype} but buffer is {rawbuffers[0].dtype}"
if not NOOPT:
if not (used_tensor_cores:=k.apply_tensor_cores(getenv("TC", 1))): k.hand_coded_optimizations()
if BEAM >= 1:
lins = [(("tc" if used_tensor_cores else "hc"), k)]
# allocate a scratch buffer if output buffer is also input
test_rawbuffers = [type(rawbuffers[0])(rawbuffers[0].size, rawbuffers[0].dtype), *rawbuffers[1:]] if rawbuffers[0] in rawbuffers[1:] else rawbuffers
kb = Linearizer(ast, linearizer_opts)
kb.required_optimizations()
from tinygrad.features.search import beam_search, time_linearizer
lins.append((f"beam{BEAM.value}", beam_search(kb, test_rawbuffers, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1)))))
if used_tensor_cores:
lins.append(("hc", Linearizer(ast, linearizer_opts)))
lins[-1][1].hand_coded_optimizations()
timed = sorted([(nm, tk, time_linearizer(tk, test_rawbuffers, allow_test_size=False, disable_cache=True, clear_l2=True)) for nm, tk in lins], key=lambda x: x[2])
if DEBUG >= 1: print(" < ".join(f"{nm:6s} : {lin.colored_shape(30, dense=True)} : {tm*1e6:8.2f} us" for nm, lin, tm in timed))
k = timed[0][1]
else:
k.required_optimizations()
return to_program(k)

View File

@ -1,10 +1,9 @@
import numpy as np
import operator, functools
import operator
from typing import Callable, Dict, Tuple, Optional
from tinygrad.helpers import dtypes, DType
from tinygrad.ops import BufferOps, UnaryOps, BinaryOps, MovementOps, ReduceOps, TernaryOps, Op, Interpreted
from tinygrad.runtime.lib import RawBuffer
from tinygrad.runtime.interpreted import interpret_ast
def shape_to_axis(old_shape:Tuple[int, ...], new_shape:Tuple[int, ...]) -> Tuple[int, ...]:
assert len(old_shape) == len(new_shape), "reduce shapes must have same dimensions"
@ -52,4 +51,4 @@ class RawNumpyBuffer(RawBuffer):
@classmethod
def fromCPU(cls, x): return cls(x.size, dtypes.from_np(x.dtype), x)
def toCPU(self): return self._buf
CPUBuffer = Interpreted(RawNumpyBuffer, functools.partial(interpret_ast, numpy_fxn_for_op, RawNumpyBuffer.fromCPU))
CPUBuffer = Interpreted(RawNumpyBuffer, numpy_fxn_for_op, RawNumpyBuffer.fromCPU)

View File

@ -1,9 +1,8 @@
import os, mmap, functools
import os, mmap
from typing import Optional
from typing import Callable, Dict, Tuple
from tinygrad.helpers import prod, DType
from tinygrad.runtime.lib import RawBufferMapped
from tinygrad.runtime.interpreted import interpret_ast
from tinygrad.ops import Interpreted, Op, MovementOps, UnaryOps, BufferOps
class RawDiskBuffer(RawBufferMapped):
@ -39,4 +38,4 @@ class RawDiskBuffer(RawBufferMapped):
self._buf[0].readinto(buf)
disk_fxn_for_op: Dict[Op, Callable] = { BufferOps.MEM: lambda x: x, UnaryOps.NOOP: lambda x: x, UnaryOps.CAST: RawDiskBuffer.cast, MovementOps.AS_STRIDED: RawDiskBuffer.as_strided }
DiskBuffer = Interpreted(RawDiskBuffer, functools.partial(interpret_ast, disk_fxn_for_op, None))
DiskBuffer = Interpreted(RawDiskBuffer, disk_fxn_for_op)

View File

@ -1,10 +1,10 @@
# pip3 install pyobjc-framework-Metal pyobjc-framework-Cocoa pyobjc-framework-libdispatch
import os, subprocess, pathlib, ctypes, tempfile
import Metal, Cocoa, libdispatch
from typing import List, Any, Tuple, Dict, Union, Set
from typing import List, Any, Tuple, Dict, Union, Set, cast
from tinygrad.codegen.kernel import LinearizerOptions
from tinygrad.helpers import prod, getenv, DEBUG, DType, dtypes, diskcache, dedup, CI
from tinygrad.ops import Compiled, BatchExecutor, JitItem
from tinygrad.ops import Compiled, BatchExecutor, JitItem, CompiledASTRunner
from tinygrad.renderer.metal import MetalRenderer
from tinygrad.runtime.lib import RawBufferMapped, RawBuffer, LRUAllocator
from tinygrad.shape.symbolic import Variable, Node
@ -98,8 +98,9 @@ class MetalBatchExecutor(BatchExecutor):
self.input_has_variable_dims: Set[int] = set()
read_resources, write_resources = [], []
for j,ji in enumerate(self.jit_cache):
prg: CompiledASTRunner = cast(CompiledASTRunner, ji.prg)
descriptor = Metal.MTLComputePipelineDescriptor.new()
descriptor.setComputeFunction_(ji.prg.clprg.fxn)
descriptor.setComputeFunction_(prg.clprg.fxn)
descriptor.setSupportIndirectCommandBuffers_(True)
pipeline_state = unwrap(METAL.device.newComputePipelineStateWithDescriptor_options_reflection_error_(descriptor, Metal.MTLPipelineOption(0), None, None))
icb_command = self.icb.indirectComputeCommandAtIndex_(j)
@ -110,11 +111,11 @@ class MetalBatchExecutor(BatchExecutor):
if i == 0: write_resources.append(b._buf)
else: read_resources.append(b._buf)
var_vals_keys = list(var_vals.keys())
for i,v in enumerate(getattr(ji.prg,"vars",[])):
for i,v in enumerate(prg.vars):
icb_command.setKernelBuffer_offset_atIndex_(self.int_buf._buf, var_vals_keys.index(v)*4, len(ji.rawbufs)+i)
global_size, local_size = ji.prg.launch_dims(var_vals)
assert ji.prg.global_size and ji.prg.local_size, "need global and local size to JIT"
if any(isinstance(x, Node) for x in ji.prg.global_size) or any(isinstance(x, Node) for x in ji.prg.local_size):
global_size, local_size = prg.launch_dims(var_vals)
assert prg.global_size and prg.local_size, "need global and local size to JIT"
if any(isinstance(x, Node) for x in prg.global_size) or any(isinstance(x, Node) for x in prg.local_size):
self.input_has_variable_dims.add(j)
else:
icb_command.concurrentDispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size))
@ -130,7 +131,7 @@ class MetalBatchExecutor(BatchExecutor):
for (j,i),input_name in self.input_replace.items():
self.icb.indirectComputeCommandAtIndex_(j).setKernelBuffer_offset_atIndex_(input_rawbuffers[input_name]._buf, 0, i)
for j in self.input_has_variable_dims:
global_size, local_size = self.jit_cache[j].prg.launch_dims(var_vals)
global_size, local_size = cast(CompiledASTRunner, self.jit_cache[j].prg).launch_dims(var_vals)
self.icb.indirectComputeCommandAtIndex_(j).concurrentDispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size))
self.int_buf_view[:] = list(var_vals.values())
command_buffer = METAL.mtl_queue.commandBuffer()

View File

@ -1,11 +1,10 @@
import os, mmap, functools
import os, mmap
try: import _posixshmem
except Exception: pass
from typing import Callable, Dict
from tinygrad.helpers import DType, OSX
from tinygrad.runtime.lib import RawBufferMapped
from tinygrad.ops import Interpreted, Op, UnaryOps, MovementOps, BufferOps
from tinygrad.runtime.interpreted import interpret_ast
class RawShmBuffer(RawBufferMapped):
def __init__(self, size, dtype:DType, device:str):
@ -26,4 +25,4 @@ class RawShmBuffer(RawBufferMapped):
# TODO: is this wrong?
shm_fxn_for_op: Dict[Op, Callable] = { BufferOps.MEM: lambda x: x, UnaryOps.NOOP: lambda x:x, MovementOps.RESHAPE: lambda x,_:x, MovementOps.AS_STRIDED: lambda x,_:x }
ShmBuffer = Interpreted(RawShmBuffer, functools.partial(interpret_ast, shm_fxn_for_op, None))
ShmBuffer = Interpreted(RawShmBuffer, shm_fxn_for_op)

View File

@ -1,11 +1,10 @@
import torch, functools
import torch
import numpy as np
from typing import Dict, Callable, Optional
from tinygrad.ops import BufferOps, UnaryOps, BinaryOps, MovementOps, TernaryOps, Op, Interpreted
from tinygrad.helpers import getenv, dtypes, prod, DType
from tinygrad.runtime.ops_cpu import base_fxn_for_op, einsum_mulacc
from tinygrad.runtime.lib import RawBuffer
from tinygrad.runtime.interpreted import interpret_ast
device = torch.device("cuda:0" if torch.cuda.is_available() else ("mps" if getenv("MPS", 0) else "cpu"))
type_map = {torch.float64: dtypes.float64, torch.float16: dtypes.float16, torch.float32: dtypes.float32, torch.int8: dtypes.int8, torch.int32: dtypes.int32, torch.int64: dtypes.int64, torch.uint8: dtypes.uint8, torch.bool: dtypes.bool, torch.int16: dtypes.int16}
@ -49,4 +48,4 @@ class RawTorchBuffer(RawBuffer):
buf = torch.from_numpy(x if all(s>=0 for s in x.strides) else x.copy()).requires_grad_(False).to(device)
return cls(prod(x.shape), type_map[buf.dtype], buf)
def toCPU(self): return self._buf.cpu().numpy()
TorchBuffer = Interpreted(RawTorchBuffer, functools.partial(interpret_ast, torch_fxn_for_op, lambda x: RawTorchBuffer(prod(x.shape), type_map[x.dtype], x)))
TorchBuffer = Interpreted(RawTorchBuffer, torch_fxn_for_op, lambda x: RawTorchBuffer(prod(x.shape), type_map[x.dtype], x))