diff --git a/extra/thneed.py b/extra/thneed.py index c9c428c8..eaa02aeb 100644 --- a/extra/thneed.py +++ b/extra/thneed.py @@ -4,7 +4,7 @@ import struct import json import traceback import numpy as np -from tinygrad.runtime.ops_gpu import CLProgram, CLImage, CLBuffer +from tinygrad.runtime.ops_gpu import CLProgram from tinygrad.helpers import prod, getenv from collections import defaultdict import pyopencl as cl diff --git a/test/test_speed_v_torch.py b/test/test_speed_v_torch.py index 47470bf1..a0e56136 100644 --- a/test/test_speed_v_torch.py +++ b/test/test_speed_v_torch.py @@ -37,7 +37,7 @@ def helper_test_speed(f1, *args): args = [(x+1).realize() if isinstance(x, Tensor) else (None if x is None else (x+1)) for x in args] # cache defeats # force syncing - [x.numpy() if isinstance(x, Tensor) else x.cpu().numpy() for x in args if x is not None] + [x.numpy() if isinstance(x, Tensor) or str(torch_device) == "cpu" else x.cpu().numpy() for x in args if x is not None] GlobalCounters.global_ops = 0 GlobalCounters.global_mem = 0 @@ -45,7 +45,7 @@ def helper_test_speed(f1, *args): st = time.monotonic() ret = f1(*args) # not ideal, it's copying (sometimes). why is this so slow in tinygrad? - if isinstance(ret, Tensor): ret.numpy() + if isinstance(ret, Tensor) or str(torch_device) == "cpu": ret.numpy() else: ret.cpu().numpy() et = (time.monotonic() - st) * 1000 ets.append(et) diff --git a/tinygrad/codegen/gpu.py b/tinygrad/codegen/gpu.py index e2119827..8c439cad 100644 --- a/tinygrad/codegen/gpu.py +++ b/tinygrad/codegen/gpu.py @@ -17,6 +17,7 @@ NATIVE_EXPLOG = getenv("NATIVE_EXPLOG", 0) # this is needed as a switch for the class GPULanguage(NamedTuple): kernel_prefix : str = "" buffer_prefix : str = "" + buffer_suffix : str = "" smem_prefix : str = "" barrier : str = "" gid : List[str] = [] @@ -321,7 +322,7 @@ class GPUCodegen(ASTKernel): if GPUCodegen.kernel_cnt[function_name]: function_name = f"{function_name}{'_N'+str(GPUCodegen.kernel_cnt[function_name])}" - buftypes = [f"{'read_only' if i > 0 else 'write_only'} image2d_t" if hasattr(x._buf, "IMAGE") else self.lang.buffer_prefix+self.buftokens[i].decltype() for i,x in enumerate(self.bufs)] + buftypes = [f"{'read_only' if i > 0 else 'write_only'} image2d_t" if hasattr(x._buf, "IMAGE") else self.lang.buffer_prefix+self.buftokens[i].decltype()+self.lang.buffer_suffix for i,x in enumerate(self.bufs)] self.kernel = list(self.prekernel) + [f"{self.lang.kernel_prefix} void {function_name}(",] + \ [', '.join([f'{t} data{i}' for i,t in enumerate(buftypes) if i not in self.bufs_to_delete] + self.lang.extra_args)] + \ [") {\n"] + self.kernel diff --git a/tinygrad/jit.py b/tinygrad/jit.py index b1a7e02c..118ef7db 100644 --- a/tinygrad/jit.py +++ b/tinygrad/jit.py @@ -1,5 +1,6 @@ from typing import Callable, List, Tuple, Any, Dict, cast import itertools +from tinygrad.helpers import DEBUG from tinygrad.lazy import Device from tinygrad.tensor import Tensor from tinygrad.ops import GlobalCounters, DeviceBuffer @@ -26,6 +27,7 @@ class TinyJit: self.jit_cache = GlobalCounters.cache GlobalCounters.cache = None assert len(self.jit_cache) != 0, "didn't JIT anything!" + if DEBUG >= 1: print(f"JIT captured {len(self.jit_cache)} kernels with {len(input_tensors)} inputs") # get the inputs for replacement for prg, args in self.jit_cache: # pylint: disable=E1133 diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 50c8befb..bb1a9d88 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -120,8 +120,9 @@ class CompiledBuffer(DeviceBuffer): # pylint: disable=abstract-method assert backing is None or GlobalCounters.cache is None, f"can't copy in {backing.shape} while caching" return cls.raw_buffer_type(4*prod(shape)) if backing is None else cls.raw_buffer_type.fromCPU(backing) def raw(self) -> RawBuffer: - if self._buf is None: self._buf = self.create_raw_buffer(self._base_shape, self._backing) - self._backing = None + if self._buf is None: + self._buf = self.create_raw_buffer(self._base_shape, self._backing) + self._backing = None return self._buf @classmethod diff --git a/tinygrad/runtime/ops_clang.py b/tinygrad/runtime/ops_clang.py index 7d32b3ba..c33a6fe5 100644 --- a/tinygrad/runtime/ops_clang.py +++ b/tinygrad/runtime/ops_clang.py @@ -6,7 +6,7 @@ import subprocess from collections import defaultdict from typing import Final, Dict from tinygrad.ops import CompiledBuffer, RawBufferCopyIn -from tinygrad.codegen.gpu import GPUCodegen +from tinygrad.codegen.gpu import GPUCodegen, GPULanguage import platform OSX = platform.system() == "Darwin" @@ -31,7 +31,10 @@ class ClangProgram: self.fxn(*[x._buf for x in args[2:]]) if wait: return time.monotonic()-st +class ClangCodegen(GPUCodegen): + lang = GPULanguage(buffer_suffix="restrict") + class ClangBuffer(CompiledBuffer): raw_buffer_type = RawMallocBuffer - codegen_type = GPUCodegen # clang is the default + codegen_type = ClangCodegen runtime_type = ClangProgram