diff --git a/tinygrad/shape/shapetracker.py b/tinygrad/shape/shapetracker.py index 417f1441..0c39f22d 100644 --- a/tinygrad/shape/shapetracker.py +++ b/tinygrad/shape/shapetracker.py @@ -12,13 +12,10 @@ from tinygrad.shape.view import View def to_shape_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> Tuple[Tuple[int, int], ...]: assert len(shape) == len(strides) ret = [(shape[0], strides[0])] if shape else [] - for i in range(1, len(shape)): - if ret[-1][1] == shape[i]*strides[i] or ret[-1][0] == 1: - ret[-1] = (ret[-1][0] * shape[i], strides[i]) - elif shape[i] == 1: - continue - else: - ret.append((shape[i], strides[i])) + for s,st in zip(shape[1:], strides[1:]): + ps,pst = ret[-1] + if pst == s*st or ps == 1: ret[-1] = (ps*s, st) + elif s != 1: ret.append((s, st)) return tuple(ret) def expr_node_mask(view:View, idx:Node, valid:Optional[Node]=None) -> Node: