mirror of https://github.com/commaai/tinygrad.git
broadcast expand to match torch (#4085)
* initial version * heh gimme grrrreen * version 2 * clean ups * some test confusion * fix onnx * rename to _broadcast_tensors * improved errors and test * fixed? * some test fixup * version 3 lol * comments * cleaner * add failure test for expand to 0 test * 1 more assertRaises test * make err msg better * also rewrite the expand onnx op? :s
This commit is contained in:
parent
2b81d9b334
commit
183708b3fd
|
@ -1,7 +1,7 @@
|
|||
import functools, io, math
|
||||
from typing import Union, Tuple, Optional, List, Any
|
||||
from tinygrad import Tensor, dtypes
|
||||
from tinygrad.dtype import ImageDType
|
||||
from tinygrad.tensor import Tensor, broadcast_shape
|
||||
from tinygrad.dtype import ImageDType, dtypes
|
||||
from tinygrad.helpers import prod, flatten
|
||||
from extra.onnx import safe_numpy, DTYPE_MAP
|
||||
import numpy as np
|
||||
|
@ -82,6 +82,7 @@ def Size(data: Tensor): return prod(data if isinstance(data, list) else data.sha
|
|||
def Flatten(x: Tensor, axis=1): return x.reshape(prod(x.shape[0:axis]), -1)
|
||||
def Reshape(data: Tensor, shape: Tensor, allowzero=0):
|
||||
return data.reshape([int(x) if x != 0 else (0 if allowzero else data.shape[i]) for i,x in enumerate(safe_numpy(shape))])
|
||||
def Expand(x: Tensor, shape:Tensor): return x.expand(broadcast_shape(x.shape, tuple(int(x) for x in safe_numpy(shape))))
|
||||
def Shrink(x: Tensor, bias=0.0, lambd=0.5): return (x < -lambd)*(x+bias) + (x > lambd)*(x-bias)
|
||||
def And(x:Tensor, y:Tensor): return (x==y).where(x, False)
|
||||
def Or(x:Tensor, y:Tensor): return (x==y).where(x, True)
|
||||
|
@ -135,14 +136,6 @@ def ConstantOfShape(x, value:Tensor=None):
|
|||
shape = [int(x) for x in safe_numpy(x)]
|
||||
return Tensor.ones(*shape, dtype=value.dtype) * (value if shape[0]!=0 else 1)
|
||||
|
||||
# TODO: abstract out the broadcast logic in tensor
|
||||
def Expand(x: Tensor, shape):
|
||||
x_shape, y_shape = x.shape, [int(x) for x in safe_numpy(shape)]
|
||||
# copied from _broadcasted
|
||||
x_shape, y_shape = [([1]*(max(len(x_shape), len(y_shape))-len(t_shape)) + list(t_shape)) for t_shape in [x_shape, y_shape]]
|
||||
shape_ret = tuple(max(sx, sy) for sx,sy in zip(x_shape, y_shape))
|
||||
return x.reshape(x_shape).expand(shape_ret)
|
||||
|
||||
# **************** Complex Ops ****************
|
||||
|
||||
def Gemm(A: Tensor, B: Tensor, C: Tensor=None, alpha=1.0, beta=1.0, transA=0, transB=0, broadcast=0):
|
||||
|
|
|
@ -1089,8 +1089,17 @@ class TestOps(unittest.TestCase):
|
|||
def test_expand(self):
|
||||
helper_test_op([(4,3,1,6)], lambda x: x.expand((4,3,2,6)))
|
||||
helper_test_op([(1,1,1,1)], lambda x: x.expand((4,3,2,6)))
|
||||
helper_test_op([(4,3,1,6)], lambda x: x.expand((6,1,4,3,2,6)))
|
||||
helper_test_op([(4,3,1,6)], lambda x: x.expand((0,1,4,3,2,6)))
|
||||
helper_test_op([(4,3,1,6)], lambda x: x.expand((4,3,0,6)))
|
||||
helper_test_op([()], lambda x: x.expand((4,3,2,6)))
|
||||
helper_test_op([()], lambda x: x.expand([]))
|
||||
|
||||
with self.assertRaises((ValueError, RuntimeError)): Tensor.ones(4,3,1,6).expand(4,1,1,6)
|
||||
with self.assertRaises((ValueError, RuntimeError)): Tensor.ones(4,3,1,6).expand(4,6,1,6)
|
||||
with self.assertRaises((ValueError, RuntimeError)): Tensor.ones(4,3,1,6).expand(3,1,6)
|
||||
with self.assertRaises((ValueError, RuntimeError)): Tensor.ones(4,3,2,6).expand(4,3,0,6)
|
||||
|
||||
@unittest.skip("very slow")
|
||||
def test_sd_big_conv(self):
|
||||
# internal shape (1, 1, 512, 62, 62, 512, 3, 3) overflows a int
|
||||
|
|
|
@ -422,7 +422,7 @@ class TestZeroShapeTensor(unittest.TestCase):
|
|||
a = t.reshape(())
|
||||
|
||||
def test_expand(self):
|
||||
t = Tensor.full((3, 2, 0), 12).expand((6, 2, 0))
|
||||
t = Tensor.full((1, 2, 0), 12).expand((6, 2, 0))
|
||||
assert t.shape == (6, 2, 0)
|
||||
np.testing.assert_equal(t.numpy(), np.full((6, 2, 0), 12))
|
||||
|
||||
|
|
|
@ -67,6 +67,9 @@ def _apply_winograd_matrix(mat, t:Tensor, dims:int) -> Tensor:
|
|||
assert isinstance(ret, Tensor), "sum didn't return a Tensor"
|
||||
return ret
|
||||
|
||||
def _pad_left(*shps:Tuple[sint, ...], v=1): return tuple((v,) * (max(len(i_) for i_ in shps) - len(i)) + i for i in shps)
|
||||
def broadcast_shape(*shps:Tuple[sint, ...]): return tuple(0 if any(sh_ == 0 for sh_ in sh) else max(sh) for sh in zip(*_pad_left(*shps)))
|
||||
|
||||
class Tensor:
|
||||
__slots__ = "lazydata", "requires_grad", "grad", "_ctx"
|
||||
__deletable__ = ('_ctx',)
|
||||
|
@ -372,8 +375,7 @@ class Tensor:
|
|||
new_shape = tuple([-prod(self.shape) // prod(new_shape) if s == -1 else (s if s is not None else self.shape[i]) for i,s in enumerate(new_shape)])
|
||||
return F.Reshape.apply(self, shape=new_shape) if new_shape != self.shape else self
|
||||
def expand(self, shape, *args) -> Tensor:
|
||||
new_shape = tuple([x if x != -1 and x is not None else s for s,x in zip(self.shape, argfix(shape, *args))])
|
||||
return F.Expand.apply(self, shape=new_shape) if new_shape != self.shape else self
|
||||
return self._broadcast_to(tuple(sh if s==-1 or s is None else s for s, sh in zip(*(_pad_left(argfix(shape, *args), self.shape)))))
|
||||
def permute(self, order, *args) -> Tensor: return F.Permute.apply(self, order=argfix(order, *args))
|
||||
def flip(self, axis, *args) -> Tensor: return F.Flip.apply(self, axis=[x if x >= 0 else x+len(self.shape) for x in argfix(axis, *args)])
|
||||
def shrink(self, arg:Tuple[Optional[Tuple[sint, sint]], ...]) -> Tensor:
|
||||
|
@ -500,7 +502,7 @@ class Tensor:
|
|||
# iteratively eq -> mul -> sum fancy index
|
||||
try:
|
||||
for a,i,sd in zip(arange, reshaped_idx, sum_dim): ret = (a==i).mul(ret).sum(sd)
|
||||
except AssertionError as exc: raise IndexError("cannot broadcast indices") from exc
|
||||
except ValueError as exc: raise IndexError("cannot broadcast indices") from exc
|
||||
|
||||
# special permute case
|
||||
if first_dim != 0 and len(idx) != 1 and tuple(idx.keys()) != tuple(range(first_dim, last_dim+1)):
|
||||
|
@ -860,6 +862,11 @@ class Tensor:
|
|||
def softsign(self): return self / (1 + self.abs())
|
||||
|
||||
# ***** broadcasted elementwise mlops *****
|
||||
def _broadcast_to(self, shape:Tuple[sint, ...]):
|
||||
reshape_arg, _ = _pad_left(self.shape, shape)
|
||||
if self.ndim > len(shape) or not all(sh in {s,1} or (s==0 and sh==1) for sh,s in zip(reshape_arg, shape)):
|
||||
raise ValueError(f"cannot broadcast tensor with shape={self.shape} to {shape=}")
|
||||
return F.Expand.apply(self.reshape(reshape_arg), shape=shape) if shape != self.shape else self
|
||||
|
||||
def _broadcasted(self, y:Union[Tensor, ConstType], reverse:bool=False, match_dtype:bool=True) -> Tuple[Tensor, Tensor]:
|
||||
x: Tensor = self
|
||||
|
@ -876,12 +883,9 @@ class Tensor:
|
|||
|
||||
if reverse: x, y = y, x
|
||||
|
||||
# left pad shape with 1s
|
||||
if len(y.shape) < len(x.shape): y = y.reshape((1,) * (len(x.shape) - len(y.shape)) + y.shape)
|
||||
elif len(x.shape) < len(y.shape): x = x.reshape((1,) * (len(y.shape) - len(x.shape)) + x.shape)
|
||||
|
||||
broadcasted_shape = tuple(0 if xi==0 or yi==0 else max(xi, yi) for xi, yi in zip(x.shape, y.shape))
|
||||
return x.expand(broadcasted_shape), y.expand(broadcasted_shape)
|
||||
# broadcast
|
||||
out_shape = broadcast_shape(x.shape, y.shape)
|
||||
return x._broadcast_to(out_shape), y._broadcast_to(out_shape)
|
||||
|
||||
def _to_const_val(self, x:Union[Tensor, ConstType]) -> Union[Tensor, ConstType]:
|
||||
# TODO: update with multi
|
||||
|
|
Loading…
Reference in New Issue