diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d7ece0cb..e3d82597 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,12 @@ repos: - repo: local hooks: + - id: mypy + name: mypy + entry: mypy tinygrad/ extra/helpers.py + language: system + always_run: true + pass_filenames: false - id: ruff name: ruff entry: ruff . @@ -19,12 +25,6 @@ repos: language: system always_run: true pass_filenames: false - - id: mypy - name: mypy - entry: mypy tinygrad/ extra/helpers.py - language: system - always_run: true - pass_filenames: false - id: tests name: subset of TORCH tests entry: env PYTHONPATH="." TORCH=1 python3 -m pytest -n=4 test/unit/ test/test_ops.py test/test_dtype.py test/test_schedule.py test/test_custom_function.py test/test_assign.py test/test_symbolic_shapetracker.py diff --git a/examples/llama.py b/examples/llama.py index e1aedc44..3890da5a 100755 --- a/examples/llama.py +++ b/examples/llama.py @@ -15,11 +15,11 @@ from tinygrad.tensor import Tensor from tinygrad.nn import Embedding, Linear from tinygrad.nn.state import safe_load, torch_load, load_state_dict, get_parameters from tinygrad.helpers import GlobalCounters -from tinygrad.jit import TinyJit, JIT_SUPPORTED_DEVICE +from tinygrad.jit import TinyJit from tinygrad.shape.symbolic import Variable MAX_CONTEXT = 1024 -JIT = getenv("JIT", 0 if CI else int(Device.DEFAULT in JIT_SUPPORTED_DEVICE)) +JIT = getenv("JIT", 0 if CI else 1) # https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47 def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> Tensor: diff --git a/test/external/external_test_jit_on_models.py b/test/external/external_test_jit_on_models.py index f03615b9..bb12698a 100644 --- a/test/external/external_test_jit_on_models.py +++ b/test/external/external_test_jit_on_models.py @@ -2,7 +2,7 @@ import unittest import numpy as np from tinygrad.tensor import Tensor -from tinygrad.jit import TinyJit, JIT_SUPPORTED_DEVICE +from tinygrad.jit import TinyJit from tinygrad.helpers import dtypes, CI from tinygrad.ops import Device from test.helpers import derandomize_model @@ -14,7 +14,6 @@ def helper_test_jitted_correctness(gen, train, train_jit): for _ in range(5): jit = train_jit(*gen()).numpy() np.testing.assert_allclose(nojit, jit, rtol=1e-3, atol=1e-5) -@unittest.skipUnless(Device.DEFAULT in JIT_SUPPORTED_DEVICE, "needs JIT") class TestJittedModels(unittest.TestCase): def test_jitted_tiny_llama(self): old_type = Tensor.default_type diff --git a/test/models/test_real_world.py b/test/models/test_real_world.py index 507241fc..ae7ca5f7 100644 --- a/test/models/test_real_world.py +++ b/test/models/test_real_world.py @@ -3,7 +3,7 @@ import numpy as np from tinygrad.tensor import Tensor from tinygrad.nn import optim from tinygrad.nn.state import get_parameters -from tinygrad.jit import TinyJit, JIT_SUPPORTED_DEVICE +from tinygrad.jit import TinyJit from tinygrad.ops import Device, GlobalCounters from tinygrad.helpers import CI, dtypes, getenv, prod from test.helpers import derandomize_model @@ -47,7 +47,7 @@ class TestRealWorld(unittest.TestCase): def test(t, t2): return model(t, 801, t2).realize() helper_test("test_sd", lambda: (Tensor.randn(1, 4, 64, 64),Tensor.randn(1, 77, 768)), test, 18.0, 967) - @unittest.skipUnless(Device.DEFAULT in JIT_SUPPORTED_DEVICE and Device.DEFAULT not in ["LLVM"], "needs JIT, too long on CI LLVM") + @unittest.skipUnless((Device.DEFAULT not in ["LLVM", "CPU"] or not CI), "needs JIT, too long on CI LLVM") def test_llama(self): Tensor.default_type = dtypes.float16 @@ -59,7 +59,7 @@ class TestRealWorld(unittest.TestCase): # NOTE: only test one pass, not testing the dynamic shape autoregressive part helper_test("test_llama", lambda: (Tensor([[1,]]),), test, 0.22 if CI else 13.5, 137 if CI else 521, all_jitted=True) - @unittest.skipUnless(Device.DEFAULT in JIT_SUPPORTED_DEVICE and (Device.DEFAULT not in ["LLVM"] or not CI), "needs JIT, too long on CI LLVM") + @unittest.skipUnless((Device.DEFAULT not in ["LLVM", "CPU"] or not CI), "needs JIT, too long on CI LLVM") def test_gpt2(self): Tensor.default_type = dtypes.float16 @@ -70,7 +70,7 @@ class TestRealWorld(unittest.TestCase): def test(t): return model(t, 0).realize() helper_test("test_gpt2", lambda: (Tensor([[1,]]),), test, 0.21 if CI else 0.9, 140 if CI else 396, all_jitted=True) - @unittest.skipUnless(Device.DEFAULT in JIT_SUPPORTED_DEVICE and (Device.DEFAULT not in ["LLVM", "CLANG"] or not CI), "needs JIT, too long on CI LLVM and CLANG") + @unittest.skipUnless((Device.DEFAULT not in ["LLVM", "CLANG", "CPU"] or not CI), "needs JIT, too long on CI LLVM and CLANG") def test_train_cifar(self): # TODO: with default device #old_default = Device.DEFAULT diff --git a/test/test_custom_function.py b/test/test_custom_function.py index b9312a47..92968b84 100644 --- a/test/test_custom_function.py +++ b/test/test_custom_function.py @@ -9,7 +9,7 @@ from tinygrad.helpers import prod, dtypes # *** first, we implement the atan2 op at the lowest level *** # `atan2_gpu` for GPUBuffers and `atan2_cpu` for CPUBuffers from tinygrad.lazy import LazyBuffer, create_lazybuffer -from tinygrad.ops import ASTRunner, Device +from tinygrad.ops import CompiledASTRunner, Device from tinygrad.shape.shapetracker import ShapeTracker import pytest @@ -20,7 +20,7 @@ def atan2_gpu(ret:LazyBuffer, a:LazyBuffer, b:LazyBuffer): assert a.device == "GPU" and b.device == "GPU", "gpu function requires GPUBuffers" assert a.dtype == b.dtype and a.dtype == dtypes.float32, "gpu function only supports float32" ret.realized = Device[ret.device].buffer(prod(ret.shape), ret.dtype) - ASTRunner("atan2_gpu", """ + CompiledASTRunner(None, "atan2_gpu", """ __kernel void atan2_gpu(global float *c, global float *a, global float *b) { int idx = get_global_id(0); c[idx] = atan2(a[idx], b[idx]); @@ -89,6 +89,7 @@ class TestCustomFunction(unittest.TestCase): np.testing.assert_allclose(a.grad.numpy(), ta.grad.numpy(), atol=1e-5) np.testing.assert_allclose(b.grad.numpy(), tb.grad.numpy(), atol=1e-5) + @unittest.skipIf(Device.DEFAULT in ["CPU"], "atan2_cpu not jittable") def test_atan2_jit(self): # custom ops even work in the JIT! from tinygrad.jit import TinyJit diff --git a/test/test_jit.py b/test/test_jit.py index bb2bafec..99dd3dee 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -2,13 +2,13 @@ import unittest import numpy as np from tinygrad.tensor import Tensor, Device -from tinygrad.jit import TinyJit, JIT_SUPPORTED_DEVICE +from tinygrad.jit import TinyJit import pytest pytestmark = pytest.mark.webgpu # NOTE: METAL fails, might be platform and optimization options dependent. -@unittest.skipUnless(Device.DEFAULT in JIT_SUPPORTED_DEVICE and Device.DEFAULT not in ["METAL", "WEBGPU"], f"no JIT on {Device.DEFAULT}") +@unittest.skipUnless(Device.DEFAULT not in ["METAL", "WEBGPU"], f"no JIT on {Device.DEFAULT}") class TestJit(unittest.TestCase): def test_simple_jit(self): @TinyJit diff --git a/test/test_symbolic_jit.py b/test/test_symbolic_jit.py index 83a4f382..ac987762 100644 --- a/test/test_symbolic_jit.py +++ b/test/test_symbolic_jit.py @@ -6,7 +6,6 @@ from tinygrad.tensor import Tensor, Device import numpy as np @unittest.skipIf(getenv("ARM64") or getenv("PTX"), "ARM64 and PTX are not supported") -@unittest.skipUnless(Device.DEFAULT in ["GPU", "METAL", "CLANG", "CUDA", "LLVM"], f"{Device.DEFAULT} is not supported") class TestSymbolicJit(unittest.TestCase): def test_plus1(self): def f(a): return (a+1).realize() diff --git a/test/test_symbolic_ops.py b/test/test_symbolic_ops.py index cd930d0b..b36fb225 100644 --- a/test/test_symbolic_ops.py +++ b/test/test_symbolic_ops.py @@ -1,5 +1,4 @@ import unittest -from tinygrad.jit import JIT_SUPPORTED_DEVICE from tinygrad.shape.symbolic import Variable from tinygrad.helpers import getenv from tinygrad.tensor import Tensor, Device diff --git a/test/test_uops.py b/test/test_uops.py index 2d055453..bbe29756 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -3,14 +3,14 @@ import unittest, math import numpy as np from tinygrad.helpers import dtypes, getenv, DType, PtrDType from tinygrad.tensor import Device -from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, ASTRunner, Compiled +from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, CompiledASTRunner, Compiled from tinygrad.codegen.linearizer import UOps, UOp def _uops_to_prg(uops): src, runtime_args = Device[Device.DEFAULT].renderer("test", uops) - return ASTRunner("test", src, - [1] if Device[Device.DEFAULT].linearizer_opts.has_local else None, [1] if Device[Device.DEFAULT].linearizer_opts.has_local else None, - runtime_args=runtime_args).build(Device[Device.DEFAULT].compiler, Device[Device.DEFAULT].runtime) + return CompiledASTRunner(None, "test", src, + [1] if Device[Device.DEFAULT].linearizer_opts.has_local else None, [1] if Device[Device.DEFAULT].linearizer_opts.has_local else None, + runtime_args=runtime_args).build(Device[Device.DEFAULT].compiler, Device[Device.DEFAULT].runtime) def uop(uops:List[UOp], uop:UOps, dtype:Optional[DType], vin:Tuple[UOp, ...], arg:Any=None) -> UOp: uops.append(UOp(uop, dtype, tuple(vin), arg)) diff --git a/tinygrad/jit.py b/tinygrad/jit.py index ad549f8c..f58b4daa 100644 --- a/tinygrad/jit.py +++ b/tinygrad/jit.py @@ -8,8 +8,6 @@ from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.symbolic import Variable from weakref import ref, WeakKeyDictionary -JIT_SUPPORTED_DEVICE = ["GPU", "CLANG", "METAL", "CUDA", "HIP", "WEBGPU", "LLVM"] - class TinyJit: def __init__(self, fxn:Callable): self.fxn: Callable = fxn @@ -28,8 +26,6 @@ class TinyJit: def __get__(self, obj, objtype): return functools.partial(self.__call__, obj) def __call__(self, *args, **kwargs) -> Any: - if Device.DEFAULT.split(":")[0] not in JIT_SUPPORTED_DEVICE: return self.fxn(*args, **kwargs) # only jit on supported device - # all inputs are realized input_tensors: Dict[Union[int, str], Tensor] = {cast(Union[int, str], k):v.realize() for k,v in itertools.chain(enumerate(args), kwargs.items()) if v.__class__ is Tensor} expected_sts_dtype = tuple([(v.lazydata.st.unbind(), v.dtype) for v in input_tensors.values()]) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 672c6e1f..ad2636cf 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -1,5 +1,5 @@ from __future__ import annotations -import importlib, inspect, functools, pathlib +import importlib, inspect, functools, pathlib, time from enum import Enum, auto from typing import TYPE_CHECKING, Union, Type, Tuple, Any, List, Optional, Dict, Callable, Mapping, cast from tinygrad.helpers import ansilen, prod, DEBUG, getenv, GlobalCounters, DType, colored, BEAM, NOOPT, dedup, all_int @@ -26,6 +26,8 @@ OpType = Union[Type[UnaryOps], Type[BinaryOps], Type[ReduceOps], Type[MovementOp if TYPE_CHECKING: from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.lazy import LazyBuffer + from tinygrad.codegen.linearizer import Linearizer + from tinygrad.codegen.kernel import LinearizerOptions @dataclass(frozen=True) class MemBuffer: @@ -104,48 +106,6 @@ class _Device: return "CPU" Device = _Device() -# **************** batch executor **************** - -@dataclass(frozen=True) -class JitItem: - prg: ASTRunner - rawbufs: List[Optional[RawBuffer]] - -class BatchExecutor: - def __init__(self, jit_cache: List[JitItem], input_rawbuffers: Dict[Union[int, str], RawBuffer], var_vals: Dict[Variable, int]): - self.jit_cache: List[JitItem] = jit_cache - self.input_replace: Dict[Tuple[int, int], Union[int, str]] = {} - self.op_estimate, self.mem_estimate = NumNode(0), NumNode(0) - for j,ji in enumerate(jit_cache): - if isinstance(ji.prg, ASTRunner): # TODO: this is just for world and needs to be refactored - self.op_estimate += ji.prg.op_estimate - self.mem_estimate += ji.prg.mem_estimate - for i,a in enumerate(ji.rawbufs): - if a in [v for v in input_rawbuffers.values()]: - self.input_replace[(j,i)] = [k for k,v in input_rawbuffers.items() if v == a][0] - assert set(self.input_replace.values()) == set(input_rawbuffers.keys()), "some input tensors not found" - self.clear_jit_inputs() - - def __call__(self, input_rawbuffers: Dict[Union[int, str], RawBuffer], var_vals: Dict[Variable, int], wait=False): - for (j,i),input_name in self.input_replace.items(): self.jit_cache[j].rawbufs[i] = input_rawbuffers[input_name] - for ji in self.jit_cache: ji.prg(cast(List[RawBuffer], ji.rawbufs), var_vals, jit=True) - self.clear_jit_inputs() - - def update_stats(self, var_vals: Dict[Variable, int], et: Optional[float]): - # TODO: this is mostly copied from ASTRunner - op_estimate = sym_infer(self.op_estimate, var_vals) - mem_estimate = sym_infer(self.mem_estimate, var_vals) - if DEBUG >= 2: - print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', 'CYAN')} kernels:{len(self.jit_cache):4d} inputs:{len(self.input_replace):3d} {' '.join([f'{k.expr}={v}' for k,v in var_vals.items()])[:50]:50s} OPs {int(op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " + - (str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)")) - GlobalCounters.kernel_count += len(self.jit_cache) - GlobalCounters.global_ops += sym_infer(self.op_estimate, var_vals) - GlobalCounters.global_mem += sym_infer(self.mem_estimate, var_vals) - if et is not None: GlobalCounters.time_sum_s += et - - def clear_jit_inputs(self): - for (j,i) in self.input_replace.keys(): self.jit_cache[j].rawbufs[i] = None - # **************** independent FlopCounter **************** @dataclass @@ -174,114 +134,159 @@ def get_lazyop_info(ast:LazyOp) -> FlopCounter: def run_ast(ast): return InterpretedFlopCounter[ast.op](*([run_ast(x) for x in ast.src]+([ast.arg] if ast.arg is not None else []))) return run_ast(ast) +# **************** batch executor **************** + +@dataclass(frozen=True) +class JitItem: + prg: ASTRunner + rawbufs: List[Optional[RawBuffer]] + +class BatchExecutor: + def __init__(self, jit_cache: List[JitItem], input_rawbuffers: Dict[Union[int, str], RawBuffer], var_vals: Dict[Variable, int]): + self.jit_cache: List[JitItem] = jit_cache + self.input_replace: Dict[Tuple[int, int], Union[int, str]] = {} + self.op_estimate, self.mem_estimate = NumNode(0), NumNode(0) + for j,ji in enumerate(jit_cache): + if isinstance(ji.prg, ASTRunner): # TODO: this is just for world and needs to be refactored + self.op_estimate += ji.prg.op_estimate + self.mem_estimate += ji.prg.mem_estimate + for i,a in enumerate(ji.rawbufs): + if a in [v for v in input_rawbuffers.values()]: + self.input_replace[(j,i)] = [k for k,v in input_rawbuffers.items() if v == a][0] + assert set(self.input_replace.values()) == set(input_rawbuffers.keys()), "some input tensors not found" + self.clear_jit_inputs() + + def __call__(self, input_rawbuffers: Dict[Union[int, str], RawBuffer], var_vals: Dict[Variable, int], wait=False): + for (j,i),input_name in self.input_replace.items(): self.jit_cache[j].rawbufs[i] = input_rawbuffers[input_name] + for ji in self.jit_cache: ji.prg(ji.rawbufs, var_vals, jit=True) + self.clear_jit_inputs() + + def update_stats(self, var_vals: Dict[Variable, int], et: Optional[float]): + # TODO: this is mostly copied from ASTRunner + op_estimate = sym_infer(self.op_estimate, var_vals) + mem_estimate = sym_infer(self.mem_estimate, var_vals) + if DEBUG >= 2: + print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', 'CYAN')} kernels:{len(self.jit_cache):4d} inputs:{len(self.input_replace):3d} {' '.join([f'{k.expr}={v}' for k,v in var_vals.items()])[:50]:50s} OPs {int(op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " + + (str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)")) + GlobalCounters.kernel_count += len(self.jit_cache) + GlobalCounters.global_ops += sym_infer(self.op_estimate, var_vals) + GlobalCounters.global_mem += sym_infer(self.mem_estimate, var_vals) + if et is not None: GlobalCounters.time_sum_s += et + + def clear_jit_inputs(self): + for (j,i) in self.input_replace.keys(): self.jit_cache[j].rawbufs[i] = None + +class ASTRunner: + def __init__(self, ast:Optional[LazyOp]): + if ast is None: + self.op_estimate, self.mem_estimate, self.vars = 0, 0, [] + else: + info = get_lazyop_info(ast) + self.op_estimate, self.mem_estimate = info.flops, info.mem_estimate + from tinygrad.lazy import vars_from_ast + self.vars = vars_from_ast(ast) + assert all(v._val is None for v in self.vars), f"ASTRunner contains bound Variable {self.vars}" + + def exec(self, rawbufs:List[Optional[RawBuffer]], var_vals:Optional[Dict[Variable, int]]=None, force_wait=False) -> Optional[float]: + from tinygrad.jit import CacheCollector + et = self(rawbufs, var_vals, force_wait=force_wait) + CacheCollector.add(self, rawbufs, var_vals if var_vals is not None else {}) + return et + + def update_stats(self, name, var_vals: Optional[Dict[Variable, int]], et: Optional[float], buf_count, lra, jit): + if var_vals is None: var_vals = {} + op_estimate = sym_infer(self.op_estimate, var_vals) + mem_estimate = sym_infer(self.mem_estimate, var_vals) + if DEBUG >= 2: + print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', 'magenta' if jit else None)} {name+' '*(37-ansilen(name))} arg {buf_count:3d} sz {str(lra.get('global_size', '')):18s} {str(lra.get('local_size', '')):12s} OPs {int(op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " + + (str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)")) + GlobalCounters.kernel_count += 1 + GlobalCounters.global_ops += op_estimate + GlobalCounters.global_mem += mem_estimate + + def __call__(self, rawbufs:List[Optional[RawBuffer]], var_vals:Optional[Dict[Variable, int]]=None, jit=False, force_wait=False) -> Optional[float]: + raise NotImplementedError("override this") + # **************** for Interpreted Buffers **************** +class InterpretedASTRunner(ASTRunner): + def __init__(self, ast:LazyOp, fxn:Callable): + self.fxn = fxn + super().__init__(ast) + + def __call__(self, rawbufs:List[Optional[RawBuffer]], var_vals:Optional[Dict[Variable, int]]=None, jit=False, force_wait=False) -> float: + st = time.perf_counter() + ret: RawBuffer = self.fxn(rawbufs[1:], var_vals) + et = time.perf_counter() - st + self.update_stats(f"", var_vals, et, len(rawbufs), {}, jit) + if rawbufs[0] is not None: + assert rawbufs[0].dtype == ret.dtype + rawbufs[0].size = ret.size # NOTE: for symbolic this can change + rawbufs[0]._buf = ret._buf + else: rawbufs[0] = ret + return et + +from tinygrad.runtime.interpreted import interpret_ast class Interpreted: - def __init__(self, buffer: Type[RawBuffer], compiler: Callable[[LazyOp], Callable]): - self.buffer, self.compiler = buffer, compiler + def __init__(self, buffer: Type[RawBuffer], fxn_for_op:Dict[Op, Callable], from_underlying:Optional[Callable]=None): + self.buffer, self.fxn_for_op, self.from_underlying = buffer, fxn_for_op, from_underlying self.synchronize = lambda: None self.batch_executor = BatchExecutor self.codegen = None - self.method_cache: Dict[LazyOp, Callable] = {} + self.method_cache: Dict[LazyOp, InterpretedASTRunner] = {} def exec_ast(self, ast:LazyOp, output:LazyBuffer, inputs:Tuple[LazyBuffer, ...], var_vals:Dict[Variable, int], **kwargs): - if ast not in self.method_cache: self.method_cache[ast] = self.compiler(ast) - output.realized = output.output_buffer # NOTE: assuming this is the right size and dtype from assign - ret: RawBuffer = self.method_cache[ast]([x.realized for x in inputs] if inputs else None, var_vals) - assert output.dtype == ret.dtype, f"expected {output.dtype}, got {ret.dtype}" - if output.realized is not None: output.realized._buf = ret._buf - else: output.realized = ret + if ast not in self.method_cache: self.method_cache[ast] = InterpretedASTRunner(ast, interpret_ast(self.fxn_for_op, self.from_underlying, ast)) + rawbufs = [output.realized if output.realized is not None else output.output_buffer] + [x.realized for x in inputs] + self.method_cache[ast].exec(rawbufs, var_vals) + output.realized = rawbufs[0] # **************** for Compiled Buffers **************** -class ASTRunner: - def __init__(self, name:str, prg:str, global_size:Optional[List[int]]=None, local_size:Optional[List[int]]=None, op_estimate=0, mem_estimate=0, display_name:Optional[str]=None, runtime_args:Optional[dict]=None): +class CompiledASTRunner(ASTRunner): + def __init__(self, ast:Optional[LazyOp], name:str, prg:str, global_size:Optional[List[int]]=None, local_size:Optional[List[int]]=None, op_estimate=0, mem_estimate=0, display_name:Optional[str]=None, runtime_args:Optional[dict]=None): if DEBUG >= 4: print(prg) - self.name, self.prg, self.global_size, self.local_size, self.op_estimate, self.mem_estimate, self.display_name, self.runtime_args = name, prg, global_size, local_size, op_estimate, mem_estimate, display_name, runtime_args if runtime_args is not None else {} - self.vars:List[Variable] = [] + self.name, self.prg, self.global_size, self.local_size, self.op_estimate, self.mem_estimate, self.display_name, self.runtime_args = \ + name, prg, global_size, local_size, op_estimate, mem_estimate, display_name, runtime_args if runtime_args is not None else {} + super().__init__(ast) def build(self, compiler, runtime): self.lib = compiler.__wrapped__(self.prg) if getenv("DISABLE_COMPILER_CACHE") else compiler(self.prg) self.clprg = runtime(self.name, self.lib) return self - def exec(self, rawbufs:List[RawBuffer], var_vals:Optional[Dict[Variable, int]]=None, force_wait=False) -> Optional[float]: - from tinygrad.jit import CacheCollector - CacheCollector.add(self, rawbufs, var_vals if var_vals is not None else {}) - return self(rawbufs, var_vals, force_wait=force_wait) - def launch_dims(self, var_vals): global_size = ([sym_infer(sz, var_vals) for sz in self.global_size] + [1]*(3-len(self.global_size))) if self.global_size is not None else self.global_size local_size = ([sym_infer(sz, var_vals) for sz in self.local_size] + [1]*(3-len(self.local_size))) if self.local_size is not None else self.local_size return global_size, local_size - def __call__(self, rawbufs:List[RawBuffer], var_vals:Optional[Dict[Variable, int]]=None, jit=False, force_wait=False) -> Optional[float]: + def __call__(self, rawbufs:List[Optional[RawBuffer]], var_vals:Optional[Dict[Variable, int]]=None, jit=False, force_wait=False) -> Optional[float]: if var_vals is None: var_vals = {} var_vals = {k:var_vals[k] for k in self.vars} # filter the var_vals global_size, local_size = self.launch_dims(var_vals) if global_size is not None and local_size is None and all_int(self.global_size): # type: ignore[arg-type] # TODO: this is copied from get_program from tinygrad.features.search import optimize_local_size - local_size = self.local_size = optimize_local_size(self.clprg, global_size, rawbufs) + local_size = self.local_size = optimize_local_size(self.clprg, global_size, cast(List[RawBuffer], rawbufs)) global_size = self.global_size = [g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)] lra = self.runtime_args.copy() if global_size: lra['global_size'] = global_size if local_size and 'local_size' not in lra: lra['local_size'] = local_size if et := self.clprg(*rawbufs, *var_vals.values(), **lra, wait=force_wait or DEBUG>=2): GlobalCounters.time_sum_s += et - op_estimate = sym_infer(self.op_estimate, var_vals) - mem_estimate = sym_infer(self.mem_estimate, var_vals) - if DEBUG >= 2: - print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', 'magenta' if jit else None)} {(self.display_name+' '*(37-ansilen(self.display_name))) if self.display_name is not None else self.name:33s} arg {len(rawbufs):3d} sz {str(global_size):18s} {str(local_size):12s} OPs {int(op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " + - (str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)")) - GlobalCounters.kernel_count += 1 - GlobalCounters.global_ops += op_estimate - GlobalCounters.global_mem += mem_estimate + self.update_stats(self.display_name if self.display_name is not None else self.name, var_vals, et, len(rawbufs), lra, jit) return et class Compiled: - def __init__(self, buffer: Type[RawBuffer], linearizer_opts, renderer, compiler, runtime, synchronize=lambda: None, batch_executor=BatchExecutor): + def __init__(self, buffer: Type[RawBuffer], linearizer_opts:LinearizerOptions, renderer, compiler, runtime, synchronize=lambda: None, batch_executor=BatchExecutor): self.buffer, self.linearizer_opts, self.renderer, self.compiler, self.runtime, self.synchronize = buffer, linearizer_opts, renderer, compiler, runtime, synchronize self.batch_executor = BatchExecutor if getenv("JIT") == 2 else batch_executor - self.method_cache: Dict[LazyOp, ASTRunner] = {} + self.method_cache: Dict[LazyOp, CompiledASTRunner] = {} - def to_program(self, k) -> ASTRunner: + def to_program(self, k:Linearizer) -> CompiledASTRunner: k.linearize() src, runtime_args = self.renderer(k.function_name, k.uops) - return ASTRunner(k.function_name, src, k.global_size, k.local_size, - op_estimate=k.info.flops, mem_estimate=k.info.mem_estimate, - display_name=k.display_name, runtime_args=runtime_args).build(self.compiler, self.runtime) - - def get_optimized_program(self, ast:LazyOp, rawbuffers:List[RawBuffer]) -> ASTRunner: - if DEBUG >= 3: - from tinygrad.graph import print_tree - print_tree(ast) - from tinygrad.codegen.linearizer import Linearizer - from tinygrad.lazy import vars_from_ast - k = Linearizer(ast, self.linearizer_opts) - assert k.info.dtype == rawbuffers[0].dtype, f"linearizer must match dtype. linearizer wants {k.info.dtype} but buffer is {rawbuffers[0].dtype}" - if not NOOPT: - if not (used_tensor_cores:=k.apply_tensor_cores(getenv("TC", 1))): k.hand_coded_optimizations() - if BEAM >= 1: - lins = [(("tc" if used_tensor_cores else "hc"), k)] - # allocate a scratch buffer if output buffer is also input - test_rawbuffers = [type(rawbuffers[0])(rawbuffers[0].size, rawbuffers[0].dtype), *rawbuffers[1:]] if rawbuffers[0] in rawbuffers[1:] else rawbuffers - kb = Linearizer(ast, self.linearizer_opts) - kb.required_optimizations() - from tinygrad.features.search import beam_search, time_linearizer - lins.append((f"beam{BEAM.value}", beam_search(kb, test_rawbuffers, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1))))) - if used_tensor_cores: - lins.append(("hc", Linearizer(ast, self.linearizer_opts))) - lins[-1][1].hand_coded_optimizations() - timed = sorted([(nm, tk, time_linearizer(tk, test_rawbuffers, allow_test_size=False, disable_cache=True, clear_l2=True)) for nm, tk in lins], key=lambda x: x[2]) - if DEBUG >= 1: print(" < ".join(f"{nm:6s} : {lin.colored_shape(30, dense=True)} : {tm*1e6:8.2f} us" for nm, lin, tm in timed)) - k = timed[0][1] - else: - k.required_optimizations() - prg = self.to_program(k) - # extract real vars used in ast - prg.vars = vars_from_ast(ast) - assert all(v._val is None for v in prg.vars), f"ast contains bound Variable {prg.vars}" - return prg + return CompiledASTRunner(k.ast, k.function_name, src, k.global_size, k.local_size, + display_name=k.display_name, runtime_args=runtime_args).build(self.compiler, self.runtime) def exec_ast(self, ast:LazyOp, output:LazyBuffer, inputs:Tuple[LazyBuffer, ...], var_vals:Dict[Variable, int], **kwargs): # check if we can reuse the output buffer @@ -305,5 +310,32 @@ class Compiled: # all the rawbuffers rawbuffers = [output.realized] + [x.realized for x in inputs] - if ast not in self.method_cache: self.method_cache[ast] = self.get_optimized_program(ast, rawbuffers) + if ast not in self.method_cache: self.method_cache[ast] = get_optimized_program(self.linearizer_opts, self.to_program, ast, rawbuffers) self.method_cache[ast].exec(rawbuffers, var_vals) + +def get_optimized_program(linearizer_opts:LinearizerOptions, to_program, ast:LazyOp, rawbuffers:List[RawBuffer]) -> CompiledASTRunner: + if DEBUG >= 3: + from tinygrad.graph import print_tree + print_tree(ast) + from tinygrad.codegen.linearizer import Linearizer + k = Linearizer(ast, linearizer_opts) + assert k.info.dtype == rawbuffers[0].dtype, f"linearizer must match dtype. linearizer wants {k.info.dtype} but buffer is {rawbuffers[0].dtype}" + if not NOOPT: + if not (used_tensor_cores:=k.apply_tensor_cores(getenv("TC", 1))): k.hand_coded_optimizations() + if BEAM >= 1: + lins = [(("tc" if used_tensor_cores else "hc"), k)] + # allocate a scratch buffer if output buffer is also input + test_rawbuffers = [type(rawbuffers[0])(rawbuffers[0].size, rawbuffers[0].dtype), *rawbuffers[1:]] if rawbuffers[0] in rawbuffers[1:] else rawbuffers + kb = Linearizer(ast, linearizer_opts) + kb.required_optimizations() + from tinygrad.features.search import beam_search, time_linearizer + lins.append((f"beam{BEAM.value}", beam_search(kb, test_rawbuffers, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1))))) + if used_tensor_cores: + lins.append(("hc", Linearizer(ast, linearizer_opts))) + lins[-1][1].hand_coded_optimizations() + timed = sorted([(nm, tk, time_linearizer(tk, test_rawbuffers, allow_test_size=False, disable_cache=True, clear_l2=True)) for nm, tk in lins], key=lambda x: x[2]) + if DEBUG >= 1: print(" < ".join(f"{nm:6s} : {lin.colored_shape(30, dense=True)} : {tm*1e6:8.2f} us" for nm, lin, tm in timed)) + k = timed[0][1] + else: + k.required_optimizations() + return to_program(k) diff --git a/tinygrad/runtime/ops_cpu.py b/tinygrad/runtime/ops_cpu.py index 43c75d53..2bc66024 100644 --- a/tinygrad/runtime/ops_cpu.py +++ b/tinygrad/runtime/ops_cpu.py @@ -1,10 +1,9 @@ import numpy as np -import operator, functools +import operator from typing import Callable, Dict, Tuple, Optional from tinygrad.helpers import dtypes, DType from tinygrad.ops import BufferOps, UnaryOps, BinaryOps, MovementOps, ReduceOps, TernaryOps, Op, Interpreted from tinygrad.runtime.lib import RawBuffer -from tinygrad.runtime.interpreted import interpret_ast def shape_to_axis(old_shape:Tuple[int, ...], new_shape:Tuple[int, ...]) -> Tuple[int, ...]: assert len(old_shape) == len(new_shape), "reduce shapes must have same dimensions" @@ -52,4 +51,4 @@ class RawNumpyBuffer(RawBuffer): @classmethod def fromCPU(cls, x): return cls(x.size, dtypes.from_np(x.dtype), x) def toCPU(self): return self._buf -CPUBuffer = Interpreted(RawNumpyBuffer, functools.partial(interpret_ast, numpy_fxn_for_op, RawNumpyBuffer.fromCPU)) +CPUBuffer = Interpreted(RawNumpyBuffer, numpy_fxn_for_op, RawNumpyBuffer.fromCPU) diff --git a/tinygrad/runtime/ops_disk.py b/tinygrad/runtime/ops_disk.py index afef209a..4bff4e0a 100644 --- a/tinygrad/runtime/ops_disk.py +++ b/tinygrad/runtime/ops_disk.py @@ -1,9 +1,8 @@ -import os, mmap, functools +import os, mmap from typing import Optional from typing import Callable, Dict, Tuple from tinygrad.helpers import prod, DType from tinygrad.runtime.lib import RawBufferMapped -from tinygrad.runtime.interpreted import interpret_ast from tinygrad.ops import Interpreted, Op, MovementOps, UnaryOps, BufferOps class RawDiskBuffer(RawBufferMapped): @@ -39,4 +38,4 @@ class RawDiskBuffer(RawBufferMapped): self._buf[0].readinto(buf) disk_fxn_for_op: Dict[Op, Callable] = { BufferOps.MEM: lambda x: x, UnaryOps.NOOP: lambda x: x, UnaryOps.CAST: RawDiskBuffer.cast, MovementOps.AS_STRIDED: RawDiskBuffer.as_strided } -DiskBuffer = Interpreted(RawDiskBuffer, functools.partial(interpret_ast, disk_fxn_for_op, None)) +DiskBuffer = Interpreted(RawDiskBuffer, disk_fxn_for_op) diff --git a/tinygrad/runtime/ops_metal.py b/tinygrad/runtime/ops_metal.py index f442c531..18c007fe 100644 --- a/tinygrad/runtime/ops_metal.py +++ b/tinygrad/runtime/ops_metal.py @@ -1,10 +1,10 @@ # pip3 install pyobjc-framework-Metal pyobjc-framework-Cocoa pyobjc-framework-libdispatch import os, subprocess, pathlib, ctypes, tempfile import Metal, Cocoa, libdispatch -from typing import List, Any, Tuple, Dict, Union, Set +from typing import List, Any, Tuple, Dict, Union, Set, cast from tinygrad.codegen.kernel import LinearizerOptions from tinygrad.helpers import prod, getenv, DEBUG, DType, dtypes, diskcache, dedup, CI -from tinygrad.ops import Compiled, BatchExecutor, JitItem +from tinygrad.ops import Compiled, BatchExecutor, JitItem, CompiledASTRunner from tinygrad.renderer.metal import MetalRenderer from tinygrad.runtime.lib import RawBufferMapped, RawBuffer, LRUAllocator from tinygrad.shape.symbolic import Variable, Node @@ -98,8 +98,9 @@ class MetalBatchExecutor(BatchExecutor): self.input_has_variable_dims: Set[int] = set() read_resources, write_resources = [], [] for j,ji in enumerate(self.jit_cache): + prg: CompiledASTRunner = cast(CompiledASTRunner, ji.prg) descriptor = Metal.MTLComputePipelineDescriptor.new() - descriptor.setComputeFunction_(ji.prg.clprg.fxn) + descriptor.setComputeFunction_(prg.clprg.fxn) descriptor.setSupportIndirectCommandBuffers_(True) pipeline_state = unwrap(METAL.device.newComputePipelineStateWithDescriptor_options_reflection_error_(descriptor, Metal.MTLPipelineOption(0), None, None)) icb_command = self.icb.indirectComputeCommandAtIndex_(j) @@ -110,11 +111,11 @@ class MetalBatchExecutor(BatchExecutor): if i == 0: write_resources.append(b._buf) else: read_resources.append(b._buf) var_vals_keys = list(var_vals.keys()) - for i,v in enumerate(getattr(ji.prg,"vars",[])): + for i,v in enumerate(prg.vars): icb_command.setKernelBuffer_offset_atIndex_(self.int_buf._buf, var_vals_keys.index(v)*4, len(ji.rawbufs)+i) - global_size, local_size = ji.prg.launch_dims(var_vals) - assert ji.prg.global_size and ji.prg.local_size, "need global and local size to JIT" - if any(isinstance(x, Node) for x in ji.prg.global_size) or any(isinstance(x, Node) for x in ji.prg.local_size): + global_size, local_size = prg.launch_dims(var_vals) + assert prg.global_size and prg.local_size, "need global and local size to JIT" + if any(isinstance(x, Node) for x in prg.global_size) or any(isinstance(x, Node) for x in prg.local_size): self.input_has_variable_dims.add(j) else: icb_command.concurrentDispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size)) @@ -130,7 +131,7 @@ class MetalBatchExecutor(BatchExecutor): for (j,i),input_name in self.input_replace.items(): self.icb.indirectComputeCommandAtIndex_(j).setKernelBuffer_offset_atIndex_(input_rawbuffers[input_name]._buf, 0, i) for j in self.input_has_variable_dims: - global_size, local_size = self.jit_cache[j].prg.launch_dims(var_vals) + global_size, local_size = cast(CompiledASTRunner, self.jit_cache[j].prg).launch_dims(var_vals) self.icb.indirectComputeCommandAtIndex_(j).concurrentDispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size)) self.int_buf_view[:] = list(var_vals.values()) command_buffer = METAL.mtl_queue.commandBuffer() diff --git a/tinygrad/runtime/ops_shm.py b/tinygrad/runtime/ops_shm.py index a4274cd4..8e5419b7 100644 --- a/tinygrad/runtime/ops_shm.py +++ b/tinygrad/runtime/ops_shm.py @@ -1,11 +1,10 @@ -import os, mmap, functools +import os, mmap try: import _posixshmem except Exception: pass from typing import Callable, Dict from tinygrad.helpers import DType, OSX from tinygrad.runtime.lib import RawBufferMapped from tinygrad.ops import Interpreted, Op, UnaryOps, MovementOps, BufferOps -from tinygrad.runtime.interpreted import interpret_ast class RawShmBuffer(RawBufferMapped): def __init__(self, size, dtype:DType, device:str): @@ -26,4 +25,4 @@ class RawShmBuffer(RawBufferMapped): # TODO: is this wrong? shm_fxn_for_op: Dict[Op, Callable] = { BufferOps.MEM: lambda x: x, UnaryOps.NOOP: lambda x:x, MovementOps.RESHAPE: lambda x,_:x, MovementOps.AS_STRIDED: lambda x,_:x } -ShmBuffer = Interpreted(RawShmBuffer, functools.partial(interpret_ast, shm_fxn_for_op, None)) +ShmBuffer = Interpreted(RawShmBuffer, shm_fxn_for_op) diff --git a/tinygrad/runtime/ops_torch.py b/tinygrad/runtime/ops_torch.py index 3f9ed86b..e489ba02 100644 --- a/tinygrad/runtime/ops_torch.py +++ b/tinygrad/runtime/ops_torch.py @@ -1,11 +1,10 @@ -import torch, functools +import torch import numpy as np from typing import Dict, Callable, Optional from tinygrad.ops import BufferOps, UnaryOps, BinaryOps, MovementOps, TernaryOps, Op, Interpreted from tinygrad.helpers import getenv, dtypes, prod, DType from tinygrad.runtime.ops_cpu import base_fxn_for_op, einsum_mulacc from tinygrad.runtime.lib import RawBuffer -from tinygrad.runtime.interpreted import interpret_ast device = torch.device("cuda:0" if torch.cuda.is_available() else ("mps" if getenv("MPS", 0) else "cpu")) type_map = {torch.float64: dtypes.float64, torch.float16: dtypes.float16, torch.float32: dtypes.float32, torch.int8: dtypes.int8, torch.int32: dtypes.int32, torch.int64: dtypes.int64, torch.uint8: dtypes.uint8, torch.bool: dtypes.bool, torch.int16: dtypes.int16} @@ -49,4 +48,4 @@ class RawTorchBuffer(RawBuffer): buf = torch.from_numpy(x if all(s>=0 for s in x.strides) else x.copy()).requires_grad_(False).to(device) return cls(prod(x.shape), type_map[buf.dtype], buf) def toCPU(self): return self._buf.cpu().numpy() -TorchBuffer = Interpreted(RawTorchBuffer, functools.partial(interpret_ast, torch_fxn_for_op, lambda x: RawTorchBuffer(prod(x.shape), type_map[x.dtype], x))) +TorchBuffer = Interpreted(RawTorchBuffer, torch_fxn_for_op, lambda x: RawTorchBuffer(prod(x.shape), type_map[x.dtype], x))