openpilot compile2 (#1977)

* start compile2

* tweak

* why are there two more kernels?

* minor cleanups

* don't break onnx tests

* add __metadata__ support to safetensors

* no early realize in onnx

* cleanups

* bugfix

* clean up image type, add optimize

* opt to match old

* try that

* opt work

* run compile2

* optimizer

* prt more

* prerealize

* imp

* NOLOCALS works

* no locals means no locals

* support fractional globals

* all locals welcome

* int that

* cleanups

* show gemv regression

* clean up diff

* use idx for the cond

* nolocals

---------

Co-authored-by: Comma Device <device@comma.ai>
This commit is contained in:
George Hotz 2023-10-15 20:39:46 -07:00 committed by GitHub
parent 566660675c
commit 5472a14544
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 206 additions and 26 deletions

82
extra/gemm/gemv_845.py Normal file
View File

@ -0,0 +1,82 @@
old = """__kernel void re_S256_16_8( write_only image2d_t data0, read_only image2d_t data1, read_only image2d_t data2, __global float* data3 ) {
const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
int idx2 = get_global_id(0); /* 4 */
int idx1 = get_global_id(1); /* 16 */
int idx0 = get_global_id(2); /* 256 */
float acc0 = 0.0f;
for (int idx3 = 0; idx3 < 8; idx3++) {
float4 val1_0 = read_imagef(data1, smp, (int2)(((idx1*8)+idx3), 0)) /* (1, 128, 4) */;
float4 val2_0 = read_imagef(data2, smp, (int2)(((idx1*32)+(idx3*4)+idx2), idx0)) /* (256, 512, 4) */;
acc0+=(val1_0.x*val2_0.x);
acc0+=(val1_0.y*val2_0.y);
acc0+=(val1_0.z*val2_0.z);
acc0+=(val1_0.w*val2_0.w);
}
__local float temp[64];
temp[((idx1*4)+idx2)] = acc0;
barrier(CLK_LOCAL_MEM_FENCE);
if (((idx1*4)+idx2) == 0) {
float4 output0 = (float4)(0.0f,0.0f,0.0f,0.0f);
for (int mid = 0; mid < 16; mid++) {
float4 val5_0 = ((__local float4*)temp)[mid];
output0.x+=val5_0.x;
output0.y+=val5_0.y;
output0.z+=val5_0.z;
output0.w+=val5_0.w;
}
float4 val3_0 = ((__global float4*)data3)[idx0];
write_imagef(data0, (int2)(idx0, 0), (float4)(max((output0.x+val3_0.x),(0.0f)),max((output0.y+val3_0.y),(0.0f)),max((output0.z+val3_0.z),(0.0f)),max((output0.w+val3_0.w),(0.0f)))); /* (1, 256, 4) */
}
}"""
new = """__kernel void r_256_16_4_8_4(write_only image2d_t data0, read_only image2d_t data1, read_only image2d_t data2, const __global float* data3) {
const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
__attribute__ ((aligned (16))) __local float temp[64];
int gidx0 = get_group_id(0); /* 256 */
int lidx1 = get_local_id(1); /* 16 */
int lidx2 = get_local_id(0); /* 4 */
float acc0 = 0.0f;
for (int ridx0 = 0; ridx0 < 8; ++ridx0) {
float4 val0 = read_imagef(data1, smp, (int2)(((lidx1*8)+ridx0),0));
float4 val1 = read_imagef(data2, smp, (int2)(((lidx1*32)+lidx2+(ridx0*4)),gidx0));
acc0 = (((val0).x*(val1).x)+acc0);
acc0 = (((val0).y*(val1).y)+acc0);
acc0 = (((val0).z*(val1).z)+acc0);
acc0 = (((val0).w*(val1).w)+acc0);
}
temp[(lidx1*4)+lidx2] = acc0;
barrier(CLK_LOCAL_MEM_FENCE);
float4 acc1 = (float4)(0.0f,0.0f,0.0f,0.0f);
for (int ridx1 = 0; ridx1 < 16; ++ridx1) {
float4 val2 = (float4)(*((__local float4*)(temp+ridx1*4)));
(acc1).x = ((val2).x+(acc1).x);
(acc1).y = ((val2).y+(acc1).y);
(acc1).z = ((val2).z+(acc1).z);
(acc1).w = ((val2).w+(acc1).w);
}
float4 val3 = (float4)(*((__global float4*)(data3+gidx0*4)));
write_imagef(data0, (int2)(gidx0,0), (float4)(max(((acc1).x+(val3).x),0.0f),max(((acc1).y+(val3).y),0.0f),max(((acc1).z+(val3).z),0.0f),max(((acc1).w+(val3).w),0.0f)));
}"""
from tinygrad.runtime.ops_gpu import CLBuffer, CLProgram
from tinygrad.helpers import dtypes, prod
if __name__ == "__main__":
out = CLBuffer(prod((1, 128, 4)), dtypes.imageh((1,128,4)))
x = CLBuffer(prod((1, 128, 4)), dtypes.imageh((1,128,4)))
w = CLBuffer(prod((256, 512, 4)), dtypes.imageh((256, 512, 4)))
b = CLBuffer(1024, dtypes.float)
old = CLProgram("re_S256_16_8", old)
new = CLProgram("r_256_16_4_8_4", new)
old_tms = []
new_tms = []
for i in range(5):
old_tms.append(old([1,1,256], [4,16,1], out, x, w, b, wait=True))
new_tms.append(new([256,1,1], [4,16,1], out, x, w, b, wait=True))
print(f"old: {min(old_tms)*1e6:.2f} us new: {min(new_tms)*1e6:.2f} us")

View File

@ -281,12 +281,12 @@ class Thneed:
if DEBUGCL >= 2: if DEBUGCL >= 2:
for i, ((prg, args), e) in enumerate(zip(self.cl_cache, events)): for i, ((prg, args), e) in enumerate(zip(self.cl_cache, events)):
print(f"{i:3d} {prg.name:20s} " + "queued @ %5.2f ms, submit @ %5.2fms, start @ %5.2f ms, end @ %5.2f ms" % tuple((x*OSX_TIMING_RATIO - st*1e9)/1e6 for x in [e.profile.queued, e.profile.submit, e.profile.start, e.profile.end])) print(f"{i:3d} {prg.name:25s} " + "queued @ %5.2f ms, submit @ %5.2fms, start @ %5.2f ms, end @ %5.2f ms" % tuple((x*OSX_TIMING_RATIO - st*1e9)/1e6 for x in [e.profile.queued, e.profile.submit, e.profile.start, e.profile.end]))
if DEBUGCL >= 1: if DEBUGCL >= 1:
total_runtime = 0 total_runtime = 0
for i, ((prg, args), e) in enumerate(zip(self.cl_cache, events)): for i, ((prg, args), e) in enumerate(zip(self.cl_cache, events)):
runtime = (e.profile.end - e.profile.start) * OSX_TIMING_RATIO runtime = (e.profile.end - e.profile.start) * OSX_TIMING_RATIO
print(f"{i:3d} time {total_runtime/1e6:5.2f} ms running {prg.name:20s} with {str(args[0]):15s} {str(args[1]):15s} count {len(args)-2:2d} runtime {runtime/1e3:7.2f} us {(getattr(prg, 'op_estimate', float('nan')))/runtime:9.2f} GFLOPS -> {args[2].shape if hasattr(args[2], 'shape') else args[2].size}") print(f"{i:3d} time {total_runtime/1e6:5.2f} ms running {prg.name:25s} with {str(args[0]):15s} {str(args[1]):15s} count {len(args)-2:2d} runtime {runtime/1e3:7.2f} us {(getattr(prg, 'op_estimate', float('nan')))/runtime:9.2f} GFLOPS -> {args[2].shape if hasattr(args[2], 'shape') else args[2].size}")
if hasattr(prg, 'prg') and ((DEBUGCL >= 2 and getenv("PRINT_KERNEL", -1) == i) or DEBUGCL >= 3): if hasattr(prg, 'prg') and ((DEBUGCL >= 2 and getenv("PRINT_KERNEL", -1) == i) or DEBUGCL >= 3):
print(prg.prg) print(prg.prg)
total_runtime += runtime total_runtime += runtime

View File

@ -79,7 +79,7 @@ def compile(dat, output_fn):
global_size = prg.global_size + [1]*(3-len(prg.global_size)) global_size = prg.global_size + [1]*(3-len(prg.global_size))
local_size = prg.local_size + [1]*(3-len(prg.local_size)) local_size = prg.local_size + [1]*(3-len(prg.local_size))
cl_cache.append((prg.clprg, [[g*l for g,l in zip(global_size, local_size)], local_size, *[x._buf for x in args]])) cl_cache.append((prg.clprg, [[int(g*l) for g,l in zip(global_size, local_size)], local_size, *[x._buf for x in args]]))
used_ops += prg.op_estimate used_ops += prg.op_estimate
from extra.thneed import Thneed from extra.thneed import Thneed

71
openpilot/compile2.py Normal file
View File

@ -0,0 +1,71 @@
import os
if "FLOAT16" not in os.environ: os.environ["FLOAT16"] = "1"
if "IMAGE" not in os.environ: os.environ["IMAGE"] = "2"
if "NOLOCALS" not in os.environ: os.environ["NOLOCALS"] = "1"
if "OPT" not in os.environ: os.environ["OPT"] = "99"
os.environ["PREREALIZE"] = "0"
OPENPILOT_MODEL = "https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx"
import sys
import onnx
import io
from typing import Tuple, List
from extra.utils import fetch
from extra.onnx import get_run_onnx
from tinygrad.graph import print_tree
from tinygrad.tensor import Tensor
from tinygrad.helpers import dtypes, partition, GlobalCounters, Context, DEBUG, getenv
from tinygrad.realize import run_schedule
from tinygrad.ops import LoadOps, Device, ScheduleItem
Device.DEFAULT = "GPU"
def get_schedule(fn:str) -> Tuple[List[ScheduleItem], List[ScheduleItem]]:
Tensor.no_grad = True
Tensor.training = False
# load the model
dat = fetch(fn)
onnx_model = onnx.load(io.BytesIO(dat))
run_onnx = get_run_onnx(onnx_model)
input_shapes = {inp.name:tuple(x.dim_value for x in inp.type.tensor_type.shape.dim) for inp in onnx_model.graph.input}
# run the model
inputs = {k:Tensor.empty(*shp) for k,shp in input_shapes.items()}
ret: Tensor = next(iter(run_onnx(inputs).values())).cast(dtypes.float32).contiguous()
schedule = ret.lazydata.schedule()
# filter schedule that don't depend on the inputs
input_lb = [x.lazydata.base for x in inputs.values()]
depends = set(input_lb)
for si in schedule:
if any(b in depends for b in si.inputs):
depends.add(si.out)
# run all kernels that don't depend on the inputs
# NOTE: there's two extra kernels due to fusions that now happen since the weights aren't realized
schedule, schedule_independent = partition(schedule, lambda si: si.out in depends)
print(f"{len(schedule)} schedule items depend on the input, {len(schedule_independent)} don't")
# confirm no loadops in the (non independent) schedule except for the ones that load the input buffers
assert all(si.ast.op not in LoadOps or si.out in input_lb for si in schedule), "has loadops, can't compile to Thneed"
return schedule, schedule_independent
def lb_to_numbers(schedule):
nschedule = []
nlb = {}
for op,out,buffers in schedule:
for lb in (out,)+buffers:
if lb not in nlb:
nlb[lb] = len(nlb)
nschedule.append((op, nlb[out], tuple(nlb[x] for x in buffers)))
return nschedule
if __name__ == "__main__":
schedule, schedule_independent = get_schedule(sys.argv[1] if len(sys.argv) > 1 else OPENPILOT_MODEL)
run_schedule(schedule_independent)
print("**** running real kernels ****")
with Context(DEBUG=2):
GlobalCounters.reset()
run_schedule(schedule)

View File

@ -1,2 +1,2 @@
#!/bin/bash #!/bin/bash
FLOAT16=1 DEBUGCL=1 IMAGE=2 GPU=1 python3 openpilot/compile.py NOLOCALS=1 FLOAT16=1 DEBUGCL=1 IMAGE=2 GPU=1 python3 openpilot/compile.py

View File

@ -79,6 +79,7 @@ class Kernel:
self.use_tensor_cores: bool = False self.use_tensor_cores: bool = False
self.exclude_local_upcast: int = 0 self.exclude_local_upcast: int = 0
self.reverse_upcast_dir: bool = False self.reverse_upcast_dir: bool = False
self.dont_use_locals: bool = False
self.global_size: Optional[List[int]] = None self.global_size: Optional[List[int]] = None
self.local_size: Optional[List[int]] = None self.local_size: Optional[List[int]] = None

View File

@ -206,7 +206,10 @@ class Linearizer(OptimizedKernel):
loop_uop = self.loop_uops[x.expr] loop_uop = self.loop_uops[x.expr]
if loop_uop.uop == UOps.LOOP: self.uop(UOps.END, None, (loop_uop,)) if loop_uop.uop == UOps.LOOP: self.uop(UOps.END, None, (loop_uop,))
if self.opts.has_local: if self.dont_use_locals:
self.global_size = [x.max+1 for x in loop_global_idxs][::-1]
self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr.replace("gidx", "idx"), x.max+1)) for i,x in enumerate(loop_global_idxs)})
elif self.opts.has_local:
self.global_size, self.local_size = [x.max+1 for x in loop_global_idxs][::-1], [x.max+1 for x in loop_local_idxs][::-1] self.global_size, self.local_size = [x.max+1 for x in loop_global_idxs][::-1], [x.max+1 for x in loop_local_idxs][::-1]
self.global_size += [1]*(3-len(self.global_size)) self.global_size += [1]*(3-len(self.global_size))
self.local_size += [1]*(3-len(self.local_size)) self.local_size += [1]*(3-len(self.local_size))
@ -317,7 +320,9 @@ class Linearizer(OptimizedKernel):
self.uop(UOps.BARRIER, None, (), cachable=False) self.uop(UOps.BARRIER, None, (), cachable=False)
end_loop(loop_local_idxs) # TODO: this is ending too much, should only end what's in the if? end_loop(loop_local_idxs) # TODO: this is ending too much, should only end what's in the if?
if self.opts.has_local: if self.opts.has_local:
if_cond: UOp = Variable.ands([x<1 for x in local_idxs[self.local_dims:]]).render(self.render_ops, self) fake_idxs = [Variable.num(0)]*len(self.sts[-1].shape)
fake_idxs[self.global_dims+self.local_dims:self.global_dims+len(local_idxs)] = local_idxs[self.local_dims:]
if_cond: UOp = (self.sts[-1].expr_idxs(fake_idxs)[0]<1).render(self.render_ops, self)
if_gate = self.uop(UOps.IF, None, (if_cond,), cachable=False) if_gate = self.uop(UOps.IF, None, (if_cond,), cachable=False)
# create new late reduce local loops and replace local_idxs that have been used # create new late reduce local loops and replace local_idxs that have been used

