mirror of https://github.com/commaai/tinygrad.git
openpilot compile2 (#1977)
* start compile2 * tweak * why are there two more kernels? * minor cleanups * don't break onnx tests * add __metadata__ support to safetensors * no early realize in onnx * cleanups * bugfix * clean up image type, add optimize * opt to match old * try that * opt work * run compile2 * optimizer * prt more * prerealize * imp * NOLOCALS works * no locals means no locals * support fractional globals * all locals welcome * int that * cleanups * show gemv regression * clean up diff * use idx for the cond * nolocals --------- Co-authored-by: Comma Device <device@comma.ai>
This commit is contained in:
parent
566660675c
commit
5472a14544
|
@ -0,0 +1,82 @@
|
|||
old = """__kernel void re_S256_16_8( write_only image2d_t data0, read_only image2d_t data1, read_only image2d_t data2, __global float* data3 ) {
|
||||
const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
|
||||
int idx2 = get_global_id(0); /* 4 */
|
||||
int idx1 = get_global_id(1); /* 16 */
|
||||
int idx0 = get_global_id(2); /* 256 */
|
||||
float acc0 = 0.0f;
|
||||
for (int idx3 = 0; idx3 < 8; idx3++) {
|
||||
float4 val1_0 = read_imagef(data1, smp, (int2)(((idx1*8)+idx3), 0)) /* (1, 128, 4) */;
|
||||
float4 val2_0 = read_imagef(data2, smp, (int2)(((idx1*32)+(idx3*4)+idx2), idx0)) /* (256, 512, 4) */;
|
||||
acc0+=(val1_0.x*val2_0.x);
|
||||
acc0+=(val1_0.y*val2_0.y);
|
||||
acc0+=(val1_0.z*val2_0.z);
|
||||
acc0+=(val1_0.w*val2_0.w);
|
||||
}
|
||||
__local float temp[64];
|
||||
temp[((idx1*4)+idx2)] = acc0;
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
if (((idx1*4)+idx2) == 0) {
|
||||
float4 output0 = (float4)(0.0f,0.0f,0.0f,0.0f);
|
||||
for (int mid = 0; mid < 16; mid++) {
|
||||
float4 val5_0 = ((__local float4*)temp)[mid];
|
||||
output0.x+=val5_0.x;
|
||||
output0.y+=val5_0.y;
|
||||
output0.z+=val5_0.z;
|
||||
output0.w+=val5_0.w;
|
||||
}
|
||||
float4 val3_0 = ((__global float4*)data3)[idx0];
|
||||
write_imagef(data0, (int2)(idx0, 0), (float4)(max((output0.x+val3_0.x),(0.0f)),max((output0.y+val3_0.y),(0.0f)),max((output0.z+val3_0.z),(0.0f)),max((output0.w+val3_0.w),(0.0f)))); /* (1, 256, 4) */
|
||||
}
|
||||
}"""
|
||||
|
||||
new = """__kernel void r_256_16_4_8_4(write_only image2d_t data0, read_only image2d_t data1, read_only image2d_t data2, const __global float* data3) {
|
||||
const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
|
||||
__attribute__ ((aligned (16))) __local float temp[64];
|
||||
int gidx0 = get_group_id(0); /* 256 */
|
||||
int lidx1 = get_local_id(1); /* 16 */
|
||||
int lidx2 = get_local_id(0); /* 4 */
|
||||
float acc0 = 0.0f;
|
||||
for (int ridx0 = 0; ridx0 < 8; ++ridx0) {
|
||||
float4 val0 = read_imagef(data1, smp, (int2)(((lidx1*8)+ridx0),0));
|
||||
float4 val1 = read_imagef(data2, smp, (int2)(((lidx1*32)+lidx2+(ridx0*4)),gidx0));
|
||||
acc0 = (((val0).x*(val1).x)+acc0);
|
||||
acc0 = (((val0).y*(val1).y)+acc0);
|
||||
acc0 = (((val0).z*(val1).z)+acc0);
|
||||
acc0 = (((val0).w*(val1).w)+acc0);
|
||||
}
|
||||
temp[(lidx1*4)+lidx2] = acc0;
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
float4 acc1 = (float4)(0.0f,0.0f,0.0f,0.0f);
|
||||
for (int ridx1 = 0; ridx1 < 16; ++ridx1) {
|
||||
float4 val2 = (float4)(*((__local float4*)(temp+ridx1*4)));
|
||||
(acc1).x = ((val2).x+(acc1).x);
|
||||
(acc1).y = ((val2).y+(acc1).y);
|
||||
(acc1).z = ((val2).z+(acc1).z);
|
||||
(acc1).w = ((val2).w+(acc1).w);
|
||||
}
|
||||
float4 val3 = (float4)(*((__global float4*)(data3+gidx0*4)));
|
||||
write_imagef(data0, (int2)(gidx0,0), (float4)(max(((acc1).x+(val3).x),0.0f),max(((acc1).y+(val3).y),0.0f),max(((acc1).z+(val3).z),0.0f),max(((acc1).w+(val3).w),0.0f)));
|
||||
}"""
|
||||
|
||||
from tinygrad.runtime.ops_gpu import CLBuffer, CLProgram
|
||||
from tinygrad.helpers import dtypes, prod
|
||||
|
||||
if __name__ == "__main__":
|
||||
out = CLBuffer(prod((1, 128, 4)), dtypes.imageh((1,128,4)))
|
||||
x = CLBuffer(prod((1, 128, 4)), dtypes.imageh((1,128,4)))
|
||||
w = CLBuffer(prod((256, 512, 4)), dtypes.imageh((256, 512, 4)))
|
||||
b = CLBuffer(1024, dtypes.float)
|
||||
|
||||
old = CLProgram("re_S256_16_8", old)
|
||||
new = CLProgram("r_256_16_4_8_4", new)
|
||||
|
||||
old_tms = []
|
||||
new_tms = []
|
||||
|
||||
for i in range(5):
|
||||
old_tms.append(old([1,1,256], [4,16,1], out, x, w, b, wait=True))
|
||||
new_tms.append(new([256,1,1], [4,16,1], out, x, w, b, wait=True))
|
||||
|
||||
print(f"old: {min(old_tms)*1e6:.2f} us new: {min(new_tms)*1e6:.2f} us")
|
||||
|
||||
|
|
@ -281,12 +281,12 @@ class Thneed:
|
|||
|
||||
if DEBUGCL >= 2:
|
||||
for i, ((prg, args), e) in enumerate(zip(self.cl_cache, events)):
|
||||
print(f"{i:3d} {prg.name:20s} " + "queued @ %5.2f ms, submit @ %5.2fms, start @ %5.2f ms, end @ %5.2f ms" % tuple((x*OSX_TIMING_RATIO - st*1e9)/1e6 for x in [e.profile.queued, e.profile.submit, e.profile.start, e.profile.end]))
|
||||
print(f"{i:3d} {prg.name:25s} " + "queued @ %5.2f ms, submit @ %5.2fms, start @ %5.2f ms, end @ %5.2f ms" % tuple((x*OSX_TIMING_RATIO - st*1e9)/1e6 for x in [e.profile.queued, e.profile.submit, e.profile.start, e.profile.end]))
|
||||
if DEBUGCL >= 1:
|
||||
total_runtime = 0
|
||||
for i, ((prg, args), e) in enumerate(zip(self.cl_cache, events)):
|
||||
runtime = (e.profile.end - e.profile.start) * OSX_TIMING_RATIO
|
||||
print(f"{i:3d} time {total_runtime/1e6:5.2f} ms running {prg.name:20s} with {str(args[0]):15s} {str(args[1]):15s} count {len(args)-2:2d} runtime {runtime/1e3:7.2f} us {(getattr(prg, 'op_estimate', float('nan')))/runtime:9.2f} GFLOPS -> {args[2].shape if hasattr(args[2], 'shape') else args[2].size}")
|
||||
print(f"{i:3d} time {total_runtime/1e6:5.2f} ms running {prg.name:25s} with {str(args[0]):15s} {str(args[1]):15s} count {len(args)-2:2d} runtime {runtime/1e3:7.2f} us {(getattr(prg, 'op_estimate', float('nan')))/runtime:9.2f} GFLOPS -> {args[2].shape if hasattr(args[2], 'shape') else args[2].size}")
|
||||
if hasattr(prg, 'prg') and ((DEBUGCL >= 2 and getenv("PRINT_KERNEL", -1) == i) or DEBUGCL >= 3):
|
||||
print(prg.prg)
|
||||
total_runtime += runtime
|
||||
|
|
|
@ -79,7 +79,7 @@ def compile(dat, output_fn):
|
|||
|
||||
global_size = prg.global_size + [1]*(3-len(prg.global_size))
|
||||
local_size = prg.local_size + [1]*(3-len(prg.local_size))
|
||||
cl_cache.append((prg.clprg, [[g*l for g,l in zip(global_size, local_size)], local_size, *[x._buf for x in args]]))
|
||||
cl_cache.append((prg.clprg, [[int(g*l) for g,l in zip(global_size, local_size)], local_size, *[x._buf for x in args]]))
|
||||
used_ops += prg.op_estimate
|
||||
|
||||
from extra.thneed import Thneed
|
||||
|
|
|
@ -0,0 +1,71 @@
|
|||
import os
|
||||
if "FLOAT16" not in os.environ: os.environ["FLOAT16"] = "1"
|
||||
if "IMAGE" not in os.environ: os.environ["IMAGE"] = "2"
|
||||
if "NOLOCALS" not in os.environ: os.environ["NOLOCALS"] = "1"
|
||||
if "OPT" not in os.environ: os.environ["OPT"] = "99"
|
||||
os.environ["PREREALIZE"] = "0"
|
||||
|
||||
OPENPILOT_MODEL = "https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx"
|
||||
|
||||
import sys
|
||||
import onnx
|
||||
import io
|
||||
from typing import Tuple, List
|
||||
from extra.utils import fetch
|
||||
from extra.onnx import get_run_onnx
|
||||
from tinygrad.graph import print_tree
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import dtypes, partition, GlobalCounters, Context, DEBUG, getenv
|
||||
from tinygrad.realize import run_schedule
|
||||
from tinygrad.ops import LoadOps, Device, ScheduleItem
|
||||
Device.DEFAULT = "GPU"
|
||||
|
||||
def get_schedule(fn:str) -> Tuple[List[ScheduleItem], List[ScheduleItem]]:
|
||||
Tensor.no_grad = True
|
||||
Tensor.training = False
|
||||
|
||||
# load the model
|
||||
dat = fetch(fn)
|
||||
onnx_model = onnx.load(io.BytesIO(dat))
|
||||
run_onnx = get_run_onnx(onnx_model)
|
||||
input_shapes = {inp.name:tuple(x.dim_value for x in inp.type.tensor_type.shape.dim) for inp in onnx_model.graph.input}
|
||||
|
||||
# run the model
|
||||
inputs = {k:Tensor.empty(*shp) for k,shp in input_shapes.items()}
|
||||
ret: Tensor = next(iter(run_onnx(inputs).values())).cast(dtypes.float32).contiguous()
|
||||
schedule = ret.lazydata.schedule()
|
||||
|
||||
# filter schedule that don't depend on the inputs
|
||||
input_lb = [x.lazydata.base for x in inputs.values()]
|
||||
depends = set(input_lb)
|
||||
for si in schedule:
|
||||
if any(b in depends for b in si.inputs):
|
||||
depends.add(si.out)
|
||||
|
||||
# run all kernels that don't depend on the inputs
|
||||
# NOTE: there's two extra kernels due to fusions that now happen since the weights aren't realized
|
||||
schedule, schedule_independent = partition(schedule, lambda si: si.out in depends)
|
||||
print(f"{len(schedule)} schedule items depend on the input, {len(schedule_independent)} don't")
|
||||
|
||||
# confirm no loadops in the (non independent) schedule except for the ones that load the input buffers
|
||||
assert all(si.ast.op not in LoadOps or si.out in input_lb for si in schedule), "has loadops, can't compile to Thneed"
|
||||
return schedule, schedule_independent
|
||||
|
||||
def lb_to_numbers(schedule):
|
||||
nschedule = []
|
||||
nlb = {}
|
||||
for op,out,buffers in schedule:
|
||||
for lb in (out,)+buffers:
|
||||
if lb not in nlb:
|
||||
nlb[lb] = len(nlb)
|
||||
nschedule.append((op, nlb[out], tuple(nlb[x] for x in buffers)))
|
||||
return nschedule
|
||||
|
||||
if __name__ == "__main__":
|
||||
schedule, schedule_independent = get_schedule(sys.argv[1] if len(sys.argv) > 1 else OPENPILOT_MODEL)
|
||||
run_schedule(schedule_independent)
|
||||
|
||||
print("**** running real kernels ****")
|
||||
with Context(DEBUG=2):
|
||||
GlobalCounters.reset()
|
||||
run_schedule(schedule)
|
|
@ -1,2 +1,2 @@
|
|||
#!/bin/bash
|
||||
FLOAT16=1 DEBUGCL=1 IMAGE=2 GPU=1 python3 openpilot/compile.py
|
||||
NOLOCALS=1 FLOAT16=1 DEBUGCL=1 IMAGE=2 GPU=1 python3 openpilot/compile.py
|
||||
|
|
|
@ -79,6 +79,7 @@ class Kernel:
|
|||
self.use_tensor_cores: bool = False
|
||||
self.exclude_local_upcast: int = 0
|
||||
self.reverse_upcast_dir: bool = False
|
||||
self.dont_use_locals: bool = False
|
||||
|
||||
self.global_size: Optional[List[int]] = None
|
||||
self.local_size: Optional[List[int]] = None
|
||||
|
|
|
@ -206,7 +206,10 @@ class Linearizer(OptimizedKernel):
|
|||
loop_uop = self.loop_uops[x.expr]
|
||||
if loop_uop.uop == UOps.LOOP: self.uop(UOps.END, None, (loop_uop,))
|
||||
|
||||
if self.opts.has_local:
|
||||
if self.dont_use_locals:
|
||||
self.global_size = [x.max+1 for x in loop_global_idxs][::-1]
|
||||
self.loop_uops.update({x.expr:self.uop(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr.replace("gidx", "idx"), x.max+1)) for i,x in enumerate(loop_global_idxs)})
|
||||
elif self.opts.has_local:
|
||||
self.global_size, self.local_size = [x.max+1 for x in loop_global_idxs][::-1], [x.max+1 for x in loop_local_idxs][::-1]
|
||||
self.global_size += [1]*(3-len(self.global_size))
|
||||
self.local_size += [1]*(3-len(self.local_size))
|
||||
|
@ -317,7 +320,9 @@ class Linearizer(OptimizedKernel):
|
|||
self.uop(UOps.BARRIER, None, (), cachable=False)
|
||||
end_loop(loop_local_idxs) # TODO: this is ending too much, should only end what's in the if?
|
||||
if self.opts.has_local:
|
||||
if_cond: UOp = Variable.ands([x<1 for x in local_idxs[self.local_dims:]]).render(self.render_ops, self)
|
||||
fake_idxs = [Variable.num(0)]*len(self.sts[-1].shape)
|
||||
fake_idxs[self.global_dims+self.local_dims:self.global_dims+len(local_idxs)] = local_idxs[self.local_dims:]
|
||||
if_cond: UOp = (self.sts[-1].expr_idxs(fake_idxs)[0]<1).render(self.render_ops, self)
|
||||
if_gate = self.uop(UOps.IF, None, (if_cond,), cachable=False)
|
||||
|
||||
# create new late reduce local loops and replace local_idxs that have been used
|
||||
|
|
|
@ -426,6 +426,9 @@ class OptimizedKernel(Kernel):
|
|||
# **** local groups ****
|
||||
|
||||
if self.opts.has_local:
|
||||
if getenv("NOLOCALS") and self.local_dims == 0 and not self.group_for_reduce:
|
||||
self.dont_use_locals = True
|
||||
else:
|
||||
# prioritize making expand axes local
|
||||
local_axis_ranking = [(any(self.sts[buf_index].views[-1].strides[axis] == 0 for buf_index in range(len(self.sts))), axis) for axis in range(len(self.full_shape[:self.first_reduce]))]
|
||||
to_local: List[Tuple[int, int]] = []
|
||||
|
|
|
@ -2,7 +2,7 @@ from typing import Dict, List, cast, DefaultDict, Optional
|
|||
from copy import deepcopy
|
||||
from tinygrad.lazy import vars_from_ast
|
||||
from tinygrad.ops import Device, Compiled, MemBuffer
|
||||
from tinygrad.helpers import prod, getenv, flatten
|
||||
from tinygrad.helpers import prod, getenv, ImageDType, flatten
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
from tinygrad.runtime.lib import RawBuffer
|
||||
from collections import defaultdict
|
||||
|
@ -57,7 +57,7 @@ def bufs_from_lin(lin:Linearizer) -> List[RawBuffer]:
|
|||
for x in lin.membufs: bufsts[x.idx].append(x)
|
||||
rawbufs:List[Optional[RawBuffer]] = [None]*len(bufsts)
|
||||
for k,lx in bufsts.items():
|
||||
rawbufs[k] = device.buffer(max(y.st.size() for y in lx), lx[0].dtype)
|
||||
rawbufs[k] = device.buffer(prod(lx[0].dtype.shape) if isinstance(lx[0].dtype, ImageDType) else max(y.st.size() for y in lx), lx[0].dtype)
|
||||
assert all(r is not None for r in rawbufs)
|
||||
return cast(List[RawBuffer], rawbufs)
|
||||
|
||||
|
|
|
@ -60,7 +60,7 @@ def image_conv2d(self, weight, bias=None, groups=1, stride=1, dilation=1, paddin
|
|||
# contiguous creates the image, and early realize static weights (TODO: test for the static weight)
|
||||
if IMAGE >= 2: x,w = x.cast(base_image_type(x.shape)), w.cast(base_image_type(w.shape))
|
||||
x, w = x.contiguous(), w.contiguous()
|
||||
if get_single_root(w.lazydata).realized: w.realize()
|
||||
if getenv("PREREALIZE", 1) and get_single_root(w.lazydata).realized: w.realize()
|
||||
|
||||
# expand out
|
||||
rcin_hi, rcin_lo = cin//4 if cin >= 4 else 1, 4 if cin >= 4 else 1
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from __future__ import annotations
|
||||
import time, importlib, inspect, functools, pathlib
|
||||
import time, importlib, inspect, functools, pathlib, itertools, random
|
||||
import numpy as np
|
||||
from enum import Enum, auto
|
||||
from typing import TYPE_CHECKING, Union, Type, Tuple, Any, List, Optional, Dict, Callable, cast, Mapping
|
||||
|
@ -184,6 +184,18 @@ class ASTRunner:
|
|||
op_estimate=k.info.flops, mem_estimate=k.mem_estimate,
|
||||
display_name=k.display_name, runtime_args={"binary": False})
|
||||
|
||||
def optimize_local_size(self, global_size, rawbufs) -> List[int]:
|
||||
assert self.global_size is not None, "needs a global size to optimize local size"
|
||||
MAX_WORKGROUP = self.clprg.max_work_group_size() if hasattr(self.clprg, 'max_work_group_size') else 1024
|
||||
local_dims = [[x for x in set([sz, 1, 2, 4, 8, 16, 32, 64, 128, 256, MAX_WORKGROUP]) if x<=sz] for sz in global_size]
|
||||
local_sizes = [list(x) for x in itertools.product(*local_dims) if prod(x) <= MAX_WORKGROUP] * 2 # try each valid size twice
|
||||
def try_exec(local_size):
|
||||
try:
|
||||
return self.clprg([g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)], local_size, *rawbufs, wait=True)
|
||||
except Exception:
|
||||
return float('inf')
|
||||
return min([(try_exec(local_size), local_size) for local_size in random.sample(local_sizes, len(local_sizes))])[1]
|
||||
|
||||
def build(self, runtime, batch_exec=BasicBatchExecutor):
|
||||
self.clprg, self.batch_exec = runtime(self.name, self.prg, **self.runtime_args), batch_exec
|
||||
return self
|
||||
|
@ -201,6 +213,9 @@ class ASTRunner:
|
|||
def __call__(self, rawbufs:List[RawBuffer], var_vals:Optional[Dict[Variable, int]]=None, jit=False, force_wait=False) -> Optional[float]:
|
||||
if var_vals is None: var_vals = {}
|
||||
global_size, local_size = self.launch_dims(var_vals)
|
||||
if global_size is not None and local_size is None:
|
||||
local_size = self.local_size = self.optimize_local_size(global_size, rawbufs)
|
||||
global_size = self.global_size = [g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)]
|
||||
if et := self.clprg(global_size, local_size, *rawbufs, *var_vals.values(), wait=force_wait or DEBUG>=1): GlobalCounters.time_sum_s += et
|
||||
op_estimate = sym_infer(self.op_estimate, var_vals)
|
||||
if DEBUG >= 2:
|
||||
|
|
|
@ -16,6 +16,7 @@ class CStyleLanguage(NamedTuple):
|
|||
smem_prefix_for_cast: bool = True
|
||||
arg_int_prefix: str = ""
|
||||
barrier: str = ""
|
||||
xid: List[str] = []
|
||||
gid: List[str] = []
|
||||
lid: List[str] = []
|
||||
global_max: List[int] = []
|
||||
|
@ -173,7 +174,7 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> st
|
|||
assert dtype is not None
|
||||
kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u,'acc')} = {lang.render_const(args, dtype)};")
|
||||
elif uop == UOps.SPECIAL:
|
||||
xid = lang.gid if args[1].startswith("g") else lang.lid
|
||||
xid = lang.gid if args[1].startswith("g") else (lang.xid if args[1].startswith("i") else lang.lid)
|
||||
kk(f"{lang.size_prefix} {args[1]} = {xid[args[0]]}; /* {args[2]} */")
|
||||
if args[1].startswith("l"): local_size.append(args[2])
|
||||
r[u] = args[1]
|
||||
|
|
|
@ -14,6 +14,7 @@ class OpenCLLanguage(CStyleLanguage):
|
|||
float4 = "(float4)"
|
||||
gid = [f'get_group_id({i})' for i in range(3)]
|
||||
lid = [f'get_local_id({i})' for i in range(3)]
|
||||
xid = [f'get_global_id({i})' for i in range(3)]
|
||||
uses_vload=True
|
||||
|
||||
OpenCLRenderer = functools.partial(uops_to_cstyle, OpenCLLanguage())
|
||||
|
|
|
@ -89,6 +89,7 @@ renderer = functools.partial(uops_to_cstyle, CStyleLanguage(
|
|||
kernel_prefix = "__global__ ", smem_prefix = "__shared__ ", smem_prefix_for_cast=False, arg_int_prefix = "const int", barrier = "__syncthreads();", float4 = "make_float4",
|
||||
gid = [f'blockIdx.{chr(120+i)}' for i in range(3)],
|
||||
lid = [f'threadIdx.{chr(120+i)}' for i in range(3)],
|
||||
xid = [f'(blockIdx.{chr(120+i)}*blockDim.{chr(120+i)}+threadIdx.{chr(120+i)})' for i in range(3)],
|
||||
half_prekernel = """
|
||||
#include <cuda_fp16.h>
|
||||
struct __align__(8) half4 {
|
||||
|
|
|
@ -94,7 +94,7 @@ class CLProgram:
|
|||
cl_bufs.append(x._buf)
|
||||
if hasattr(x, "event"): wait_for.append(x.event)
|
||||
else: cl_bufs.append(x)
|
||||
e = self.clprgs[cl_bufs[0].device](CL.cl_queue[cl_bufs[0].device], [g*l for g,l in zip(global_size, local_size)] if local_size is not None else global_size, local_size, *cl_bufs, wait_for=wait_for)
|
||||
e = self.clprgs[cl_bufs[0].device](CL.cl_queue[cl_bufs[0].device], [int(g*l) for g,l in zip(global_size, local_size)] if local_size is not None else global_size, local_size, *cl_bufs, wait_for=wait_for)
|
||||
if wait:
|
||||
e.wait()
|
||||
try:
|
||||
|
|
|
@ -327,7 +327,7 @@ def create_rednode(typ:Type[RedNode], nodes:List[Node]):
|
|||
def sym_rename(s) -> str: return f"s{sym_rename.cache_info().currsize}"
|
||||
def sym_render(a: Union[Node, int], ops=None, ctx=None) -> str: return str(a) if isinstance(a, int) else a.render(ops, ctx)
|
||||
def sym_infer(a: Union[Node, int], var_vals: Dict[Variable, int]) -> int:
|
||||
if isinstance(a, int): return a
|
||||
if isinstance(a, (int, float)): return a
|
||||
ret = a.substitute({k:Variable.num(v) for k, v in var_vals.items()})
|
||||
assert isinstance(ret, NumNode), f"sym_infer didn't produce NumNode from {a} with {var_vals}"
|
||||
return ret.b
|
||||
|
|
Loading…
Reference in New Issue