mirror of https://github.com/commaai/tinygrad.git
linearizer: refactor to define accs with potentially TC-modified idxs (#4211)
This commit is contained in:
parent
39b60a25f0
commit
126826afc8
|
@ -173,36 +173,28 @@ class Linearizer(Kernel):
|
|||
|
||||
def render_reduceop(self, reduceop: LazyOp, loaded_buffers:Dict[Union[MemBuffer, ConstBuffer, LocalBuffer], List[UOp]], \
|
||||
global_idxs, local_idxs, upcast_idxs):
|
||||
# define indecies
|
||||
# define indicies
|
||||
full_upcast_idxs = [Variable(f"_uidx{i}", 0, s-1) for i, s in enumerate(self.full_shape[self.shape_len-self.upcasted:])]
|
||||
reduce_idxs = [Variable(f"ridx{i}", 0, self.full_shape[i]-1) for i in range(self.first_reduce+self.group_for_reduces, self.shape_len-self.upcasted)] # noqa: E501
|
||||
fake_reduce_idxs = [x*0 for x in reduce_idxs]
|
||||
|
||||
# define accumulator
|
||||
out_buf = -1 if self.group_for_reduces else 0
|
||||
acc = self.global_load(out_buf, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, self.get_reduce_acc(reduceop))
|
||||
def calc_tc_idxs(local_sizes: List[int], aliases: List[List[int]]):
|
||||
replace_idxs, thread_idxs, thread_idx = [], [], Variable("_uidx_tc", 0, prod(local_sizes)-1)
|
||||
for s in local_sizes:
|
||||
thread_idxs.append(thread_idx % s)
|
||||
thread_idx //= s
|
||||
for alias in aliases:
|
||||
full_var, full_var_sz = NumNode(0), 1
|
||||
if alias[0] != 0:
|
||||
for i in alias:
|
||||
next_var = local_idxs[-i] if i > 0 else thread_idxs[-i-1]
|
||||
full_var += next_var * full_var_sz
|
||||
full_var_sz *= next_var.max+1
|
||||
replace_idxs.append(full_var)
|
||||
return replace_idxs
|
||||
|
||||
# reduce loop
|
||||
loop_ctx = self.render_loop(reduce_idxs)
|
||||
|
||||
if (tc:=self.tensor_core):
|
||||
def calc_tc_idxs(local_sizes: List[int], aliases: List[List[int]]):
|
||||
replace_idxs, thread_idxs, thread_idx = [], [], Variable("_uidx_tc", 0, prod(local_sizes)-1)
|
||||
for s in local_sizes:
|
||||
thread_idxs.append(thread_idx % s)
|
||||
thread_idx //= s
|
||||
for alias in aliases:
|
||||
full_var, full_var_sz = NumNode(0), 1
|
||||
if alias[0] != 0:
|
||||
for i in alias:
|
||||
next_var = local_idxs[-i] if i > 0 else thread_idxs[-i-1]
|
||||
full_var += next_var * full_var_sz
|
||||
full_var_sz *= next_var.max+1
|
||||
replace_idxs.append(full_var)
|
||||
return replace_idxs
|
||||
|
||||
# compute local aliases
|
||||
locals_to_store = []
|
||||
# compute local aliases - modify idxs if necessary for TC
|
||||
alias_buf_idxs = []
|
||||
for i in self.local_alias:
|
||||
localbuf_idx = self.bufs.index(self.local_alias[i])
|
||||
buf_idxs = [idx*0 if s == 0 else idx for idx,s in zip(global_idxs+local_idxs+reduce_idxs+full_upcast_idxs,self.sts[i].real_strides())]
|
||||
|
@ -214,10 +206,10 @@ class Linearizer(Kernel):
|
|||
for n in range(tc.num_upcasts()):
|
||||
buf_idxs[self.shape_len-self.upcasted+n] = replace_input_idxs[len(tc.threads)+n] # replace upcasts
|
||||
if DEBUG >= 3: print(f"{localbuf_idx} alias {i}: sts={self.sts[i]} idxs={buf_idxs}")
|
||||
ll = self.global_load(i, buf_idxs)
|
||||
locals_to_store.append((localbuf_idx, buf_idxs, ll))
|
||||
alias_buf_idxs.append((i, localbuf_idx, buf_idxs,))
|
||||
|
||||
# copy in any global buffers
|
||||
# define accumulator - modify idxs if necessary for TC
|
||||
out_buf = -1 if self.group_for_reduces else 0
|
||||
if (tc:=self.tensor_core):
|
||||
replace_acc_idxs = calc_tc_idxs(tc.thread_local_sizes[2], tc.thread_local_aliases[2])
|
||||
for n in range(len(tc.threads)):
|
||||
|
@ -225,7 +217,16 @@ class Linearizer(Kernel):
|
|||
for n in range(len(replace_acc_idxs)-len(tc.threads)):
|
||||
upcast_idxs[n] = replace_acc_idxs[len(tc.threads)+n] # replace upcasts
|
||||
if DEBUG >= 3: print(f"store alias: sts={self.sts[0]} idxs={global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs}")
|
||||
acc = self.global_load(out_buf, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, self.get_reduce_acc(reduceop))
|
||||
|
||||
# reduce loop
|
||||
loop_ctx = self.render_loop(reduce_idxs)
|
||||
|
||||
# store local aliases
|
||||
locals_to_store = [(localbuf_idx, buf_idxs, self.global_load(i, buf_idxs)) for i, localbuf_idx, buf_idxs in alias_buf_idxs]
|
||||
|
||||
if (tc:=self.tensor_core):
|
||||
# run tensor cores AST
|
||||
wmma_sz = [prod(l) for l in tc.thread_local_sizes]
|
||||
def upcast_strides(buf:int):
|
||||
strides, next = [], 1
|
||||
|
@ -240,7 +241,7 @@ class Linearizer(Kernel):
|
|||
self.uops.add(UOps.CAST, tc.dtype_in.vec(wmma_sz[1]), tuple(locals_to_store[1][2][offs[1]:offs[1]+wmma_sz[1]])),
|
||||
self.uops.add(UOps.CAST, (dt3:=tc.dtype_out.vec(wmma_sz[2])), tuple(op3:=acc[offs[2]:offs[2]+wmma_sz[2]])))
|
||||
ret = self.uops.add(UOps.WMMA, dt3, ops, (str(tc), tc.dims, tc.dtype_in, tc.dtype_out, tuple(map(prod, tc.thread_local_sizes)), dev))
|
||||
for z in range(wmma_sz[2]):
|
||||
for z in range(wmma_sz[2]): # TODO: don't need to DEFINE_ACC, pass to WMMA in op3, or PHI accs that are not valid
|
||||
acc[offs[2]+z] = self.uops.add(UOps.PHI, tc.dtype_out, (op3[z], self.uops.add(UOps.GEP, tc.dtype_out, (ret,), z)) + loop_ctx)
|
||||
else:
|
||||
assert not locals_to_store, "storing locals isn't supported here"
|
||||
|
|
Loading…
Reference in New Issue