View File

@ -426,6 +426,9 @@ class OptimizedKernel(Kernel):
# **** local groups **** # **** local groups ****
if self.opts.has_local: if self.opts.has_local:
if getenv("NOLOCALS") and self.local_dims == 0 and not self.group_for_reduce:
self.dont_use_locals = True
else:
# prioritize making expand axes local # prioritize making expand axes local
local_axis_ranking = [(any(self.sts[buf_index].views[-1].strides[axis] == 0 for buf_index in range(len(self.sts))), axis) for axis in range(len(self.full_shape[:self.first_reduce]))] local_axis_ranking = [(any(self.sts[buf_index].views[-1].strides[axis] == 0 for buf_index in range(len(self.sts))), axis) for axis in range(len(self.full_shape[:self.first_reduce]))]
to_local: List[Tuple[int, int]] = [] to_local: List[Tuple[int, int]] = []

View File

@ -2,7 +2,7 @@ from typing import Dict, List, cast, DefaultDict, Optional
from copy import deepcopy from copy import deepcopy
from tinygrad.lazy import vars_from_ast from tinygrad.lazy import vars_from_ast
from tinygrad.ops import Device, Compiled, MemBuffer from tinygrad.ops import Device, Compiled, MemBuffer
from tinygrad.helpers import prod, getenv, flatten from tinygrad.helpers import prod, getenv, ImageDType, flatten
from tinygrad.codegen.linearizer import Linearizer from tinygrad.codegen.linearizer import Linearizer
from tinygrad.runtime.lib import RawBuffer from tinygrad.runtime.lib import RawBuffer
from collections import defaultdict from collections import defaultdict
@ -57,7 +57,7 @@ def bufs_from_lin(lin:Linearizer) -> List[RawBuffer]:
for x in lin.membufs: bufsts[x.idx].append(x) for x in lin.membufs: bufsts[x.idx].append(x)
rawbufs:List[Optional[RawBuffer]] = [None]*len(bufsts) rawbufs:List[Optional[RawBuffer]] = [None]*len(bufsts)
for k,lx in bufsts.items(): for k,lx in bufsts.items():
rawbufs[k] = device.buffer(max(y.st.size() for y in lx), lx[0].dtype) rawbufs[k] = device.buffer(prod(lx[0].dtype.shape) if isinstance(lx[0].dtype, ImageDType) else max(y.st.size() for y in lx), lx[0].dtype)
assert all(r is not None for r in rawbufs) assert all(r is not None for r in rawbufs)
return cast(List[RawBuffer], rawbufs) return cast(List[RawBuffer], rawbufs)

