broadcast expand to match torch (#4085)

* initial version

* heh gimme grrrreen

* version 2

* clean ups

* some test confusion

* fix onnx

* rename to _broadcast_tensors

* improved errors and test

* fixed?

* some test fixup

* version 3 lol

* comments

* cleaner

* add failure test for expand to 0 test

* 1 more assertRaises test

* make err msg better

* also rewrite the expand onnx op? :s
This commit is contained in:
geohotstan 2024-04-08 04:23:13 +08:00 committed by GitHub
parent 2b81d9b334
commit 183708b3fd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 26 additions and 20 deletions

View File

@ -1,7 +1,7 @@
import functools, io, math
from typing import Union, Tuple, Optional, List, Any
from tinygrad import Tensor, dtypes
from tinygrad.dtype import ImageDType
from tinygrad.tensor import Tensor, broadcast_shape
from tinygrad.dtype import ImageDType, dtypes
from tinygrad.helpers import prod, flatten
from extra.onnx import safe_numpy, DTYPE_MAP
import numpy as np
@ -82,6 +82,7 @@ def Size(data: Tensor): return prod(data if isinstance(data, list) else data.sha
def Flatten(x: Tensor, axis=1): return x.reshape(prod(x.shape[0:axis]), -1)
def Reshape(data: Tensor, shape: Tensor, allowzero=0):
return data.reshape([int(x) if x != 0 else (0 if allowzero else data.shape[i]) for i,x in enumerate(safe_numpy(shape))])
def Expand(x: Tensor, shape:Tensor): return x.expand(broadcast_shape(x.shape, tuple(int(x) for x in safe_numpy(shape))))
def Shrink(x: Tensor, bias=0.0, lambd=0.5): return (x < -lambd)*(x+bias) + (x > lambd)*(x-bias)
def And(x:Tensor, y:Tensor): return (x==y).where(x, False)
def Or(x:Tensor, y:Tensor): return (x==y).where(x, True)
@ -135,14 +136,6 @@ def ConstantOfShape(x, value:Tensor=None):
shape = [int(x) for x in safe_numpy(x)]
return Tensor.ones(*shape, dtype=value.dtype) * (value if shape[0]!=0 else 1)
# TODO: abstract out the broadcast logic in tensor
def Expand(x: Tensor, shape):
x_shape, y_shape = x.shape, [int(x) for x in safe_numpy(shape)]
# copied from _broadcasted
x_shape, y_shape = [([1]*(max(len(x_shape), len(y_shape))-len(t_shape)) + list(t_shape)) for t_shape in [x_shape, y_shape]]
shape_ret = tuple(max(sx, sy) for sx,sy in zip(x_shape, y_shape))
return x.reshape(x_shape).expand(shape_ret)
# **************** Complex Ops ****************
def Gemm(A: Tensor, B: Tensor, C: Tensor=None, alpha=1.0, beta=1.0, transA=0, transB=0, broadcast=0):

View File

@ -1089,8 +1089,17 @@ class TestOps(unittest.TestCase):
def test_expand(self):
helper_test_op([(4,3,1,6)], lambda x: x.expand((4,3,2,6)))
helper_test_op([(1,1,1,1)], lambda x: x.expand((4,3,2,6)))
helper_test_op([(4,3,1,6)], lambda x: x.expand((6,1,4,3,2,6)))
helper_test_op([(4,3,1,6)], lambda x: x.expand((0,1,4,3,2,6)))
helper_test_op([(4,3,1,6)], lambda x: x.expand((4,3,0,6)))
helper_test_op([()], lambda x: x.expand((4,3,2,6)))
helper_test_op([()], lambda x: x.expand([]))
with self.assertRaises((ValueError, RuntimeError)): Tensor.ones(4,3,1,6).expand(4,1,1,6)
with self.assertRaises((ValueError, RuntimeError)): Tensor.ones(4,3,1,6).expand(4,6,1,6)
with self.assertRaises((ValueError, RuntimeError)): Tensor.ones(4,3,1,6).expand(3,1,6)
with self.assertRaises((ValueError, RuntimeError)): Tensor.ones(4,3,2,6).expand(4,3,0,6)
@unittest.skip("very slow")
def test_sd_big_conv(self):
# internal shape (1, 1, 512, 62, 62, 512, 3, 3) overflows a int

View File

@ -422,7 +422,7 @@ class TestZeroShapeTensor(unittest.TestCase):
a = t.reshape(())
def test_expand(self):
t = Tensor.full((3, 2, 0), 12).expand((6, 2, 0))
t = Tensor.full((1, 2, 0), 12).expand((6, 2, 0))
assert t.shape == (6, 2, 0)
np.testing.assert_equal(t.numpy(), np.full((6, 2, 0), 12))

View File

@ -67,6 +67,9 @@ def _apply_winograd_matrix(mat, t:Tensor, dims:int) -> Tensor:
assert isinstance(ret, Tensor), "sum didn't return a Tensor"
return ret
def _pad_left(*shps:Tuple[sint, ...], v=1): return tuple((v,) * (max(len(i_) for i_ in shps) - len(i)) + i for i in shps)
def broadcast_shape(*shps:Tuple[sint, ...]): return tuple(0 if any(sh_ == 0 for sh_ in sh) else max(sh) for sh in zip(*_pad_left(*shps)))
class Tensor:
__slots__ = "lazydata", "requires_grad", "grad", "_ctx"
__deletable__ = ('_ctx',)
@ -372,8 +375,7 @@ class Tensor:
new_shape = tuple([-prod(self.shape) // prod(new_shape) if s == -1 else (s if s is not None else self.shape[i]) for i,s in enumerate(new_shape)])
return F.Reshape.apply(self, shape=new_shape) if new_shape != self.shape else self
def expand(self, shape, *args) -> Tensor:
new_shape = tuple([x if x != -1 and x is not None else s for s,x in zip(self.shape, argfix(shape, *args))])
return F.Expand.apply(self, shape=new_shape) if new_shape != self.shape else self
return self._broadcast_to(tuple(sh if s==-1 or s is None else s for s, sh in zip(*(_pad_left(argfix(shape, *args), self.shape)))))
def permute(self, order, *args) -> Tensor: return F.Permute.apply(self, order=argfix(order, *args))
def flip(self, axis, *args) -> Tensor: return F.Flip.apply(self, axis=[x if x >= 0 else x+len(self.shape) for x in argfix(axis, *args)])
def shrink(self, arg:Tuple[Optional[Tuple[sint, sint]], ...]) -> Tensor:
@ -500,7 +502,7 @@ class Tensor:
# iteratively eq -> mul -> sum fancy index
try:
for a,i,sd in zip(arange, reshaped_idx, sum_dim): ret = (a==i).mul(ret).sum(sd)
except AssertionError as exc: raise IndexError("cannot broadcast indices") from exc
except ValueError as exc: raise IndexError("cannot broadcast indices") from exc
# special permute case
if first_dim != 0 and len(idx) != 1 and tuple(idx.keys()) != tuple(range(first_dim, last_dim+1)):
@ -860,6 +862,11 @@ class Tensor:
def softsign(self): return self / (1 + self.abs())
# ***** broadcasted elementwise mlops *****
def _broadcast_to(self, shape:Tuple[sint, ...]):
reshape_arg, _ = _pad_left(self.shape, shape)
if self.ndim > len(shape) or not all(sh in {s,1} or (s==0 and sh==1) for sh,s in zip(reshape_arg, shape)):
raise ValueError(f"cannot broadcast tensor with shape={self.shape} to {shape=}")
return F.Expand.apply(self.reshape(reshape_arg), shape=shape) if shape != self.shape else self
def _broadcasted(self, y:Union[Tensor, ConstType], reverse:bool=False, match_dtype:bool=True) -> Tuple[Tensor, Tensor]:
x: Tensor = self
@ -876,12 +883,9 @@ class Tensor:
if reverse: x, y = y, x
# left pad shape with 1s
if len(y.shape) < len(x.shape): y = y.reshape((1,) * (len(x.shape) - len(y.shape)) + y.shape)
elif len(x.shape) < len(y.shape): x = x.reshape((1,) * (len(y.shape) - len(x.shape)) + x.shape)
broadcasted_shape = tuple(0 if xi==0 or yi==0 else max(xi, yi) for xi, yi in zip(x.shape, y.shape))
return x.expand(broadcasted_shape), y.expand(broadcasted_shape)
# broadcast
out_shape = broadcast_shape(x.shape, y.shape)
return x._broadcast_to(out_shape), y._broadcast_to(out_shape)
def _to_const_val(self, x:Union[Tensor, ConstType]) -> Union[Tensor, ConstType]:
# TODO: update with multi