mirror of https://github.com/commaai/tinygrad.git
put restrict back
This commit is contained in:
parent
201d9a2d58
commit
d062cc82b8
|
@ -4,7 +4,7 @@ import struct
|
|||
import json
|
||||
import traceback
|
||||
import numpy as np
|
||||
from tinygrad.runtime.ops_gpu import CLProgram, CLImage, CLBuffer
|
||||
from tinygrad.runtime.ops_gpu import CLProgram
|
||||
from tinygrad.helpers import prod, getenv
|
||||
from collections import defaultdict
|
||||
import pyopencl as cl
|
||||
|
|
|
@ -37,7 +37,7 @@ def helper_test_speed(f1, *args):
|
|||
args = [(x+1).realize() if isinstance(x, Tensor) else (None if x is None else (x+1)) for x in args] # cache defeats
|
||||
|
||||
# force syncing
|
||||
[x.numpy() if isinstance(x, Tensor) else x.cpu().numpy() for x in args if x is not None]
|
||||
[x.numpy() if isinstance(x, Tensor) or str(torch_device) == "cpu" else x.cpu().numpy() for x in args if x is not None]
|
||||
|
||||
GlobalCounters.global_ops = 0
|
||||
GlobalCounters.global_mem = 0
|
||||
|
@ -45,7 +45,7 @@ def helper_test_speed(f1, *args):
|
|||
st = time.monotonic()
|
||||
ret = f1(*args)
|
||||
# not ideal, it's copying (sometimes). why is this so slow in tinygrad?
|
||||
if isinstance(ret, Tensor): ret.numpy()
|
||||
if isinstance(ret, Tensor) or str(torch_device) == "cpu": ret.numpy()
|
||||
else: ret.cpu().numpy()
|
||||
et = (time.monotonic() - st) * 1000
|
||||
ets.append(et)
|
||||
|
|
|
@ -17,6 +17,7 @@ NATIVE_EXPLOG = getenv("NATIVE_EXPLOG", 0) # this is needed as a switch for the
|
|||
class GPULanguage(NamedTuple):
|
||||
kernel_prefix : str = ""
|
||||
buffer_prefix : str = ""
|
||||
buffer_suffix : str = ""
|
||||
smem_prefix : str = ""
|
||||
barrier : str = ""
|
||||
gid : List[str] = []
|
||||
|
@ -321,7 +322,7 @@ class GPUCodegen(ASTKernel):
|
|||
if GPUCodegen.kernel_cnt[function_name]:
|
||||
function_name = f"{function_name}{'_N'+str(GPUCodegen.kernel_cnt[function_name])}"
|
||||
|
||||
buftypes = [f"{'read_only' if i > 0 else 'write_only'} image2d_t" if hasattr(x._buf, "IMAGE") else self.lang.buffer_prefix+self.buftokens[i].decltype() for i,x in enumerate(self.bufs)]
|
||||
buftypes = [f"{'read_only' if i > 0 else 'write_only'} image2d_t" if hasattr(x._buf, "IMAGE") else self.lang.buffer_prefix+self.buftokens[i].decltype()+self.lang.buffer_suffix for i,x in enumerate(self.bufs)]
|
||||
self.kernel = list(self.prekernel) + [f"{self.lang.kernel_prefix} void {function_name}(",] + \
|
||||
[', '.join([f'{t} data{i}' for i,t in enumerate(buftypes) if i not in self.bufs_to_delete] + self.lang.extra_args)] + \
|
||||
[") {\n"] + self.kernel
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from typing import Callable, List, Tuple, Any, Dict, cast
|
||||
import itertools
|
||||
from tinygrad.helpers import DEBUG
|
||||
from tinygrad.lazy import Device
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.ops import GlobalCounters, DeviceBuffer
|
||||
|
@ -26,6 +27,7 @@ class TinyJit:
|
|||
self.jit_cache = GlobalCounters.cache
|
||||
GlobalCounters.cache = None
|
||||
assert len(self.jit_cache) != 0, "didn't JIT anything!"
|
||||
if DEBUG >= 1: print(f"JIT captured {len(self.jit_cache)} kernels with {len(input_tensors)} inputs")
|
||||
|
||||
# get the inputs for replacement
|
||||
for prg, args in self.jit_cache: # pylint: disable=E1133
|
||||
|
|
|
@ -120,8 +120,9 @@ class CompiledBuffer(DeviceBuffer): # pylint: disable=abstract-method
|
|||
assert backing is None or GlobalCounters.cache is None, f"can't copy in {backing.shape} while caching"
|
||||
return cls.raw_buffer_type(4*prod(shape)) if backing is None else cls.raw_buffer_type.fromCPU(backing)
|
||||
def raw(self) -> RawBuffer:
|
||||
if self._buf is None: self._buf = self.create_raw_buffer(self._base_shape, self._backing)
|
||||
self._backing = None
|
||||
if self._buf is None:
|
||||
self._buf = self.create_raw_buffer(self._base_shape, self._backing)
|
||||
self._backing = None
|
||||
return self._buf
|
||||
|
||||
@classmethod
|
||||
|
|
|
@ -6,7 +6,7 @@ import subprocess
|
|||
from collections import defaultdict
|
||||
from typing import Final, Dict
|
||||
from tinygrad.ops import CompiledBuffer, RawBufferCopyIn
|
||||
from tinygrad.codegen.gpu import GPUCodegen
|
||||
from tinygrad.codegen.gpu import GPUCodegen, GPULanguage
|
||||
import platform
|
||||
OSX = platform.system() == "Darwin"
|
||||
|
||||
|
@ -31,7 +31,10 @@ class ClangProgram:
|
|||
self.fxn(*[x._buf for x in args[2:]])
|
||||
if wait: return time.monotonic()-st
|
||||
|
||||
class ClangCodegen(GPUCodegen):
|
||||
lang = GPULanguage(buffer_suffix="restrict")
|
||||
|
||||
class ClangBuffer(CompiledBuffer):
|
||||
raw_buffer_type = RawMallocBuffer
|
||||
codegen_type = GPUCodegen # clang is the default
|
||||
codegen_type = ClangCodegen
|
||||
runtime_type = ClangProgram
|
||||
|
|
Loading…
Reference in New Issue