View File

@ -60,7 +60,7 @@ def image_conv2d(self, weight, bias=None, groups=1, stride=1, dilation=1, paddin
# contiguous creates the image, and early realize static weights (TODO: test for the static weight) # contiguous creates the image, and early realize static weights (TODO: test for the static weight)
if IMAGE >= 2: x,w = x.cast(base_image_type(x.shape)), w.cast(base_image_type(w.shape)) if IMAGE >= 2: x,w = x.cast(base_image_type(x.shape)), w.cast(base_image_type(w.shape))
x, w = x.contiguous(), w.contiguous() x, w = x.contiguous(), w.contiguous()
if get_single_root(w.lazydata).realized: w.realize() if getenv("PREREALIZE", 1) and get_single_root(w.lazydata).realized: w.realize()
# expand out # expand out
rcin_hi, rcin_lo = cin//4 if cin >= 4 else 1, 4 if cin >= 4 else 1 rcin_hi, rcin_lo = cin//4 if cin >= 4 else 1, 4 if cin >= 4 else 1

View File

@ -1,5 +1,5 @@
from __future__ import annotations from __future__ import annotations
import time, importlib, inspect, functools, pathlib import time, importlib, inspect, functools, pathlib, itertools, random
import numpy as np import numpy as np
from enum import Enum, auto from enum import Enum, auto
from typing import TYPE_CHECKING, Union, Type, Tuple, Any, List, Optional, Dict, Callable, cast, Mapping from typing import TYPE_CHECKING, Union, Type, Tuple, Any, List, Optional, Dict, Callable, cast, Mapping
@ -184,6 +184,18 @@ class ASTRunner:
op_estimate=k.info.flops, mem_estimate=k.mem_estimate, op_estimate=k.info.flops, mem_estimate=k.mem_estimate,
display_name=k.display_name, runtime_args={"binary": False}) display_name=k.display_name, runtime_args={"binary": False})
def optimize_local_size(self, global_size, rawbufs) -> List[int]:
assert self.global_size is not None, "needs a global size to optimize local size"
MAX_WORKGROUP = self.clprg.max_work_group_size() if hasattr(self.clprg, 'max_work_group_size') else 1024
local_dims = [[x for x in set([sz, 1, 2, 4, 8, 16, 32, 64, 128, 256, MAX_WORKGROUP]) if x<=sz] for sz in global_size]
local_sizes = [list(x) for x in itertools.product(*local_dims) if prod(x) <= MAX_WORKGROUP] * 2 # try each valid size twice
def try_exec(local_size):
try:
return self.clprg([g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)], local_size, *rawbufs, wait=True)
except Exception:
return float('inf')
return min([(try_exec(local_size), local_size) for local_size in random.sample(local_sizes, len(local_sizes))])[1]
def build(self, runtime, batch_exec=BasicBatchExecutor): def build(self, runtime, batch_exec=BasicBatchExecutor):
self.clprg, self.batch_exec = runtime(self.name, self.prg, **self.runtime_args), batch_exec self.clprg, self.batch_exec = runtime(self.name, self.prg, **self.runtime_args), batch_exec
return self return self
@ -201,6 +213,9 @@ class ASTRunner:
def __call__(self, rawbufs:List[RawBuffer], var_vals:Optional[Dict[Variable, int]]=None, jit=False, force_wait=False) -> Optional[float]: def __call__(self, rawbufs:List[RawBuffer], var_vals:Optional[Dict[Variable, int]]=None, jit=False, force_wait=False) -> Optional[float]:
if var_vals is None: var_vals = {} if var_vals is None: var_vals = {}
global_size, local_size = self.launch_dims(var_vals) global_size, local_size = self.launch_dims(var_vals)
if global_size is not None and local_size is None:
local_size = self.local_size = self.optimize_local_size(global_size, rawbufs)
global_size = self.global_size = [g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)]
if et := self.clprg(global_size, local_size, *rawbufs, *var_vals.values(), wait=force_wait or DEBUG>=1): GlobalCounters.time_sum_s += et if et := self.clprg(global_size, local_size, *rawbufs, *var_vals.values(), wait=force_wait or DEBUG>=1): GlobalCounters.time_sum_s += et
op_estimate = sym_infer(self.op_estimate, var_vals) op_estimate = sym_infer(self.op_estimate, var_vals)
if DEBUG >= 2: if DEBUG >= 2:

