put restrict back

This commit is contained in:
George Hotz 2023-03-01 21:34:45 -08:00
parent 201d9a2d58
commit d062cc82b8
6 changed files with 15 additions and 8 deletions

View File

@ -4,7 +4,7 @@ import struct
import json
import traceback
import numpy as np
from tinygrad.runtime.ops_gpu import CLProgram, CLImage, CLBuffer
from tinygrad.runtime.ops_gpu import CLProgram
from tinygrad.helpers import prod, getenv
from collections import defaultdict
import pyopencl as cl

View File

@ -37,7 +37,7 @@ def helper_test_speed(f1, *args):
args = [(x+1).realize() if isinstance(x, Tensor) else (None if x is None else (x+1)) for x in args] # cache defeats
# force syncing
[x.numpy() if isinstance(x, Tensor) else x.cpu().numpy() for x in args if x is not None]
[x.numpy() if isinstance(x, Tensor) or str(torch_device) == "cpu" else x.cpu().numpy() for x in args if x is not None]
GlobalCounters.global_ops = 0
GlobalCounters.global_mem = 0
@ -45,7 +45,7 @@ def helper_test_speed(f1, *args):
st = time.monotonic()
ret = f1(*args)
# not ideal, it's copying (sometimes). why is this so slow in tinygrad?
if isinstance(ret, Tensor): ret.numpy()
if isinstance(ret, Tensor) or str(torch_device) == "cpu": ret.numpy()
else: ret.cpu().numpy()
et = (time.monotonic() - st) * 1000
ets.append(et)

View File

@ -17,6 +17,7 @@ NATIVE_EXPLOG = getenv("NATIVE_EXPLOG", 0) # this is needed as a switch for the
class GPULanguage(NamedTuple):
kernel_prefix : str = ""
buffer_prefix : str = ""
buffer_suffix : str = ""
smem_prefix : str = ""
barrier : str = ""
gid : List[str] = []
@ -321,7 +322,7 @@ class GPUCodegen(ASTKernel):
if GPUCodegen.kernel_cnt[function_name]:
function_name = f"{function_name}{'_N'+str(GPUCodegen.kernel_cnt[function_name])}"
buftypes = [f"{'read_only' if i > 0 else 'write_only'} image2d_t" if hasattr(x._buf, "IMAGE") else self.lang.buffer_prefix+self.buftokens[i].decltype() for i,x in enumerate(self.bufs)]
buftypes = [f"{'read_only' if i > 0 else 'write_only'} image2d_t" if hasattr(x._buf, "IMAGE") else self.lang.buffer_prefix+self.buftokens[i].decltype()+self.lang.buffer_suffix for i,x in enumerate(self.bufs)]
self.kernel = list(self.prekernel) + [f"{self.lang.kernel_prefix} void {function_name}(",] + \
[', '.join([f'{t} data{i}' for i,t in enumerate(buftypes) if i not in self.bufs_to_delete] + self.lang.extra_args)] + \
[") {\n"] + self.kernel

View File

@ -1,5 +1,6 @@
from typing import Callable, List, Tuple, Any, Dict, cast
import itertools
from tinygrad.helpers import DEBUG
from tinygrad.lazy import Device
from tinygrad.tensor import Tensor
from tinygrad.ops import GlobalCounters, DeviceBuffer
@ -26,6 +27,7 @@ class TinyJit:
self.jit_cache = GlobalCounters.cache
GlobalCounters.cache = None
assert len(self.jit_cache) != 0, "didn't JIT anything!"
if DEBUG >= 1: print(f"JIT captured {len(self.jit_cache)} kernels with {len(input_tensors)} inputs")
# get the inputs for replacement
for prg, args in self.jit_cache: # pylint: disable=E1133

View File

@ -120,8 +120,9 @@ class CompiledBuffer(DeviceBuffer): # pylint: disable=abstract-method
assert backing is None or GlobalCounters.cache is None, f"can't copy in {backing.shape} while caching"
return cls.raw_buffer_type(4*prod(shape)) if backing is None else cls.raw_buffer_type.fromCPU(backing)
def raw(self) -> RawBuffer:
if self._buf is None: self._buf = self.create_raw_buffer(self._base_shape, self._backing)
self._backing = None
if self._buf is None:
self._buf = self.create_raw_buffer(self._base_shape, self._backing)
self._backing = None
return self._buf
@classmethod

View File

@ -6,7 +6,7 @@ import subprocess
from collections import defaultdict
from typing import Final, Dict
from tinygrad.ops import CompiledBuffer, RawBufferCopyIn
from tinygrad.codegen.gpu import GPUCodegen
from tinygrad.codegen.gpu import GPUCodegen, GPULanguage
import platform
OSX = platform.system() == "Darwin"
@ -31,7 +31,10 @@ class ClangProgram:
self.fxn(*[x._buf for x in args[2:]])
if wait: return time.monotonic()-st
class ClangCodegen(GPUCodegen):
lang = GPULanguage(buffer_suffix="restrict")
class ClangBuffer(CompiledBuffer):
raw_buffer_type = RawMallocBuffer
codegen_type = GPUCodegen # clang is the default
codegen_type = ClangCodegen
runtime_type = ClangProgram