mirror of https://github.com/commaai/tinygrad.git
JIT support in Interpreted (#2314)
* factor that out * jit is supported everywhere * fix some tests * there's no jit supported device, the jit is everywhere * fix test uops
This commit is contained in:
parent
9a20bc08d6
commit
70a65c201e
|
@ -1,6 +1,12 @@
|
|||
repos:
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: mypy
|
||||
name: mypy
|
||||
entry: mypy tinygrad/ extra/helpers.py
|
||||
language: system
|
||||
always_run: true
|
||||
pass_filenames: false
|
||||
- id: ruff
|
||||
name: ruff
|
||||
entry: ruff .
|
||||
|
@ -19,12 +25,6 @@ repos:
|
|||
language: system
|
||||
always_run: true
|
||||
pass_filenames: false
|
||||
- id: mypy
|
||||
name: mypy
|
||||
entry: mypy tinygrad/ extra/helpers.py
|
||||
language: system
|
||||
always_run: true
|
||||
pass_filenames: false
|
||||
- id: tests
|
||||
name: subset of TORCH tests
|
||||
entry: env PYTHONPATH="." TORCH=1 python3 -m pytest -n=4 test/unit/ test/test_ops.py test/test_dtype.py test/test_schedule.py test/test_custom_function.py test/test_assign.py test/test_symbolic_shapetracker.py
|
||||
|
|
|
@ -15,11 +15,11 @@ from tinygrad.tensor import Tensor
|
|||
from tinygrad.nn import Embedding, Linear
|
||||
from tinygrad.nn.state import safe_load, torch_load, load_state_dict, get_parameters
|
||||
from tinygrad.helpers import GlobalCounters
|
||||
from tinygrad.jit import TinyJit, JIT_SUPPORTED_DEVICE
|
||||
from tinygrad.jit import TinyJit
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
|
||||
MAX_CONTEXT = 1024
|
||||
JIT = getenv("JIT", 0 if CI else int(Device.DEFAULT in JIT_SUPPORTED_DEVICE))
|
||||
JIT = getenv("JIT", 0 if CI else 1)
|
||||
|
||||
# https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47
|
||||
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> Tensor:
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.jit import TinyJit, JIT_SUPPORTED_DEVICE
|
||||
from tinygrad.jit import TinyJit
|
||||
from tinygrad.helpers import dtypes, CI
|
||||
from tinygrad.ops import Device
|
||||
from test.helpers import derandomize_model
|
||||
|
@ -14,7 +14,6 @@ def helper_test_jitted_correctness(gen, train, train_jit):
|
|||
for _ in range(5): jit = train_jit(*gen()).numpy()
|
||||
np.testing.assert_allclose(nojit, jit, rtol=1e-3, atol=1e-5)
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT in JIT_SUPPORTED_DEVICE, "needs JIT")
|
||||
class TestJittedModels(unittest.TestCase):
|
||||
def test_jitted_tiny_llama(self):
|
||||
old_type = Tensor.default_type
|
||||
|
|
|
@ -3,7 +3,7 @@ import numpy as np
|
|||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn import optim
|
||||
from tinygrad.nn.state import get_parameters
|
||||
from tinygrad.jit import TinyJit, JIT_SUPPORTED_DEVICE
|
||||
from tinygrad.jit import TinyJit
|
||||
from tinygrad.ops import Device, GlobalCounters
|
||||
from tinygrad.helpers import CI, dtypes, getenv, prod
|
||||
from test.helpers import derandomize_model
|
||||
|
@ -47,7 +47,7 @@ class TestRealWorld(unittest.TestCase):
|
|||
def test(t, t2): return model(t, 801, t2).realize()
|
||||
helper_test("test_sd", lambda: (Tensor.randn(1, 4, 64, 64),Tensor.randn(1, 77, 768)), test, 18.0, 967)
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT in JIT_SUPPORTED_DEVICE and Device.DEFAULT not in ["LLVM"], "needs JIT, too long on CI LLVM")
|
||||
@unittest.skipUnless((Device.DEFAULT not in ["LLVM", "CPU"] or not CI), "needs JIT, too long on CI LLVM")
|
||||
def test_llama(self):
|
||||
Tensor.default_type = dtypes.float16
|
||||
|
||||
|
@ -59,7 +59,7 @@ class TestRealWorld(unittest.TestCase):
|
|||
# NOTE: only test one pass, not testing the dynamic shape autoregressive part
|
||||
helper_test("test_llama", lambda: (Tensor([[1,]]),), test, 0.22 if CI else 13.5, 137 if CI else 521, all_jitted=True)
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT in JIT_SUPPORTED_DEVICE and (Device.DEFAULT not in ["LLVM"] or not CI), "needs JIT, too long on CI LLVM")
|
||||
@unittest.skipUnless((Device.DEFAULT not in ["LLVM", "CPU"] or not CI), "needs JIT, too long on CI LLVM")
|
||||
def test_gpt2(self):
|
||||
Tensor.default_type = dtypes.float16
|
||||
|
||||
|
@ -70,7 +70,7 @@ class TestRealWorld(unittest.TestCase):
|
|||
def test(t): return model(t, 0).realize()
|
||||
helper_test("test_gpt2", lambda: (Tensor([[1,]]),), test, 0.21 if CI else 0.9, 140 if CI else 396, all_jitted=True)
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT in JIT_SUPPORTED_DEVICE and (Device.DEFAULT not in ["LLVM", "CLANG"] or not CI), "needs JIT, too long on CI LLVM and CLANG")
|
||||
@unittest.skipUnless((Device.DEFAULT not in ["LLVM", "CLANG", "CPU"] or not CI), "needs JIT, too long on CI LLVM and CLANG")
|
||||
def test_train_cifar(self):
|
||||
# TODO: with default device
|
||||
#old_default = Device.DEFAULT
|
||||
|
|
|
@ -9,7 +9,7 @@ from tinygrad.helpers import prod, dtypes
|
|||
# *** first, we implement the atan2 op at the lowest level ***
|
||||
# `atan2_gpu` for GPUBuffers and `atan2_cpu` for CPUBuffers
|
||||
from tinygrad.lazy import LazyBuffer, create_lazybuffer
|
||||
from tinygrad.ops import ASTRunner, Device
|
||||
from tinygrad.ops import CompiledASTRunner, Device
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
import pytest
|
||||
|
||||
|
@ -20,7 +20,7 @@ def atan2_gpu(ret:LazyBuffer, a:LazyBuffer, b:LazyBuffer):
|
|||
assert a.device == "GPU" and b.device == "GPU", "gpu function requires GPUBuffers"
|
||||
assert a.dtype == b.dtype and a.dtype == dtypes.float32, "gpu function only supports float32"
|
||||
ret.realized = Device[ret.device].buffer(prod(ret.shape), ret.dtype)
|
||||
ASTRunner("atan2_gpu", """
|
||||
CompiledASTRunner(None, "atan2_gpu", """
|
||||
__kernel void atan2_gpu(global float *c, global float *a, global float *b) {
|
||||
int idx = get_global_id(0);
|
||||
c[idx] = atan2(a[idx], b[idx]);
|
||||
|
@ -89,6 +89,7 @@ class TestCustomFunction(unittest.TestCase):
|
|||
np.testing.assert_allclose(a.grad.numpy(), ta.grad.numpy(), atol=1e-5)
|
||||
np.testing.assert_allclose(b.grad.numpy(), tb.grad.numpy(), atol=1e-5)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT in ["CPU"], "atan2_cpu not jittable")
|
||||
def test_atan2_jit(self):
|
||||
# custom ops even work in the JIT!
|
||||
from tinygrad.jit import TinyJit
|
||||
|
|
|
@ -2,13 +2,13 @@
|
|||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad.tensor import Tensor, Device
|
||||
from tinygrad.jit import TinyJit, JIT_SUPPORTED_DEVICE
|
||||
from tinygrad.jit import TinyJit
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.webgpu
|
||||
|
||||
# NOTE: METAL fails, might be platform and optimization options dependent.
|
||||
@unittest.skipUnless(Device.DEFAULT in JIT_SUPPORTED_DEVICE and Device.DEFAULT not in ["METAL", "WEBGPU"], f"no JIT on {Device.DEFAULT}")
|
||||
@unittest.skipUnless(Device.DEFAULT not in ["METAL", "WEBGPU"], f"no JIT on {Device.DEFAULT}")
|
||||
class TestJit(unittest.TestCase):
|
||||
def test_simple_jit(self):
|
||||
@TinyJit
|
||||
|
|
|
@ -6,7 +6,6 @@ from tinygrad.tensor import Tensor, Device
|
|||
import numpy as np
|
||||
|
||||
@unittest.skipIf(getenv("ARM64") or getenv("PTX"), "ARM64 and PTX are not supported")
|
||||
@unittest.skipUnless(Device.DEFAULT in ["GPU", "METAL", "CLANG", "CUDA", "LLVM"], f"{Device.DEFAULT} is not supported")
|
||||
class TestSymbolicJit(unittest.TestCase):
|
||||
def test_plus1(self):
|
||||
def f(a): return (a+1).realize()
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
import unittest
|
||||
from tinygrad.jit import JIT_SUPPORTED_DEVICE
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.helpers import getenv
|
||||
from tinygrad.tensor import Tensor, Device
|
||||
|
|
|
@ -3,14 +3,14 @@ import unittest, math
|
|||
import numpy as np
|
||||
from tinygrad.helpers import dtypes, getenv, DType, PtrDType
|
||||
from tinygrad.tensor import Device
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, ASTRunner, Compiled
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, CompiledASTRunner, Compiled
|
||||
from tinygrad.codegen.linearizer import UOps, UOp
|
||||
|
||||
def _uops_to_prg(uops):
|
||||
src, runtime_args = Device[Device.DEFAULT].renderer("test", uops)
|
||||
return ASTRunner("test", src,
|
||||
[1] if Device[Device.DEFAULT].linearizer_opts.has_local else None, [1] if Device[Device.DEFAULT].linearizer_opts.has_local else None,
|
||||
runtime_args=runtime_args).build(Device[Device.DEFAULT].compiler, Device[Device.DEFAULT].runtime)
|
||||
return CompiledASTRunner(None, "test", src,
|
||||
[1] if Device[Device.DEFAULT].linearizer_opts.has_local else None, [1] if Device[Device.DEFAULT].linearizer_opts.has_local else None,
|
||||
runtime_args=runtime_args).build(Device[Device.DEFAULT].compiler, Device[Device.DEFAULT].runtime)
|
||||
|
||||
def uop(uops:List[UOp], uop:UOps, dtype:Optional[DType], vin:Tuple[UOp, ...], arg:Any=None) -> UOp:
|
||||
uops.append(UOp(uop, dtype, tuple(vin), arg))
|
||||
|
|
|
@ -8,8 +8,6 @@ from tinygrad.shape.shapetracker import ShapeTracker
|
|||
from tinygrad.shape.symbolic import Variable
|
||||
from weakref import ref, WeakKeyDictionary
|
||||
|
||||
JIT_SUPPORTED_DEVICE = ["GPU", "CLANG", "METAL", "CUDA", "HIP", "WEBGPU", "LLVM"]
|
||||
|
||||
class TinyJit:
|
||||
def __init__(self, fxn:Callable):
|
||||
self.fxn: Callable = fxn
|
||||
|
@ -28,8 +26,6 @@ class TinyJit:
|
|||
def __get__(self, obj, objtype): return functools.partial(self.__call__, obj)
|
||||
|
||||
def __call__(self, *args, **kwargs) -> Any:
|
||||
if Device.DEFAULT.split(":")[0] not in JIT_SUPPORTED_DEVICE: return self.fxn(*args, **kwargs) # only jit on supported device
|
||||
|
||||
# all inputs are realized
|
||||
input_tensors: Dict[Union[int, str], Tensor] = {cast(Union[int, str], k):v.realize() for k,v in itertools.chain(enumerate(args), kwargs.items()) if v.__class__ is Tensor}
|
||||
expected_sts_dtype = tuple([(v.lazydata.st.unbind(), v.dtype) for v in input_tensors.values()])
|
||||
|
|
252
tinygrad/ops.py
252
tinygrad/ops.py
|
@ -1,5 +1,5 @@
|
|||
from __future__ import annotations
|
||||
import importlib, inspect, functools, pathlib
|
||||
import importlib, inspect, functools, pathlib, time
|
||||
from enum import Enum, auto
|
||||
from typing import TYPE_CHECKING, Union, Type, Tuple, Any, List, Optional, Dict, Callable, Mapping, cast
|
||||
from tinygrad.helpers import ansilen, prod, DEBUG, getenv, GlobalCounters, DType, colored, BEAM, NOOPT, dedup, all_int
|
||||
|
@ -26,6 +26,8 @@ OpType = Union[Type[UnaryOps], Type[BinaryOps], Type[ReduceOps], Type[MovementOp
|
|||
if TYPE_CHECKING:
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
from tinygrad.codegen.kernel import LinearizerOptions
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MemBuffer:
|
||||
|
@ -104,48 +106,6 @@ class _Device:
|
|||
return "CPU"
|
||||
Device = _Device()
|
||||
|
||||
# **************** batch executor ****************
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class JitItem:
|
||||
prg: ASTRunner
|
||||
rawbufs: List[Optional[RawBuffer]]
|
||||
|
||||
class BatchExecutor:
|
||||
def __init__(self, jit_cache: List[JitItem], input_rawbuffers: Dict[Union[int, str], RawBuffer], var_vals: Dict[Variable, int]):
|
||||
self.jit_cache: List[JitItem] = jit_cache
|
||||
self.input_replace: Dict[Tuple[int, int], Union[int, str]] = {}
|
||||
self.op_estimate, self.mem_estimate = NumNode(0), NumNode(0)
|
||||
for j,ji in enumerate(jit_cache):
|
||||
if isinstance(ji.prg, ASTRunner): # TODO: this is just for world and needs to be refactored
|
||||
self.op_estimate += ji.prg.op_estimate
|
||||
self.mem_estimate += ji.prg.mem_estimate
|
||||
for i,a in enumerate(ji.rawbufs):
|
||||
if a in [v for v in input_rawbuffers.values()]:
|
||||
self.input_replace[(j,i)] = [k for k,v in input_rawbuffers.items() if v == a][0]
|
||||
assert set(self.input_replace.values()) == set(input_rawbuffers.keys()), "some input tensors not found"
|
||||
self.clear_jit_inputs()
|
||||
|
||||
def __call__(self, input_rawbuffers: Dict[Union[int, str], RawBuffer], var_vals: Dict[Variable, int], wait=False):
|
||||
for (j,i),input_name in self.input_replace.items(): self.jit_cache[j].rawbufs[i] = input_rawbuffers[input_name]
|
||||
for ji in self.jit_cache: ji.prg(cast(List[RawBuffer], ji.rawbufs), var_vals, jit=True)
|
||||
self.clear_jit_inputs()
|
||||
|
||||
def update_stats(self, var_vals: Dict[Variable, int], et: Optional[float]):
|
||||
# TODO: this is mostly copied from ASTRunner
|
||||
op_estimate = sym_infer(self.op_estimate, var_vals)
|
||||
mem_estimate = sym_infer(self.mem_estimate, var_vals)
|
||||
if DEBUG >= 2:
|
||||
print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', 'CYAN')} kernels:{len(self.jit_cache):4d} inputs:{len(self.input_replace):3d} {' '.join([f'{k.expr}={v}' for k,v in var_vals.items()])[:50]:50s} OPs {int(op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " +
|
||||
(str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)"))
|
||||
GlobalCounters.kernel_count += len(self.jit_cache)
|
||||
GlobalCounters.global_ops += sym_infer(self.op_estimate, var_vals)
|
||||
GlobalCounters.global_mem += sym_infer(self.mem_estimate, var_vals)
|
||||
if et is not None: GlobalCounters.time_sum_s += et
|
||||
|
||||
def clear_jit_inputs(self):
|
||||
for (j,i) in self.input_replace.keys(): self.jit_cache[j].rawbufs[i] = None
|
||||
|
||||
# **************** independent FlopCounter ****************
|
||||
|
||||
@dataclass
|
||||
|
@ -174,114 +134,159 @@ def get_lazyop_info(ast:LazyOp) -> FlopCounter:
|
|||
def run_ast(ast): return InterpretedFlopCounter[ast.op](*([run_ast(x) for x in ast.src]+([ast.arg] if ast.arg is not None else [])))
|
||||
return run_ast(ast)
|
||||
|
||||
# **************** batch executor ****************
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class JitItem:
|
||||
prg: ASTRunner
|
||||
rawbufs: List[Optional[RawBuffer]]
|
||||
|
||||
class BatchExecutor:
|
||||
def __init__(self, jit_cache: List[JitItem], input_rawbuffers: Dict[Union[int, str], RawBuffer], var_vals: Dict[Variable, int]):
|
||||
self.jit_cache: List[JitItem] = jit_cache
|
||||
self.input_replace: Dict[Tuple[int, int], Union[int, str]] = {}
|
||||
self.op_estimate, self.mem_estimate = NumNode(0), NumNode(0)
|
||||
for j,ji in enumerate(jit_cache):
|
||||
if isinstance(ji.prg, ASTRunner): # TODO: this is just for world and needs to be refactored
|
||||
self.op_estimate += ji.prg.op_estimate
|
||||
self.mem_estimate += ji.prg.mem_estimate
|
||||
for i,a in enumerate(ji.rawbufs):
|
||||
if a in [v for v in input_rawbuffers.values()]:
|
||||
self.input_replace[(j,i)] = [k for k,v in input_rawbuffers.items() if v == a][0]
|
||||
assert set(self.input_replace.values()) == set(input_rawbuffers.keys()), "some input tensors not found"
|
||||
self.clear_jit_inputs()
|
||||
|
||||
def __call__(self, input_rawbuffers: Dict[Union[int, str], RawBuffer], var_vals: Dict[Variable, int], wait=False):
|
||||
for (j,i),input_name in self.input_replace.items(): self.jit_cache[j].rawbufs[i] = input_rawbuffers[input_name]
|
||||
for ji in self.jit_cache: ji.prg(ji.rawbufs, var_vals, jit=True)
|
||||
self.clear_jit_inputs()
|
||||
|
||||
def update_stats(self, var_vals: Dict[Variable, int], et: Optional[float]):
|
||||
# TODO: this is mostly copied from ASTRunner
|
||||
op_estimate = sym_infer(self.op_estimate, var_vals)
|
||||
mem_estimate = sym_infer(self.mem_estimate, var_vals)
|
||||
if DEBUG >= 2:
|
||||
print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', 'CYAN')} kernels:{len(self.jit_cache):4d} inputs:{len(self.input_replace):3d} {' '.join([f'{k.expr}={v}' for k,v in var_vals.items()])[:50]:50s} OPs {int(op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " +
|
||||
(str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)"))
|
||||
GlobalCounters.kernel_count += len(self.jit_cache)
|
||||
GlobalCounters.global_ops += sym_infer(self.op_estimate, var_vals)
|
||||
GlobalCounters.global_mem += sym_infer(self.mem_estimate, var_vals)
|
||||
if et is not None: GlobalCounters.time_sum_s += et
|
||||
|
||||
def clear_jit_inputs(self):
|
||||
for (j,i) in self.input_replace.keys(): self.jit_cache[j].rawbufs[i] = None
|
||||
|
||||
class ASTRunner:
|
||||
def __init__(self, ast:Optional[LazyOp]):
|
||||
if ast is None:
|
||||
self.op_estimate, self.mem_estimate, self.vars = 0, 0, []
|
||||
else:
|
||||
info = get_lazyop_info(ast)
|
||||
self.op_estimate, self.mem_estimate = info.flops, info.mem_estimate
|
||||
from tinygrad.lazy import vars_from_ast
|
||||
self.vars = vars_from_ast(ast)
|
||||
assert all(v._val is None for v in self.vars), f"ASTRunner contains bound Variable {self.vars}"
|
||||
|
||||
def exec(self, rawbufs:List[Optional[RawBuffer]], var_vals:Optional[Dict[Variable, int]]=None, force_wait=False) -> Optional[float]:
|
||||
from tinygrad.jit import CacheCollector
|
||||
et = self(rawbufs, var_vals, force_wait=force_wait)
|
||||
CacheCollector.add(self, rawbufs, var_vals if var_vals is not None else {})
|
||||
return et
|
||||
|
||||
def update_stats(self, name, var_vals: Optional[Dict[Variable, int]], et: Optional[float], buf_count, lra, jit):
|
||||
if var_vals is None: var_vals = {}
|
||||
op_estimate = sym_infer(self.op_estimate, var_vals)
|
||||
mem_estimate = sym_infer(self.mem_estimate, var_vals)
|
||||
if DEBUG >= 2:
|
||||
print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', 'magenta' if jit else None)} {name+' '*(37-ansilen(name))} arg {buf_count:3d} sz {str(lra.get('global_size', '')):18s} {str(lra.get('local_size', '')):12s} OPs {int(op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " +
|
||||
(str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)"))
|
||||
GlobalCounters.kernel_count += 1
|
||||
GlobalCounters.global_ops += op_estimate
|
||||
GlobalCounters.global_mem += mem_estimate
|
||||
|
||||
def __call__(self, rawbufs:List[Optional[RawBuffer]], var_vals:Optional[Dict[Variable, int]]=None, jit=False, force_wait=False) -> Optional[float]:
|
||||
raise NotImplementedError("override this")
|
||||
|
||||
# **************** for Interpreted Buffers ****************
|
||||
|
||||
class InterpretedASTRunner(ASTRunner):
|
||||
def __init__(self, ast:LazyOp, fxn:Callable):
|
||||
self.fxn = fxn
|
||||
super().__init__(ast)
|
||||
|
||||
def __call__(self, rawbufs:List[Optional[RawBuffer]], var_vals:Optional[Dict[Variable, int]]=None, jit=False, force_wait=False) -> float:
|
||||
st = time.perf_counter()
|
||||
ret: RawBuffer = self.fxn(rawbufs[1:], var_vals)
|
||||
et = time.perf_counter() - st
|
||||
self.update_stats(f"<interpreted {ret.size}>", var_vals, et, len(rawbufs), {}, jit)
|
||||
if rawbufs[0] is not None:
|
||||
assert rawbufs[0].dtype == ret.dtype
|
||||
rawbufs[0].size = ret.size # NOTE: for symbolic this can change
|
||||
rawbufs[0]._buf = ret._buf
|
||||
else: rawbufs[0] = ret
|
||||
return et
|
||||
|
||||
from tinygrad.runtime.interpreted import interpret_ast
|
||||
class Interpreted:
|
||||
def __init__(self, buffer: Type[RawBuffer], compiler: Callable[[LazyOp], Callable]):
|
||||
self.buffer, self.compiler = buffer, compiler
|
||||
def __init__(self, buffer: Type[RawBuffer], fxn_for_op:Dict[Op, Callable], from_underlying:Optional[Callable]=None):
|
||||
self.buffer, self.fxn_for_op, self.from_underlying = buffer, fxn_for_op, from_underlying
|
||||
self.synchronize = lambda: None
|
||||
self.batch_executor = BatchExecutor
|
||||
self.codegen = None
|
||||
self.method_cache: Dict[LazyOp, Callable] = {}
|
||||
self.method_cache: Dict[LazyOp, InterpretedASTRunner] = {}
|
||||
|
||||
def exec_ast(self, ast:LazyOp, output:LazyBuffer, inputs:Tuple[LazyBuffer, ...], var_vals:Dict[Variable, int], **kwargs):
|
||||
if ast not in self.method_cache: self.method_cache[ast] = self.compiler(ast)
|
||||
output.realized = output.output_buffer # NOTE: assuming this is the right size and dtype from assign
|
||||
ret: RawBuffer = self.method_cache[ast]([x.realized for x in inputs] if inputs else None, var_vals)
|
||||
assert output.dtype == ret.dtype, f"expected {output.dtype}, got {ret.dtype}"
|
||||
if output.realized is not None: output.realized._buf = ret._buf
|
||||
else: output.realized = ret
|
||||
if ast not in self.method_cache: self.method_cache[ast] = InterpretedASTRunner(ast, interpret_ast(self.fxn_for_op, self.from_underlying, ast))
|
||||
rawbufs = [output.realized if output.realized is not None else output.output_buffer] + [x.realized for x in inputs]
|
||||
self.method_cache[ast].exec(rawbufs, var_vals)
|
||||
output.realized = rawbufs[0]
|
||||
|
||||
# **************** for Compiled Buffers ****************
|
||||
|
||||
class ASTRunner:
|
||||
def __init__(self, name:str, prg:str, global_size:Optional[List[int]]=None, local_size:Optional[List[int]]=None, op_estimate=0, mem_estimate=0, display_name:Optional[str]=None, runtime_args:Optional[dict]=None):
|
||||
class CompiledASTRunner(ASTRunner):
|
||||
def __init__(self, ast:Optional[LazyOp], name:str, prg:str, global_size:Optional[List[int]]=None, local_size:Optional[List[int]]=None, op_estimate=0, mem_estimate=0, display_name:Optional[str]=None, runtime_args:Optional[dict]=None):
|
||||
if DEBUG >= 4: print(prg)
|
||||
self.name, self.prg, self.global_size, self.local_size, self.op_estimate, self.mem_estimate, self.display_name, self.runtime_args = name, prg, global_size, local_size, op_estimate, mem_estimate, display_name, runtime_args if runtime_args is not None else {}
|
||||
self.vars:List[Variable] = []
|
||||
self.name, self.prg, self.global_size, self.local_size, self.op_estimate, self.mem_estimate, self.display_name, self.runtime_args = \
|
||||
name, prg, global_size, local_size, op_estimate, mem_estimate, display_name, runtime_args if runtime_args is not None else {}
|
||||
super().__init__(ast)
|
||||
|
||||
def build(self, compiler, runtime):
|
||||
self.lib = compiler.__wrapped__(self.prg) if getenv("DISABLE_COMPILER_CACHE") else compiler(self.prg)
|
||||
self.clprg = runtime(self.name, self.lib)
|
||||
return self
|
||||
|
||||
def exec(self, rawbufs:List[RawBuffer], var_vals:Optional[Dict[Variable, int]]=None, force_wait=False) -> Optional[float]:
|
||||
from tinygrad.jit import CacheCollector
|
||||
CacheCollector.add(self, rawbufs, var_vals if var_vals is not None else {})
|
||||
return self(rawbufs, var_vals, force_wait=force_wait)
|
||||
|
||||
def launch_dims(self, var_vals):
|
||||
global_size = ([sym_infer(sz, var_vals) for sz in self.global_size] + [1]*(3-len(self.global_size))) if self.global_size is not None else self.global_size
|
||||
local_size = ([sym_infer(sz, var_vals) for sz in self.local_size] + [1]*(3-len(self.local_size))) if self.local_size is not None else self.local_size
|
||||
return global_size, local_size
|
||||
|
||||
def __call__(self, rawbufs:List[RawBuffer], var_vals:Optional[Dict[Variable, int]]=None, jit=False, force_wait=False) -> Optional[float]:
|
||||
def __call__(self, rawbufs:List[Optional[RawBuffer]], var_vals:Optional[Dict[Variable, int]]=None, jit=False, force_wait=False) -> Optional[float]:
|
||||
if var_vals is None: var_vals = {}
|
||||
var_vals = {k:var_vals[k] for k in self.vars} # filter the var_vals
|
||||
global_size, local_size = self.launch_dims(var_vals)
|
||||
if global_size is not None and local_size is None and all_int(self.global_size): # type: ignore[arg-type]
|
||||
# TODO: this is copied from get_program
|
||||
from tinygrad.features.search import optimize_local_size
|
||||
local_size = self.local_size = optimize_local_size(self.clprg, global_size, rawbufs)
|
||||
local_size = self.local_size = optimize_local_size(self.clprg, global_size, cast(List[RawBuffer], rawbufs))
|
||||
global_size = self.global_size = [g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)]
|
||||
lra = self.runtime_args.copy()
|
||||
if global_size: lra['global_size'] = global_size
|
||||
if local_size and 'local_size' not in lra: lra['local_size'] = local_size
|
||||
if et := self.clprg(*rawbufs, *var_vals.values(), **lra, wait=force_wait or DEBUG>=2): GlobalCounters.time_sum_s += et
|
||||
op_estimate = sym_infer(self.op_estimate, var_vals)
|
||||
mem_estimate = sym_infer(self.mem_estimate, var_vals)
|
||||
if DEBUG >= 2:
|
||||
print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', 'magenta' if jit else None)} {(self.display_name+' '*(37-ansilen(self.display_name))) if self.display_name is not None else self.name:33s} arg {len(rawbufs):3d} sz {str(global_size):18s} {str(local_size):12s} OPs {int(op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " +
|
||||
(str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)"))
|
||||
GlobalCounters.kernel_count += 1
|
||||
GlobalCounters.global_ops += op_estimate
|
||||
GlobalCounters.global_mem += mem_estimate
|
||||
self.update_stats(self.display_name if self.display_name is not None else self.name, var_vals, et, len(rawbufs), lra, jit)
|
||||
return et
|
||||
|
||||
class Compiled:
|
||||
def __init__(self, buffer: Type[RawBuffer], linearizer_opts, renderer, compiler, runtime, synchronize=lambda: None, batch_executor=BatchExecutor):
|
||||
def __init__(self, buffer: Type[RawBuffer], linearizer_opts:LinearizerOptions, renderer, compiler, runtime, synchronize=lambda: None, batch_executor=BatchExecutor):
|
||||
self.buffer, self.linearizer_opts, self.renderer, self.compiler, self.runtime, self.synchronize = buffer, linearizer_opts, renderer, compiler, runtime, synchronize
|
||||
self.batch_executor = BatchExecutor if getenv("JIT") == 2 else batch_executor
|
||||
self.method_cache: Dict[LazyOp, ASTRunner] = {}
|
||||
self.method_cache: Dict[LazyOp, CompiledASTRunner] = {}
|
||||
|
||||
def to_program(self, k) -> ASTRunner:
|
||||
def to_program(self, k:Linearizer) -> CompiledASTRunner:
|
||||
k.linearize()
|
||||
src, runtime_args = self.renderer(k.function_name, k.uops)
|
||||
return ASTRunner(k.function_name, src, k.global_size, k.local_size,
|
||||
op_estimate=k.info.flops, mem_estimate=k.info.mem_estimate,
|
||||
display_name=k.display_name, runtime_args=runtime_args).build(self.compiler, self.runtime)
|
||||
|
||||
def get_optimized_program(self, ast:LazyOp, rawbuffers:List[RawBuffer]) -> ASTRunner:
|
||||
if DEBUG >= 3:
|
||||
from tinygrad.graph import print_tree
|
||||
print_tree(ast)
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
from tinygrad.lazy import vars_from_ast
|
||||
k = Linearizer(ast, self.linearizer_opts)
|
||||
assert k.info.dtype == rawbuffers[0].dtype, f"linearizer must match dtype. linearizer wants {k.info.dtype} but buffer is {rawbuffers[0].dtype}"
|
||||
if not NOOPT:
|
||||
if not (used_tensor_cores:=k.apply_tensor_cores(getenv("TC", 1))): k.hand_coded_optimizations()
|
||||
if BEAM >= 1:
|
||||
lins = [(("tc" if used_tensor_cores else "hc"), k)]
|
||||
# allocate a scratch buffer if output buffer is also input
|
||||
test_rawbuffers = [type(rawbuffers[0])(rawbuffers[0].size, rawbuffers[0].dtype), *rawbuffers[1:]] if rawbuffers[0] in rawbuffers[1:] else rawbuffers
|
||||
kb = Linearizer(ast, self.linearizer_opts)
|
||||
kb.required_optimizations()
|
||||
from tinygrad.features.search import beam_search, time_linearizer
|
||||
lins.append((f"beam{BEAM.value}", beam_search(kb, test_rawbuffers, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1)))))
|
||||
if used_tensor_cores:
|
||||
lins.append(("hc", Linearizer(ast, self.linearizer_opts)))
|
||||
lins[-1][1].hand_coded_optimizations()
|
||||
timed = sorted([(nm, tk, time_linearizer(tk, test_rawbuffers, allow_test_size=False, disable_cache=True, clear_l2=True)) for nm, tk in lins], key=lambda x: x[2])
|
||||
if DEBUG >= 1: print(" < ".join(f"{nm:6s} : {lin.colored_shape(30, dense=True)} : {tm*1e6:8.2f} us" for nm, lin, tm in timed))
|
||||
k = timed[0][1]
|
||||
else:
|
||||
k.required_optimizations()
|
||||
prg = self.to_program(k)
|
||||
# extract real vars used in ast
|
||||
prg.vars = vars_from_ast(ast)
|
||||
assert all(v._val is None for v in prg.vars), f"ast contains bound Variable {prg.vars}"
|
||||
return prg
|
||||
return CompiledASTRunner(k.ast, k.function_name, src, k.global_size, k.local_size,
|
||||
display_name=k.display_name, runtime_args=runtime_args).build(self.compiler, self.runtime)
|
||||
|
||||
def exec_ast(self, ast:LazyOp, output:LazyBuffer, inputs:Tuple[LazyBuffer, ...], var_vals:Dict[Variable, int], **kwargs):
|
||||
# check if we can reuse the output buffer
|
||||
|
@ -305,5 +310,32 @@ class Compiled:
|
|||
# all the rawbuffers
|
||||
rawbuffers = [output.realized] + [x.realized for x in inputs]
|
||||
|
||||
if ast not in self.method_cache: self.method_cache[ast] = self.get_optimized_program(ast, rawbuffers)
|
||||
if ast not in self.method_cache: self.method_cache[ast] = get_optimized_program(self.linearizer_opts, self.to_program, ast, rawbuffers)
|
||||
self.method_cache[ast].exec(rawbuffers, var_vals)
|
||||
|
||||
def get_optimized_program(linearizer_opts:LinearizerOptions, to_program, ast:LazyOp, rawbuffers:List[RawBuffer]) -> CompiledASTRunner:
|
||||
if DEBUG >= 3:
|
||||
from tinygrad.graph import print_tree
|
||||
print_tree(ast)
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
k = Linearizer(ast, linearizer_opts)
|
||||
assert k.info.dtype == rawbuffers[0].dtype, f"linearizer must match dtype. linearizer wants {k.info.dtype} but buffer is {rawbuffers[0].dtype}"
|
||||
if not NOOPT:
|
||||
if not (used_tensor_cores:=k.apply_tensor_cores(getenv("TC", 1))): k.hand_coded_optimizations()
|
||||
if BEAM >= 1:
|
||||
lins = [(("tc" if used_tensor_cores else "hc"), k)]
|
||||
# allocate a scratch buffer if output buffer is also input
|
||||
test_rawbuffers = [type(rawbuffers[0])(rawbuffers[0].size, rawbuffers[0].dtype), *rawbuffers[1:]] if rawbuffers[0] in rawbuffers[1:] else rawbuffers
|
||||
kb = Linearizer(ast, linearizer_opts)
|
||||
kb.required_optimizations()
|
||||
from tinygrad.features.search import beam_search, time_linearizer
|
||||
lins.append((f"beam{BEAM.value}", beam_search(kb, test_rawbuffers, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1)))))
|
||||
if used_tensor_cores:
|
||||
lins.append(("hc", Linearizer(ast, linearizer_opts)))
|
||||
lins[-1][1].hand_coded_optimizations()
|
||||
timed = sorted([(nm, tk, time_linearizer(tk, test_rawbuffers, allow_test_size=False, disable_cache=True, clear_l2=True)) for nm, tk in lins], key=lambda x: x[2])
|
||||
if DEBUG >= 1: print(" < ".join(f"{nm:6s} : {lin.colored_shape(30, dense=True)} : {tm*1e6:8.2f} us" for nm, lin, tm in timed))
|
||||
k = timed[0][1]
|
||||
else:
|
||||
k.required_optimizations()
|
||||
return to_program(k)
|
||||
|
|
|
@ -1,10 +1,9 @@
|
|||
import numpy as np
|
||||
import operator, functools
|
||||
import operator
|
||||
from typing import Callable, Dict, Tuple, Optional
|
||||
from tinygrad.helpers import dtypes, DType
|
||||
from tinygrad.ops import BufferOps, UnaryOps, BinaryOps, MovementOps, ReduceOps, TernaryOps, Op, Interpreted
|
||||
from tinygrad.runtime.lib import RawBuffer
|
||||
from tinygrad.runtime.interpreted import interpret_ast
|
||||
|
||||
def shape_to_axis(old_shape:Tuple[int, ...], new_shape:Tuple[int, ...]) -> Tuple[int, ...]:
|
||||
assert len(old_shape) == len(new_shape), "reduce shapes must have same dimensions"
|
||||
|
@ -52,4 +51,4 @@ class RawNumpyBuffer(RawBuffer):
|
|||
@classmethod
|
||||
def fromCPU(cls, x): return cls(x.size, dtypes.from_np(x.dtype), x)
|
||||
def toCPU(self): return self._buf
|
||||
CPUBuffer = Interpreted(RawNumpyBuffer, functools.partial(interpret_ast, numpy_fxn_for_op, RawNumpyBuffer.fromCPU))
|
||||
CPUBuffer = Interpreted(RawNumpyBuffer, numpy_fxn_for_op, RawNumpyBuffer.fromCPU)
|
||||
|
|
|
@ -1,9 +1,8 @@
|
|||
import os, mmap, functools
|
||||
import os, mmap
|
||||
from typing import Optional
|
||||
from typing import Callable, Dict, Tuple
|
||||
from tinygrad.helpers import prod, DType
|
||||
from tinygrad.runtime.lib import RawBufferMapped
|
||||
from tinygrad.runtime.interpreted import interpret_ast
|
||||
from tinygrad.ops import Interpreted, Op, MovementOps, UnaryOps, BufferOps
|
||||
|
||||
class RawDiskBuffer(RawBufferMapped):
|
||||
|
@ -39,4 +38,4 @@ class RawDiskBuffer(RawBufferMapped):
|
|||
self._buf[0].readinto(buf)
|
||||
|
||||
disk_fxn_for_op: Dict[Op, Callable] = { BufferOps.MEM: lambda x: x, UnaryOps.NOOP: lambda x: x, UnaryOps.CAST: RawDiskBuffer.cast, MovementOps.AS_STRIDED: RawDiskBuffer.as_strided }
|
||||
DiskBuffer = Interpreted(RawDiskBuffer, functools.partial(interpret_ast, disk_fxn_for_op, None))
|
||||
DiskBuffer = Interpreted(RawDiskBuffer, disk_fxn_for_op)
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
# pip3 install pyobjc-framework-Metal pyobjc-framework-Cocoa pyobjc-framework-libdispatch
|
||||
import os, subprocess, pathlib, ctypes, tempfile
|
||||
import Metal, Cocoa, libdispatch
|
||||
from typing import List, Any, Tuple, Dict, Union, Set
|
||||
from typing import List, Any, Tuple, Dict, Union, Set, cast
|
||||
from tinygrad.codegen.kernel import LinearizerOptions
|
||||
from tinygrad.helpers import prod, getenv, DEBUG, DType, dtypes, diskcache, dedup, CI
|
||||
from tinygrad.ops import Compiled, BatchExecutor, JitItem
|
||||
from tinygrad.ops import Compiled, BatchExecutor, JitItem, CompiledASTRunner
|
||||
from tinygrad.renderer.metal import MetalRenderer
|
||||
from tinygrad.runtime.lib import RawBufferMapped, RawBuffer, LRUAllocator
|
||||
from tinygrad.shape.symbolic import Variable, Node
|
||||
|
@ -98,8 +98,9 @@ class MetalBatchExecutor(BatchExecutor):
|
|||
self.input_has_variable_dims: Set[int] = set()
|
||||
read_resources, write_resources = [], []
|
||||
for j,ji in enumerate(self.jit_cache):
|
||||
prg: CompiledASTRunner = cast(CompiledASTRunner, ji.prg)
|
||||
descriptor = Metal.MTLComputePipelineDescriptor.new()
|
||||
descriptor.setComputeFunction_(ji.prg.clprg.fxn)
|
||||
descriptor.setComputeFunction_(prg.clprg.fxn)
|
||||
descriptor.setSupportIndirectCommandBuffers_(True)
|
||||
pipeline_state = unwrap(METAL.device.newComputePipelineStateWithDescriptor_options_reflection_error_(descriptor, Metal.MTLPipelineOption(0), None, None))
|
||||
icb_command = self.icb.indirectComputeCommandAtIndex_(j)
|
||||
|
@ -110,11 +111,11 @@ class MetalBatchExecutor(BatchExecutor):
|
|||
if i == 0: write_resources.append(b._buf)
|
||||
else: read_resources.append(b._buf)
|
||||
var_vals_keys = list(var_vals.keys())
|
||||
for i,v in enumerate(getattr(ji.prg,"vars",[])):
|
||||
for i,v in enumerate(prg.vars):
|
||||
icb_command.setKernelBuffer_offset_atIndex_(self.int_buf._buf, var_vals_keys.index(v)*4, len(ji.rawbufs)+i)
|
||||
global_size, local_size = ji.prg.launch_dims(var_vals)
|
||||
assert ji.prg.global_size and ji.prg.local_size, "need global and local size to JIT"
|
||||
if any(isinstance(x, Node) for x in ji.prg.global_size) or any(isinstance(x, Node) for x in ji.prg.local_size):
|
||||
global_size, local_size = prg.launch_dims(var_vals)
|
||||
assert prg.global_size and prg.local_size, "need global and local size to JIT"
|
||||
if any(isinstance(x, Node) for x in prg.global_size) or any(isinstance(x, Node) for x in prg.local_size):
|
||||
self.input_has_variable_dims.add(j)
|
||||
else:
|
||||
icb_command.concurrentDispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size))
|
||||
|
@ -130,7 +131,7 @@ class MetalBatchExecutor(BatchExecutor):
|
|||
for (j,i),input_name in self.input_replace.items():
|
||||
self.icb.indirectComputeCommandAtIndex_(j).setKernelBuffer_offset_atIndex_(input_rawbuffers[input_name]._buf, 0, i)
|
||||
for j in self.input_has_variable_dims:
|
||||
global_size, local_size = self.jit_cache[j].prg.launch_dims(var_vals)
|
||||
global_size, local_size = cast(CompiledASTRunner, self.jit_cache[j].prg).launch_dims(var_vals)
|
||||
self.icb.indirectComputeCommandAtIndex_(j).concurrentDispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size))
|
||||
self.int_buf_view[:] = list(var_vals.values())
|
||||
command_buffer = METAL.mtl_queue.commandBuffer()
|
||||
|
|
|
@ -1,11 +1,10 @@
|
|||
import os, mmap, functools
|
||||
import os, mmap
|
||||
try: import _posixshmem
|
||||
except Exception: pass
|
||||
from typing import Callable, Dict
|
||||
from tinygrad.helpers import DType, OSX
|
||||
from tinygrad.runtime.lib import RawBufferMapped
|
||||
from tinygrad.ops import Interpreted, Op, UnaryOps, MovementOps, BufferOps
|
||||
from tinygrad.runtime.interpreted import interpret_ast
|
||||
|
||||
class RawShmBuffer(RawBufferMapped):
|
||||
def __init__(self, size, dtype:DType, device:str):
|
||||
|
@ -26,4 +25,4 @@ class RawShmBuffer(RawBufferMapped):
|
|||
|
||||
# TODO: is this wrong?
|
||||
shm_fxn_for_op: Dict[Op, Callable] = { BufferOps.MEM: lambda x: x, UnaryOps.NOOP: lambda x:x, MovementOps.RESHAPE: lambda x,_:x, MovementOps.AS_STRIDED: lambda x,_:x }
|
||||
ShmBuffer = Interpreted(RawShmBuffer, functools.partial(interpret_ast, shm_fxn_for_op, None))
|
||||
ShmBuffer = Interpreted(RawShmBuffer, shm_fxn_for_op)
|
||||
|
|
|
@ -1,11 +1,10 @@
|
|||
import torch, functools
|
||||
import torch
|
||||
import numpy as np
|
||||
from typing import Dict, Callable, Optional
|
||||
from tinygrad.ops import BufferOps, UnaryOps, BinaryOps, MovementOps, TernaryOps, Op, Interpreted
|
||||
from tinygrad.helpers import getenv, dtypes, prod, DType
|
||||
from tinygrad.runtime.ops_cpu import base_fxn_for_op, einsum_mulacc
|
||||
from tinygrad.runtime.lib import RawBuffer
|
||||
from tinygrad.runtime.interpreted import interpret_ast
|
||||
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else ("mps" if getenv("MPS", 0) else "cpu"))
|
||||
type_map = {torch.float64: dtypes.float64, torch.float16: dtypes.float16, torch.float32: dtypes.float32, torch.int8: dtypes.int8, torch.int32: dtypes.int32, torch.int64: dtypes.int64, torch.uint8: dtypes.uint8, torch.bool: dtypes.bool, torch.int16: dtypes.int16}
|
||||
|
@ -49,4 +48,4 @@ class RawTorchBuffer(RawBuffer):
|
|||
buf = torch.from_numpy(x if all(s>=0 for s in x.strides) else x.copy()).requires_grad_(False).to(device)
|
||||
return cls(prod(x.shape), type_map[buf.dtype], buf)
|
||||
def toCPU(self): return self._buf.cpu().numpy()
|
||||
TorchBuffer = Interpreted(RawTorchBuffer, functools.partial(interpret_ast, torch_fxn_for_op, lambda x: RawTorchBuffer(prod(x.shape), type_map[x.dtype], x)))
|
||||
TorchBuffer = Interpreted(RawTorchBuffer, torch_fxn_for_op, lambda x: RawTorchBuffer(prod(x.shape), type_map[x.dtype], x))
|
||||
|
|
Loading…
Reference in New Issue