View File

@ -16,6 +16,7 @@ class CStyleLanguage(NamedTuple):
smem_prefix_for_cast: bool = True smem_prefix_for_cast: bool = True
arg_int_prefix: str = "" arg_int_prefix: str = ""
barrier: str = "" barrier: str = ""
xid: List[str] = []
gid: List[str] = [] gid: List[str] = []
lid: List[str] = [] lid: List[str] = []
global_max: List[int] = [] global_max: List[int] = []
@ -173,7 +174,7 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> st
assert dtype is not None assert dtype is not None
kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u,'acc')} = {lang.render_const(args, dtype)};") kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u,'acc')} = {lang.render_const(args, dtype)};")
elif uop == UOps.SPECIAL: elif uop == UOps.SPECIAL:
xid = lang.gid if args[1].startswith("g") else lang.lid xid = lang.gid if args[1].startswith("g") else (lang.xid if args[1].startswith("i") else lang.lid)
kk(f"{lang.size_prefix} {args[1]} = {xid[args[0]]}; /* {args[2]} */") kk(f"{lang.size_prefix} {args[1]} = {xid[args[0]]}; /* {args[2]} */")
if args[1].startswith("l"): local_size.append(args[2]) if args[1].startswith("l"): local_size.append(args[2])
r[u] = args[1] r[u] = args[1]

