clean up to_shape_strides (#2402)

This commit is contained in:
chenyu 2023-11-23 13:04:00 -05:00 committed by GitHub
parent e4026dc197
commit 64aa2f4156
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 4 additions and 7 deletions

View File

@ -12,13 +12,10 @@ from tinygrad.shape.view import View
def to_shape_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> Tuple[Tuple[int, int], ...]:
assert len(shape) == len(strides)
ret = [(shape[0], strides[0])] if shape else []
for i in range(1, len(shape)):
if ret[-1][1] == shape[i]*strides[i] or ret[-1][0] == 1:
ret[-1] = (ret[-1][0] * shape[i], strides[i])
elif shape[i] == 1:
continue
else:
ret.append((shape[i], strides[i]))
for s,st in zip(shape[1:], strides[1:]):
ps,pst = ret[-1]
if pst == s*st or ps == 1: ret[-1] = (ps*s, st)
elif s != 1: ret.append((s, st))
return tuple(ret)
def expr_node_mask(view:View, idx:Node, valid:Optional[Node]=None) -> Node: