mirror of https://github.com/commaai/tinygrad.git
Multireduce Kernels - prereq refactor (#4173)
* refector rendering a reduceop into it's own function (will help for kernels with multiple reduceops) * linters * addressing concerns
This commit is contained in:
parent
593c90d7d6
commit
4592fc8fe7
|
@ -163,6 +163,144 @@ class Linearizer(Kernel):
|
|||
else: stores.append(self.uops.add(UOps.STORE, None, (buf_uop, rendered_idx, var, valid.render(self.render_ops, self))))
|
||||
return stores
|
||||
|
||||
# render loop
|
||||
def render_loop(self, xx:List[Variable]) -> Tuple[UOp, ...]:
|
||||
new_loops = {x.expr:self.uops.add(UOps.LOOP, dtypes.int32, (
|
||||
self.const(x.min) if isinstance(x.min, int) else cast(Node, x.min).render(self.render_ops, self),
|
||||
self.const(x.max+1) if isinstance(x.max, int) else cast(Node, x.max+1).render(self.render_ops, self)), cachable=False) for x in xx if not isinstance(x, NumNode) and x.expr is not None} # noqa: E501
|
||||
self.loop_uops.update(new_loops)
|
||||
return tuple(new_loops.values())
|
||||
|
||||
def render_reduceop(self, reduceop: LazyOp, loaded_buffers:Dict[Union[MemBuffer, ConstBuffer, LocalBuffer], List[UOp]], \
|
||||
global_idxs, local_idxs, upcast_idxs):
|
||||
# define indecies
|
||||
full_upcast_idxs = [Variable(f"_uidx{i}", 0, s-1) for i, s in enumerate(self.full_shape[self.shape_len-self.upcasted:])]
|
||||
reduce_idxs = [Variable(f"ridx{i}", 0, self.full_shape[i]-1) for i in range(self.first_reduce+self.group_for_reduces, self.shape_len-self.upcasted)] # noqa: E501
|
||||
fake_reduce_idxs = [x*0 for x in reduce_idxs]
|
||||
|
||||
# define accumulator
|
||||
out_buf = -1 if self.group_for_reduces else 0
|
||||
acc = self.global_load(out_buf, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, self.get_reduce_acc(reduceop))
|
||||
|
||||
# reduce loop
|
||||
loop_ctx = self.render_loop(reduce_idxs)
|
||||
|
||||
if (tc:=self.tensor_core):
|
||||
def calc_tc_idxs(local_sizes: List[int], aliases: List[List[int]]):
|
||||
replace_idxs, thread_idxs, thread_idx = [], [], Variable("_uidx_tc", 0, prod(local_sizes)-1)
|
||||
for s in local_sizes:
|
||||
thread_idxs.append(thread_idx % s)
|
||||
thread_idx //= s
|
||||
for alias in aliases:
|
||||
full_var, full_var_sz = NumNode(0), 1
|
||||
if alias[0] != 0:
|
||||
for i in alias:
|
||||
next_var = local_idxs[-i] if i > 0 else thread_idxs[-i-1]
|
||||
full_var += next_var * full_var_sz
|
||||
full_var_sz *= next_var.max+1
|
||||
replace_idxs.append(full_var)
|
||||
return replace_idxs
|
||||
|
||||
# compute local aliases
|
||||
locals_to_store = []
|
||||
for i in self.local_alias:
|
||||
localbuf_idx = self.bufs.index(self.local_alias[i])
|
||||
buf_idxs = [idx*0 if s == 0 else idx for idx,s in zip(global_idxs+local_idxs+reduce_idxs+full_upcast_idxs,self.sts[i].real_strides())]
|
||||
if (tc:=self.tensor_core):
|
||||
min_alias_idx = min(self.local_alias.keys())
|
||||
replace_input_idxs = calc_tc_idxs(tc.thread_local_sizes[i-min_alias_idx], tc.thread_local_aliases[i-min_alias_idx])
|
||||
for n in range(len(tc.threads)):
|
||||
buf_idxs[self.first_reduce-len(tc.threads)+n] = replace_input_idxs[n] # replace locals
|
||||
for n in range(tc.num_upcasts()):
|
||||
buf_idxs[self.shape_len-self.upcasted+n] = replace_input_idxs[len(tc.threads)+n] # replace upcasts
|
||||
if DEBUG >= 3: print(f"{localbuf_idx} alias {i}: sts={self.sts[i]} idxs={buf_idxs}")
|
||||
ll = self.global_load(i, buf_idxs)
|
||||
locals_to_store.append((localbuf_idx, buf_idxs, ll))
|
||||
|
||||
# copy in any global buffers
|
||||
if (tc:=self.tensor_core):
|
||||
replace_acc_idxs = calc_tc_idxs(tc.thread_local_sizes[2], tc.thread_local_aliases[2])
|
||||
for n in range(len(tc.threads)):
|
||||
local_idxs[self.local_dims-len(tc.threads)+n] = replace_acc_idxs[n] # replace locals
|
||||
for n in range(len(replace_acc_idxs)-len(tc.threads)):
|
||||
upcast_idxs[n] = replace_acc_idxs[len(tc.threads)+n] # replace upcasts
|
||||
if DEBUG >= 3: print(f"store alias: sts={self.sts[0]} idxs={global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs}")
|
||||
|
||||
wmma_sz = [prod(l) for l in tc.thread_local_sizes]
|
||||
def upcast_strides(buf:int):
|
||||
strides, next = [], 1
|
||||
for (sz, stride, reduce) in self.upcasted_axis(buf)[tc.num_upcasts():]:
|
||||
strides.append((0 if stride == 0 else next, sz))
|
||||
next *= 1 if stride == 0 else sz
|
||||
return strides
|
||||
upcasts, dev = [upcast_strides(x) for x in [locals_to_store[0][0], locals_to_store[1][0], 0]], self.opts.device
|
||||
for iter in [x[::-1] for x in itertools.product(*[x for x in [range(sz) for _,sz in upcasts[0]][::-1]])]:
|
||||
offs = [x*y for (x,y) in zip([sum([prod(x) for x in zip(iter, [stride for stride,_ in y])]) for y in upcasts], wmma_sz)]
|
||||
ops = (self.uops.add(UOps.CAST, tc.dtype_in.vec(wmma_sz[0]), tuple(locals_to_store[0][2][offs[0]:offs[0]+wmma_sz[0]])),
|
||||
self.uops.add(UOps.CAST, tc.dtype_in.vec(wmma_sz[1]), tuple(locals_to_store[1][2][offs[1]:offs[1]+wmma_sz[1]])),
|
||||
self.uops.add(UOps.CAST, (dt3:=tc.dtype_out.vec(wmma_sz[2])), tuple(op3:=acc[offs[2]:offs[2]+wmma_sz[2]])))
|
||||
ret = self.uops.add(UOps.WMMA, dt3, ops, (str(tc), tc.dims, tc.dtype_in, tc.dtype_out, tuple(map(prod, tc.thread_local_sizes)), dev))
|
||||
for z in range(wmma_sz[2]):
|
||||
acc[offs[2]+z] = self.uops.add(UOps.PHI, tc.dtype_out, (op3[z], self.uops.add(UOps.GEP, tc.dtype_out, (ret,), z)) + loop_ctx)
|
||||
else:
|
||||
assert not locals_to_store, "storing locals isn't supported here"
|
||||
|
||||
# load earlybufs
|
||||
loaded_buffers.update({b:self.global_load(self.bufs.index(self.local_alias[i]) if i in self.local_alias else i,
|
||||
global_idxs+local_idxs+reduce_idxs+full_upcast_idxs) for i,b in enumerate(self.bufs) if b in self.earlybufs})
|
||||
|
||||
# run early AST (with reduce)
|
||||
self.ast_parse(reduceop, acc, self.acc_offsets(self.full_buf_index), loaded_buffers, do_reduce=True, loop_ctx=loop_ctx)
|
||||
|
||||
# end the reduce loop
|
||||
self.load_cache.clear()
|
||||
|
||||
# end the local loop, do the local reduce
|
||||
if self.group_for_reduces:
|
||||
fake_global_idxs = [x*0 for x in global_idxs]
|
||||
stores = self.global_store(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, acc) # store accumulators
|
||||
barrier = self.uops.add(UOps.BARRIER, None, tuple(stores), cachable=False)
|
||||
if self.opts.has_local:
|
||||
fake_idxs = [NumNode(0)]*len(self.sts[-1].shape)
|
||||
fake_idxs[self.global_dims+self.local_dims:self.global_dims+len(local_idxs)] = local_idxs[self.local_dims:]
|
||||
if_cond: UOp = create_lt_node(self.sts[-1].expr_idxs(fake_idxs)[0], 1).render(self.render_ops, self)
|
||||
barrier = self.uops.add(UOps.IF, None, (if_cond, barrier), cachable=False)
|
||||
|
||||
# create new late reduce local loops and replace local_idxs that have been used
|
||||
end_local_idxs = [Variable(f"tidx{i}", 0, self.full_shape[i]-1 if i >= self.first_reduce and i not in self.upcast_in_mid_reduce_axes else 0) for i in range(0, self.first_reduce+self.group_for_reduces)] # noqa: E501
|
||||
local_idxs = local_idxs[:self.local_dims] + end_local_idxs[self.global_dims + self.local_dims:]
|
||||
|
||||
# if any group_for_reduce items aren't reduces, upcast them here
|
||||
for j in self.upcast_in_mid_reduce_axes:
|
||||
self.reshape_and_permute(None, [i for i in range(self.shape_len) if i != j] + [j])
|
||||
self.upcast()
|
||||
self.group_for_reduces -= 1
|
||||
local_idxs = local_idxs[:-1]
|
||||
end_local_idxs = end_local_idxs[:-1]
|
||||
# regenerate upcast_idxs
|
||||
upcast_idxs = [Variable(f"_uidx{i}", 0, s-1) for i, s in enumerate(self.output_shape[self.shape_len-self.upcasted:])]
|
||||
|
||||
# NOTE: this structure is the same as the reduce op above
|
||||
|
||||
# define late accumulator
|
||||
acc = self.global_load(0, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, self.get_reduce_acc(reduceop))
|
||||
|
||||
# late reduce loop
|
||||
loop_ctx = self.render_loop(end_local_idxs)
|
||||
|
||||
# load localbufs
|
||||
loaded_buffers[self.bufs[-1]] = self.global_load(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, barrier=barrier)
|
||||
|
||||
# there's no AST here (and there's no shape for the reduce LazyOp)
|
||||
self.ast_parse(LazyOp(reduceop.op, (LazyOp(BufferOps.LOAD, (), self.bufs[-1]),)), acc, self.acc_offsets(-1), loaded_buffers, do_reduce=True, loop_ctx=loop_ctx) # noqa: E501
|
||||
|
||||
# end the late reduce loop
|
||||
self.load_cache.clear()
|
||||
|
||||
# all local indices which were used for group_for_reduce are not valid any more and should be replaced with fake NumNode(0), since they have
|
||||
# been rewritten with fake end_local_idxs.
|
||||
return (acc, loaded_buffers, fake_reduce_idxs, local_idxs[:self.local_dims] + [NumNode(0) for i in range(self.group_for_reduces)], upcast_idxs)
|
||||
|
||||
kernel_cnt: Final[DefaultDict[str, int]] = defaultdict(int)
|
||||
def linearize(self):
|
||||
# no new opts and we already ran? skip relinearizing
|
||||
|
@ -223,17 +361,8 @@ class Linearizer(Kernel):
|
|||
# define indexes
|
||||
global_idxs, loop_global_idxs = get_grouped_dims("gidx", 0, self.full_shape[:self.global_dims], 3 if self.opts.has_local else 0)
|
||||
local_idxs, loop_local_idxs = get_grouped_dims("lidx", self.global_dims, self.full_shape[self.global_dims:self.first_reduce+self.group_for_reduces], 3 if self.opts.has_local else 0) # noqa: E501
|
||||
full_upcast_idxs = [Variable(f"_uidx{i}", 0, s-1) for i, s in enumerate(self.full_shape[self.shape_len-self.upcasted:])]
|
||||
upcast_idxs = [Variable(f"_uidx{i}", 0, s-1) for i, s in enumerate(self.output_shape[self.shape_len-self.upcasted:])]
|
||||
|
||||
# global and local loops
|
||||
def render_loop(xx:List[Variable]) -> Tuple[UOp, ...]:
|
||||
new_loops = {x.expr:self.uops.add(UOps.LOOP, dtypes.int32, (
|
||||
self.const(x.min) if isinstance(x.min, int) else cast(Node, x.min).render(self.render_ops, self),
|
||||
self.const(x.max+1) if isinstance(x.max, int) else cast(Node, x.max+1).render(self.render_ops, self)), cachable=False) for x in xx if not isinstance(x, NumNode) and x.expr is not None} # noqa: E501
|
||||
self.loop_uops.update(new_loops)
|
||||
return tuple(new_loops.values())
|
||||
|
||||
# set global/local size
|
||||
self.global_size: Optional[List[int]] = None
|
||||
self.local_size: Optional[List[int]] = None
|
||||
|
@ -245,142 +374,17 @@ class Linearizer(Kernel):
|
|||
self.loop_uops.update({x.expr:self.uops.add(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr, x.max+1)) for i,x in enumerate(loop_global_idxs)}) # noqa: E501
|
||||
self.loop_uops.update({x.expr:self.uops.add(UOps.SPECIAL, dtypes.int32, (), (len(loop_local_idxs)-1-i, x.expr, x.max+1)) for i,x in enumerate(loop_local_idxs)}) # noqa: E501
|
||||
else:
|
||||
render_loop(loop_global_idxs+loop_local_idxs)
|
||||
self.render_loop(loop_global_idxs+loop_local_idxs)
|
||||
|
||||
# parse AST
|
||||
loaded_buffers = {}
|
||||
loaded_buffers:Dict[Union[MemBuffer, ConstBuffer, LocalBuffer], List[UOp]] = {}
|
||||
acc: List[UOp] = []
|
||||
self.load_cache: Dict[str, UOp] = {}
|
||||
|
||||
# reduce op
|
||||
fake_reduce_idxs: List[Variable] = []
|
||||
if self.reduceop is not None:
|
||||
# define indexes
|
||||
reduce_idxs = [Variable(f"ridx{i}", 0, self.full_shape[i]-1) for i in range(self.first_reduce+self.group_for_reduces, self.shape_len-self.upcasted)] # noqa: E501
|
||||
fake_reduce_idxs = [x*0 for x in reduce_idxs]
|
||||
|
||||
# define accumulator
|
||||
out_buf = -1 if self.group_for_reduces else 0
|
||||
acc = self.global_load(out_buf, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, self.get_reduce_acc(self.reduceop))
|
||||
|
||||
# reduce loop
|
||||
loop_ctx = render_loop(reduce_idxs)
|
||||
|
||||
if (tc:=self.tensor_core):
|
||||
def calc_tc_idxs(local_sizes: List[int], aliases: List[List[int]]):
|
||||
replace_idxs, thread_idxs, thread_idx = [], [], Variable("_uidx_tc", 0, prod(local_sizes)-1)
|
||||
for s in local_sizes:
|
||||
thread_idxs.append(thread_idx % s)
|
||||
thread_idx //= s
|
||||
for alias in aliases:
|
||||
full_var, full_var_sz = NumNode(0), 1
|
||||
if alias[0] != 0:
|
||||
for i in alias:
|
||||
next_var = local_idxs[-i] if i > 0 else thread_idxs[-i-1]
|
||||
full_var += next_var * full_var_sz
|
||||
full_var_sz *= next_var.max+1
|
||||
replace_idxs.append(full_var)
|
||||
return replace_idxs
|
||||
|
||||
# compute local aliases
|
||||
locals_to_store = []
|
||||
for i in self.local_alias:
|
||||
localbuf_idx = self.bufs.index(self.local_alias[i])
|
||||
buf_idxs = [idx*0 if s == 0 else idx for idx,s in zip(global_idxs+local_idxs+reduce_idxs+full_upcast_idxs,self.sts[i].real_strides())]
|
||||
if (tc:=self.tensor_core):
|
||||
min_alias_idx = min(self.local_alias.keys())
|
||||
replace_input_idxs = calc_tc_idxs(tc.thread_local_sizes[i-min_alias_idx], tc.thread_local_aliases[i-min_alias_idx])
|
||||
for n in range(len(tc.threads)):
|
||||
buf_idxs[self.first_reduce-len(tc.threads)+n] = replace_input_idxs[n] # replace locals
|
||||
for n in range(tc.num_upcasts()):
|
||||
buf_idxs[self.shape_len-self.upcasted+n] = replace_input_idxs[len(tc.threads)+n] # replace upcasts
|
||||
if DEBUG >= 3: print(f"{localbuf_idx} alias {i}: sts={self.sts[i]} idxs={buf_idxs}")
|
||||
ll = self.global_load(i, buf_idxs)
|
||||
locals_to_store.append((localbuf_idx, buf_idxs, ll))
|
||||
|
||||
# copy in any global buffers
|
||||
if (tc:=self.tensor_core):
|
||||
replace_acc_idxs = calc_tc_idxs(tc.thread_local_sizes[2], tc.thread_local_aliases[2])
|
||||
for n in range(len(tc.threads)):
|
||||
local_idxs[self.local_dims-len(tc.threads)+n] = replace_acc_idxs[n] # replace locals
|
||||
for n in range(len(replace_acc_idxs)-len(tc.threads)):
|
||||
upcast_idxs[n] = replace_acc_idxs[len(tc.threads)+n] # replace upcasts
|
||||
if DEBUG >= 3: print(f"store alias: sts={self.sts[0]} idxs={global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs}")
|
||||
|
||||
wmma_sz = [prod(l) for l in tc.thread_local_sizes]
|
||||
def upcast_strides(buf:int):
|
||||
strides, next = [], 1
|
||||
for (sz, stride, reduce) in self.upcasted_axis(buf)[tc.num_upcasts():]:
|
||||
strides.append((0 if stride == 0 else next, sz))
|
||||
next *= 1 if stride == 0 else sz
|
||||
return strides
|
||||
upcasts, dev = [upcast_strides(x) for x in [locals_to_store[0][0], locals_to_store[1][0], 0]], self.opts.device
|
||||
for iter in [x[::-1] for x in itertools.product(*[x for x in [range(sz) for _,sz in upcasts[0]][::-1]])]:
|
||||
offs = [x*y for (x,y) in zip([sum([prod(x) for x in zip(iter, [stride for stride,_ in y])]) for y in upcasts], wmma_sz)]
|
||||
ops = (self.uops.add(UOps.CAST, tc.dtype_in.vec(wmma_sz[0]), tuple(locals_to_store[0][2][offs[0]:offs[0]+wmma_sz[0]])),
|
||||
self.uops.add(UOps.CAST, tc.dtype_in.vec(wmma_sz[1]), tuple(locals_to_store[1][2][offs[1]:offs[1]+wmma_sz[1]])),
|
||||
self.uops.add(UOps.CAST, (dt3:=tc.dtype_out.vec(wmma_sz[2])), tuple(op3:=acc[offs[2]:offs[2]+wmma_sz[2]])))
|
||||
ret = self.uops.add(UOps.WMMA, dt3, ops, (str(tc), tc.dims, tc.dtype_in, tc.dtype_out, tuple(map(prod, tc.thread_local_sizes)), dev))
|
||||
for z in range(wmma_sz[2]):
|
||||
acc[offs[2]+z] = self.uops.add(UOps.PHI, tc.dtype_out, (op3[z], self.uops.add(UOps.GEP, tc.dtype_out, (ret,), z)) + loop_ctx)
|
||||
else:
|
||||
assert not locals_to_store, "storing locals isn't supported here"
|
||||
|
||||
# load earlybufs
|
||||
loaded_buffers.update({b:self.global_load(self.bufs.index(self.local_alias[i]) if i in self.local_alias else i,
|
||||
global_idxs+local_idxs+reduce_idxs+full_upcast_idxs) for i,b in enumerate(self.bufs) if b in self.earlybufs})
|
||||
|
||||
# run early AST (with reduce)
|
||||
self.ast_parse(self.reduceop, acc, self.acc_offsets(self.full_buf_index), loaded_buffers, do_reduce=True, loop_ctx=loop_ctx)
|
||||
|
||||
# end the reduce loop
|
||||
self.load_cache.clear()
|
||||
|
||||
# end the local loop, do the local reduce
|
||||
if self.group_for_reduces:
|
||||
fake_global_idxs = [x*0 for x in global_idxs]
|
||||
stores = self.global_store(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, acc) # store accumulators
|
||||
barrier = self.uops.add(UOps.BARRIER, None, tuple(stores), cachable=False)
|
||||
if self.opts.has_local:
|
||||
fake_idxs = [NumNode(0)]*len(self.sts[-1].shape)
|
||||
fake_idxs[self.global_dims+self.local_dims:self.global_dims+len(local_idxs)] = local_idxs[self.local_dims:]
|
||||
if_cond: UOp = create_lt_node(self.sts[-1].expr_idxs(fake_idxs)[0], 1).render(self.render_ops, self)
|
||||
barrier = self.uops.add(UOps.IF, None, (if_cond, barrier), cachable=False)
|
||||
|
||||
# create new late reduce local loops and replace local_idxs that have been used
|
||||
end_local_idxs = [Variable(f"tidx{i}", 0, self.full_shape[i]-1 if i >= self.first_reduce and i not in self.upcast_in_mid_reduce_axes else 0) for i in range(0, self.first_reduce+self.group_for_reduces)] # noqa: E501
|
||||
local_idxs = local_idxs[:self.local_dims] + end_local_idxs[self.global_dims + self.local_dims:]
|
||||
|
||||
# if any group_for_reduce items aren't reduces, upcast them here
|
||||
for j in self.upcast_in_mid_reduce_axes:
|
||||
self.reshape_and_permute(None, [i for i in range(self.shape_len) if i != j] + [j])
|
||||
self.upcast()
|
||||
self.group_for_reduces -= 1
|
||||
local_idxs = local_idxs[:-1]
|
||||
end_local_idxs = end_local_idxs[:-1]
|
||||
# regenerate upcast_idxs
|
||||
upcast_idxs = [Variable(f"_uidx{i}", 0, s-1) for i, s in enumerate(self.output_shape[self.shape_len-self.upcasted:])]
|
||||
|
||||
# NOTE: this structure is the same as the reduce op above
|
||||
|
||||
# define late accumulator
|
||||
acc = self.global_load(0, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, self.get_reduce_acc(self.reduceop))
|
||||
|
||||
# late reduce loop
|
||||
loop_ctx = render_loop(end_local_idxs)
|
||||
|
||||
# load localbufs
|
||||
loaded_buffers[self.bufs[-1]] = self.global_load(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, barrier=barrier)
|
||||
|
||||
# there's no AST here (and there's no shape for the reduce LazyOp)
|
||||
self.ast_parse(LazyOp(self.reduceop.op, (LazyOp(BufferOps.LOAD, (), self.bufs[-1]),)), acc, self.acc_offsets(-1), loaded_buffers, do_reduce=True, loop_ctx=loop_ctx) # noqa: E501
|
||||
|
||||
# end the late reduce loop
|
||||
self.load_cache.clear()
|
||||
|
||||
# all local indices which were used for group_for_reduce are not valid any more and should be replaced with fake NumNode(0), since they have
|
||||
# been rewritten with fake end_local_idxs.
|
||||
local_idxs = local_idxs[:self.local_dims] + [NumNode(0) for i in range(self.group_for_reduces)]
|
||||
for reduceop in [self.reduceop] if self.reduceop is not None else []:
|
||||
acc,loaded_buffers,fake_reduce_idxs,local_idxs,upcast_idxs = self.render_reduceop(reduceop,loaded_buffers,global_idxs,local_idxs,upcast_idxs)
|
||||
|
||||
# load latebufs
|
||||
loaded_buffers.update({b:self.global_load(i, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs) \
|
||||
|
@ -405,7 +409,7 @@ class Linearizer(Kernel):
|
|||
self.applied_opts_cache = self.applied_opts[:]
|
||||
return self
|
||||
|
||||
def ast_parse(self, x:LazyOp, acc: List[UOp], offs:Optional[List[int]], loaded_buffers:Dict[Union[MemBuffer, ConstBuffer, LocalBuffer], List[UOp]], do_reduce=False, loop_ctx=tuple(), cache=None) -> List[UOp]: # noqa: E501
|
||||
def ast_parse(self, x:LazyOp, acc: List[UOp], offs:Optional[List[int]], loaded_buffers:Dict[Union[MemBuffer, ConstBuffer, LocalBuffer], List[UOp]], do_reduce=False, loop_ctx=tuple(), cache=None) -> List[UOp]: # noqa: E501
|
||||
if cache is None: cache = {}
|
||||
if x in cache: return cache[x]
|
||||
if x.op in BufferOps: return loaded_buffers[x.arg]
|
||||
|
|
Loading…
Reference in New Issue