diff --git a/extra/onnx_ops.py b/extra/onnx_ops.py index ddf2fae0..e3b2582f 100644 --- a/extra/onnx_ops.py +++ b/extra/onnx_ops.py @@ -1,7 +1,7 @@ import functools, io, math from typing import Union, Tuple, Optional, List, Any -from tinygrad import Tensor, dtypes -from tinygrad.dtype import ImageDType +from tinygrad.tensor import Tensor, broadcast_shape +from tinygrad.dtype import ImageDType, dtypes from tinygrad.helpers import prod, flatten from extra.onnx import safe_numpy, DTYPE_MAP import numpy as np @@ -82,6 +82,7 @@ def Size(data: Tensor): return prod(data if isinstance(data, list) else data.sha def Flatten(x: Tensor, axis=1): return x.reshape(prod(x.shape[0:axis]), -1) def Reshape(data: Tensor, shape: Tensor, allowzero=0): return data.reshape([int(x) if x != 0 else (0 if allowzero else data.shape[i]) for i,x in enumerate(safe_numpy(shape))]) +def Expand(x: Tensor, shape:Tensor): return x.expand(broadcast_shape(x.shape, tuple(int(x) for x in safe_numpy(shape)))) def Shrink(x: Tensor, bias=0.0, lambd=0.5): return (x < -lambd)*(x+bias) + (x > lambd)*(x-bias) def And(x:Tensor, y:Tensor): return (x==y).where(x, False) def Or(x:Tensor, y:Tensor): return (x==y).where(x, True) @@ -135,14 +136,6 @@ def ConstantOfShape(x, value:Tensor=None): shape = [int(x) for x in safe_numpy(x)] return Tensor.ones(*shape, dtype=value.dtype) * (value if shape[0]!=0 else 1) -# TODO: abstract out the broadcast logic in tensor -def Expand(x: Tensor, shape): - x_shape, y_shape = x.shape, [int(x) for x in safe_numpy(shape)] - # copied from _broadcasted - x_shape, y_shape = [([1]*(max(len(x_shape), len(y_shape))-len(t_shape)) + list(t_shape)) for t_shape in [x_shape, y_shape]] - shape_ret = tuple(max(sx, sy) for sx,sy in zip(x_shape, y_shape)) - return x.reshape(x_shape).expand(shape_ret) - # **************** Complex Ops **************** def Gemm(A: Tensor, B: Tensor, C: Tensor=None, alpha=1.0, beta=1.0, transA=0, transB=0, broadcast=0): diff --git a/test/test_ops.py b/test/test_ops.py index aab8d928..603cf640 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1089,8 +1089,17 @@ class TestOps(unittest.TestCase): def test_expand(self): helper_test_op([(4,3,1,6)], lambda x: x.expand((4,3,2,6))) helper_test_op([(1,1,1,1)], lambda x: x.expand((4,3,2,6))) + helper_test_op([(4,3,1,6)], lambda x: x.expand((6,1,4,3,2,6))) + helper_test_op([(4,3,1,6)], lambda x: x.expand((0,1,4,3,2,6))) + helper_test_op([(4,3,1,6)], lambda x: x.expand((4,3,0,6))) + helper_test_op([()], lambda x: x.expand((4,3,2,6))) helper_test_op([()], lambda x: x.expand([])) + with self.assertRaises((ValueError, RuntimeError)): Tensor.ones(4,3,1,6).expand(4,1,1,6) + with self.assertRaises((ValueError, RuntimeError)): Tensor.ones(4,3,1,6).expand(4,6,1,6) + with self.assertRaises((ValueError, RuntimeError)): Tensor.ones(4,3,1,6).expand(3,1,6) + with self.assertRaises((ValueError, RuntimeError)): Tensor.ones(4,3,2,6).expand(4,3,0,6) + @unittest.skip("very slow") def test_sd_big_conv(self): # internal shape (1, 1, 512, 62, 62, 512, 3, 3) overflows a int diff --git a/test/test_tensor.py b/test/test_tensor.py index 6da4d176..eb39c363 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -422,7 +422,7 @@ class TestZeroShapeTensor(unittest.TestCase): a = t.reshape(()) def test_expand(self): - t = Tensor.full((3, 2, 0), 12).expand((6, 2, 0)) + t = Tensor.full((1, 2, 0), 12).expand((6, 2, 0)) assert t.shape == (6, 2, 0) np.testing.assert_equal(t.numpy(), np.full((6, 2, 0), 12)) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index f0e406aa..26eca8ed 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -67,6 +67,9 @@ def _apply_winograd_matrix(mat, t:Tensor, dims:int) -> Tensor: assert isinstance(ret, Tensor), "sum didn't return a Tensor" return ret +def _pad_left(*shps:Tuple[sint, ...], v=1): return tuple((v,) * (max(len(i_) for i_ in shps) - len(i)) + i for i in shps) +def broadcast_shape(*shps:Tuple[sint, ...]): return tuple(0 if any(sh_ == 0 for sh_ in sh) else max(sh) for sh in zip(*_pad_left(*shps))) + class Tensor: __slots__ = "lazydata", "requires_grad", "grad", "_ctx" __deletable__ = ('_ctx',) @@ -372,8 +375,7 @@ class Tensor: new_shape = tuple([-prod(self.shape) // prod(new_shape) if s == -1 else (s if s is not None else self.shape[i]) for i,s in enumerate(new_shape)]) return F.Reshape.apply(self, shape=new_shape) if new_shape != self.shape else self def expand(self, shape, *args) -> Tensor: - new_shape = tuple([x if x != -1 and x is not None else s for s,x in zip(self.shape, argfix(shape, *args))]) - return F.Expand.apply(self, shape=new_shape) if new_shape != self.shape else self + return self._broadcast_to(tuple(sh if s==-1 or s is None else s for s, sh in zip(*(_pad_left(argfix(shape, *args), self.shape))))) def permute(self, order, *args) -> Tensor: return F.Permute.apply(self, order=argfix(order, *args)) def flip(self, axis, *args) -> Tensor: return F.Flip.apply(self, axis=[x if x >= 0 else x+len(self.shape) for x in argfix(axis, *args)]) def shrink(self, arg:Tuple[Optional[Tuple[sint, sint]], ...]) -> Tensor: @@ -500,7 +502,7 @@ class Tensor: # iteratively eq -> mul -> sum fancy index try: for a,i,sd in zip(arange, reshaped_idx, sum_dim): ret = (a==i).mul(ret).sum(sd) - except AssertionError as exc: raise IndexError("cannot broadcast indices") from exc + except ValueError as exc: raise IndexError("cannot broadcast indices") from exc # special permute case if first_dim != 0 and len(idx) != 1 and tuple(idx.keys()) != tuple(range(first_dim, last_dim+1)): @@ -860,6 +862,11 @@ class Tensor: def softsign(self): return self / (1 + self.abs()) # ***** broadcasted elementwise mlops ***** + def _broadcast_to(self, shape:Tuple[sint, ...]): + reshape_arg, _ = _pad_left(self.shape, shape) + if self.ndim > len(shape) or not all(sh in {s,1} or (s==0 and sh==1) for sh,s in zip(reshape_arg, shape)): + raise ValueError(f"cannot broadcast tensor with shape={self.shape} to {shape=}") + return F.Expand.apply(self.reshape(reshape_arg), shape=shape) if shape != self.shape else self def _broadcasted(self, y:Union[Tensor, ConstType], reverse:bool=False, match_dtype:bool=True) -> Tuple[Tensor, Tensor]: x: Tensor = self @@ -876,12 +883,9 @@ class Tensor: if reverse: x, y = y, x - # left pad shape with 1s - if len(y.shape) < len(x.shape): y = y.reshape((1,) * (len(x.shape) - len(y.shape)) + y.shape) - elif len(x.shape) < len(y.shape): x = x.reshape((1,) * (len(y.shape) - len(x.shape)) + x.shape) - - broadcasted_shape = tuple(0 if xi==0 or yi==0 else max(xi, yi) for xi, yi in zip(x.shape, y.shape)) - return x.expand(broadcasted_shape), y.expand(broadcasted_shape) + # broadcast + out_shape = broadcast_shape(x.shape, y.shape) + return x._broadcast_to(out_shape), y._broadcast_to(out_shape) def _to_const_val(self, x:Union[Tensor, ConstType]) -> Union[Tensor, ConstType]: # TODO: update with multi