mirror of https://github.com/commaai/tinygrad.git
clean up to_shape_strides (#2402)
This commit is contained in:
parent
e4026dc197
commit
64aa2f4156
|
@ -12,13 +12,10 @@ from tinygrad.shape.view import View
|
|||
def to_shape_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> Tuple[Tuple[int, int], ...]:
|
||||
assert len(shape) == len(strides)
|
||||
ret = [(shape[0], strides[0])] if shape else []
|
||||
for i in range(1, len(shape)):
|
||||
if ret[-1][1] == shape[i]*strides[i] or ret[-1][0] == 1:
|
||||
ret[-1] = (ret[-1][0] * shape[i], strides[i])
|
||||
elif shape[i] == 1:
|
||||
continue
|
||||
else:
|
||||
ret.append((shape[i], strides[i]))
|
||||
for s,st in zip(shape[1:], strides[1:]):
|
||||
ps,pst = ret[-1]
|
||||
if pst == s*st or ps == 1: ret[-1] = (ps*s, st)
|
||||
elif s != 1: ret.append((s, st))
|
||||
return tuple(ret)
|
||||
|
||||
def expr_node_mask(view:View, idx:Node, valid:Optional[Node]=None) -> Node:
|
||||
|
|
Loading…
Reference in New Issue