View File

@ -14,6 +14,7 @@ class OpenCLLanguage(CStyleLanguage):
float4 = "(float4)" float4 = "(float4)"
gid = [f'get_group_id({i})' for i in range(3)] gid = [f'get_group_id({i})' for i in range(3)]
lid = [f'get_local_id({i})' for i in range(3)] lid = [f'get_local_id({i})' for i in range(3)]
xid = [f'get_global_id({i})' for i in range(3)]
uses_vload=True uses_vload=True
OpenCLRenderer = functools.partial(uops_to_cstyle, OpenCLLanguage()) OpenCLRenderer = functools.partial(uops_to_cstyle, OpenCLLanguage())

View File

@ -89,6 +89,7 @@ renderer = functools.partial(uops_to_cstyle, CStyleLanguage(
kernel_prefix = "__global__ ", smem_prefix = "__shared__ ", smem_prefix_for_cast=False, arg_int_prefix = "const int", barrier = "__syncthreads();", float4 = "make_float4", kernel_prefix = "__global__ ", smem_prefix = "__shared__ ", smem_prefix_for_cast=False, arg_int_prefix = "const int", barrier = "__syncthreads();", float4 = "make_float4",
gid = [f'blockIdx.{chr(120+i)}' for i in range(3)], gid = [f'blockIdx.{chr(120+i)}' for i in range(3)],
lid = [f'threadIdx.{chr(120+i)}' for i in range(3)], lid = [f'threadIdx.{chr(120+i)}' for i in range(3)],
xid = [f'(blockIdx.{chr(120+i)}*blockDim.{chr(120+i)}+threadIdx.{chr(120+i)})' for i in range(3)],
half_prekernel = """ half_prekernel = """
#include <cuda_fp16.h> #include <cuda_fp16.h>
struct __align__(8) half4 { struct __align__(8) half4 {

View File

@ -94,7 +94,7 @@ class CLProgram:
cl_bufs.append(x._buf) cl_bufs.append(x._buf)
if hasattr(x, "event"): wait_for.append(x.event) if hasattr(x, "event"): wait_for.append(x.event)
else: cl_bufs.append(x) else: cl_bufs.append(x)
e = self.clprgs[cl_bufs[0].device](CL.cl_queue[cl_bufs[0].device], [g*l for g,l in zip(global_size, local_size)] if local_size is not None else global_size, local_size, *cl_bufs, wait_for=wait_for) e = self.clprgs[cl_bufs[0].device](CL.cl_queue[cl_bufs[0].device], [int(g*l) for g,l in zip(global_size, local_size)] if local_size is not None else global_size, local_size, *cl_bufs, wait_for=wait_for)
if wait: if wait:
e.wait() e.wait()
try: try:

View File

@ -327,7 +327,7 @@ def create_rednode(typ:Type[RedNode], nodes:List[Node]):
def sym_rename(s) -> str: return f"s{sym_rename.cache_info().currsize}" def sym_rename(s) -> str: return f"s{sym_rename.cache_info().currsize}"
def sym_render(a: Union[Node, int], ops=None, ctx=None) -> str: return str(a) if isinstance(a, int) else a.render(ops, ctx) def sym_render(a: Union[Node, int], ops=None, ctx=None) -> str: return str(a) if isinstance(a, int) else a.render(ops, ctx)
def sym_infer(a: Union[Node, int], var_vals: Dict[Variable, int]) -> int: def sym_infer(a: Union[Node, int], var_vals: Dict[Variable, int]) -> int:
if isinstance(a, int): return a if isinstance(a, (int, float)): return a
ret = a.substitute({k:Variable.num(v) for k, v in var_vals.items()}) ret = a.substitute({k:Variable.num(v) for k, v in var_vals.items()})
assert isinstance(ret, NumNode), f"sym_infer didn't produce NumNode from {a} with {var_vals}" assert isinstance(ret, NumNode), f"sym_infer didn't produce NumNode from {a} with {var_vals}"
return ret.b return ret.b