more typing in linearizer uoping utils (#4929)

* type check everything

* idxs will be uops
This commit is contained in:
qazal 2024-06-12 23:00:02 +08:00 committed by GitHub
parent 828c98d5c4
commit 898430c004
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 5 additions and 3 deletions

View File

@ -218,7 +218,8 @@ class Linearizer(Kernel):
return alias_buf_idxs
def render_reduceop(self, reduceop:LazyOp, accs:Dict[LazyOp, List[UOp]], loaded_buffers:Dict[Union[MemBuffer, ConstBuffer, LocalBuffer], List[UOp]],
global_idxs, local_idxs, upcast_idxs, full_upcast_idxs, reduce_idxs, fake_reduce_idxs, alias_buf_idxs):
global_idxs, local_idxs, upcast_idxs, full_upcast_idxs, reduce_idxs, fake_reduce_idxs,
alias_buf_idxs:List[Tuple[int, int, List]]) -> Tuple[List[NumNode|Variable], List[NumNode|Variable]]:
# reduce loop
loop_ctx = self.render_loop(reduce_idxs, 2)
@ -313,7 +314,7 @@ class Linearizer(Kernel):
return local_idxs[:self.local_dims] + [NumNode(0) for _ in range(self.group_for_reduces)], upcast_idxs
kernel_cnt: Final[DefaultDict[str, int]] = defaultdict(int)
def linearize(self):
def linearize(self) -> Linearizer:
# no new opts and we already ran? skip relinearizing
if self.applied_opts == self.applied_opts_cache: return self
@ -419,7 +420,8 @@ class Linearizer(Kernel):
return self
def render_block(self, outputs:Tuple[LazyOp, ...], global_idxs, local_idxs, upcast_idxs, full_upcast_idxs,
alias_buf_idxs, loaded_buffers, accs) -> List[List[UOp]]:
alias_buf_idxs:DefaultDict[LazyOp,List[Tuple[int,int,List[NumNode|Variable]]]],
loaded_buffers:Dict[Union[MemBuffer, ConstBuffer, LocalBuffer], List[UOp]], accs:Dict[LazyOp,List[UOp]]) -> List[List[UOp]]:
reduceops = dedup(x for x in outputs if x.op in ReduceOps)
assert len(reduceops) <= 1, "max one reduceop per block"
reduce_idxs = [Variable(f"ridx{i}", 0, self.full_shape[i]-1) for i in range(self.first_reduce+self.group_for_reduces, self.shape_len-self.upcasted)] # noqa: E501