support zero in shape (#2303)

* zero in shape start

* no assert for that

* if output size is 0, return without exec

* tweak

* strides

* reduce over non-zero

* shrink and expand

* fix import

* test_elementwise where

* cannot reshape from size 0 to size 1

* compiled backend reduce over 0

* zeros for numpy

* reduce over 0 and keepdim resulted in 1

* reduce empty set default values

* compare with same input

* pad test case

* cat test case

* torch does not support that?
This commit is contained in:
chenyu 2023-11-15 11:57:48 -05:00 committed by GitHub
parent f113a0b83b
commit 123a0b86b2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 142 additions and 29 deletions

View File

@ -110,8 +110,8 @@ backend_test.exclude('test_bernoulli_*')
backend_test.exclude('test_cumsum_*')
backend_test.exclude('test_det_*')
backend_test.exclude('test_tril_zero_cpu') # TODO: zero array support
backend_test.exclude('test_triu_zero_cpu') # TODO: zero array support
backend_test.exclude('test_tril_zero_cpu') # TODO: zero array tril support
backend_test.exclude('test_triu_zero_cpu') # TODO: zero array triu support
backend_test.exclude('test_col2im_*')
backend_test.exclude('test_hammingwindow_*')
@ -147,9 +147,7 @@ backend_test.exclude('test_resize_upsample_sizes_cubic_*') # unsure how to imple
# rest of the failing tests
backend_test.exclude('test_regex_*') # does not support string Tensors
backend_test.exclude('test_optional_has_element_empty_optional_input_cpu') # Attempts to create Tensor from None
backend_test.exclude('test_reshape_allowzero_reordered_cpu') # reshaping to shape with 0
backend_test.exclude('test_reduce_min_empty_set_cpu') # max a tensor with 0 in shape
backend_test.exclude('test_reduce_sum_empty_set_non_reduced_axis_zero_cpu') # reducing a tensor with 0 in shape
backend_test.exclude('test_reshape_allowzero_reordered_cpu') # reshaping to shape with 0, also allowzero
backend_test.exclude('test_resize_downsample_scales_linear_antialias_cpu') # antialias not implemented
backend_test.exclude('test_resize_downsample_sizes_linear_antialias_cpu') # antialias not implemented
backend_test.exclude('test_resize_tf_crop_and_resize_cpu') # unsure about fill value after clip
@ -174,20 +172,6 @@ if getenv('METAL'):
backend_test.exclude('test_maxpool_2d_pads_cpu')
backend_test.exclude('test_maxpool_2d_same_lower_cpu')
# compiled backends cannot reshape to or from 0
if getenv('LLVM') or getenv('GPU') or getenv('CLANG') or getenv('METAL') or getenv('CUDA'):
backend_test.exclude('test_slice_start_out_of_bounds_cpu')
backend_test.exclude('test_constantofshape_int_shape_zero_cpu')
backend_test.exclude('test_reduce_l1_empty_set_cpu')
backend_test.exclude('test_reduce_sum_empty_set_cpu')
backend_test.exclude('test_reduce_l1_empty_set_expanded_cpu')
backend_test.exclude('test_reduce_sum_square_empty_set_cpu')
backend_test.exclude('test_reduce_l2_empty_set_cpu')
backend_test.exclude('test_reduce_sum_square_empty_set_expanded_cpu')
backend_test.exclude('test_reduce_l2_empty_set_expanded_cpu')
backend_test.exclude('test_reduce_log_sum_empty_set_cpu')
backend_test.exclude('test_reduce_log_sum_empty_set_expanded_cpu')
if getenv('GPU') or getenv('METAL'):
backend_test.exclude('test_mish_cpu') # weird inaccuracy
backend_test.exclude('test_mish_expanded_cpu') # weird inaccuracy

View File

@ -1157,8 +1157,7 @@ class TestOps(unittest.TestCase):
with self.assertRaises(AssertionError):
x.repeat((2, 4))
with self.assertRaises(AssertionError):
x.repeat((2, 0, 4))
np.testing.assert_allclose(x.repeat((2, 0, 4)).numpy(), Tensor.zeros(8, 0, 12).numpy())
def test_clip(self):
helper_test_op([(45,65)], lambda x: x.clip(-2.3, 1.2), lambda x: x.clip(-2.3, 1.2))

View File

@ -262,5 +262,127 @@ class TestTinygrad(unittest.TestCase):
# force device copy - to() is opt'd away - Tensor(dev)/1 is ignored
np.testing.assert_allclose(ua_arr, (Tensor(ua_arr)/Tensor(1)).numpy())
class TestZeroShapeTensor(unittest.TestCase):
def test_shape_stride(self):
t = Tensor.rand(3, 2, 0)
assert t.shape == (3, 2, 0)
# numpy has stride 0, 0, 0; torch has stride 2, 1, 1
assert t.lazydata.st.real_strides() == (0, 0, 1)
t = Tensor.rand(3, 0, 2)
assert t.shape == (3, 0, 2)
# numpy has stride 0, 0, 0; torch has stride 2, 2, 1
assert t.lazydata.st.real_strides() == (0, 2, 1)
t = Tensor.rand(0, 0, 0)
assert t.shape == (0, 0, 0)
# numpy has stride 0, 0, 0; torch has stride 1, 1, 1
assert t.lazydata.st.real_strides() == (0, 0, 1)
def test_rand(self):
t = Tensor.rand(3, 2, 0)
assert t.shape == (3, 2, 0)
np.testing.assert_equal(t.numpy(), np.zeros((3, 2, 0)))
t = Tensor.rand(0)
assert t.shape == (0,)
np.testing.assert_equal(t.numpy(), np.zeros((0,)))
t = Tensor.rand(0, 0, 0)
assert t.shape == (0, 0, 0)
np.testing.assert_equal(t.numpy(), np.zeros((0, 0, 0)))
def test_full(self):
t = Tensor.zeros(3, 2, 0)
assert t.shape == (3, 2, 0)
np.testing.assert_equal(t.numpy(), np.zeros((3, 2, 0)))
t = Tensor.full((3, 2, 0), 12)
assert t.shape == (3, 2, 0)
np.testing.assert_equal(t.numpy(), np.full((3, 2, 0), 12))
def test_reshape(self):
t = Tensor.zeros(3, 2, 0)
a = t.reshape(7, 0)
assert a.shape == (7, 0)
np.testing.assert_equal(a.numpy(), np.zeros((7, 0)))
with self.assertRaises(AssertionError):
# cannot reshape from size 0 to size 1
a = t.reshape(())
def test_expand(self):
t = Tensor.full((3, 2, 0), 12).expand((6, 2, 0))
assert t.shape == (6, 2, 0)
np.testing.assert_equal(t.numpy(), np.full((6, 2, 0), 12))
def test_pad(self):
t = Tensor.rand(3, 2, 0).pad((None, None, (1, 1)), 1)
assert t.shape == (3, 2, 2)
np.testing.assert_equal(t.numpy(), np.ones((3, 2, 2)))
if Device.DEFAULT != "TORCH":
# torch does not support padding non-zero dim with 0-size. torch.nn.functional.pad(torch.zeros(3,2,0), [0,0,0,4,0,0])
t = Tensor.rand(3, 2, 0).pad((None, (1, 1), None), 1)
assert t.shape == (3, 4, 0)
np.testing.assert_equal(t.numpy(), np.ones((3, 4, 0)))
t = Tensor.rand(3, 2, 0).pad(((1, 1), None, None), 1)
assert t.shape == (5, 2, 0)
np.testing.assert_equal(t.numpy(), np.ones((5, 2, 0)))
def test_shrink_into_zero(self):
t = Tensor.rand(3, 4).realize()
assert t.shrink((None, (2, 2))).realize().shape == (3, 0)
assert t.shrink(((2, 2), None)).realize().shape == (0, 4)
assert t.shrink(((2, 2), (2, 2))).realize().shape == (0, 0)
def test_cat(self):
s = Tensor.rand(3, 2, 2)
t = Tensor.rand(3, 2, 0).cat(s, dim=2)
assert t.shape == (3, 2, 2)
np.testing.assert_equal(t.numpy(), s.numpy())
if Device.DEFAULT != "TORCH":
# torch does not support padding non-zero dim with 0-size. torch.nn.functional.pad(torch.zeros(3,2,0), [0,0,0,4,0,0])
s = Tensor.rand(3, 4, 0)
t = Tensor.rand(3, 2, 0).cat(s, dim=1)
assert t.shape == (3, 6, 0)
np.testing.assert_equal(t.numpy(), np.zeros((3, 6, 0)))
def test_elementwise(self):
a = Tensor.rand(3, 2, 0)
a_exp = a.exp()
assert a_exp.shape == (3, 2, 0)
np.testing.assert_equal(a_exp.numpy(), np.exp(a.numpy()))
b = Tensor.rand(3, 2, 0)
assert b.shape == (3, 2, 0)
ab = a * b
assert ab.shape == (3, 2, 0)
np.testing.assert_equal(ab.numpy(), a.numpy() * b.numpy())
# NOTE: cannot compare with a constant to construct the mask because 0-dim tensor is not broadcastable
mask = (Tensor.rand(3, 2, 0) > Tensor.rand(3, 2, 0))
assert mask.shape == (3, 2, 0)
c = mask.where(a, b)
assert c.shape == (3, 2, 0)
np.testing.assert_equal(c.numpy(), np.where(mask.numpy(), a.numpy(), b.numpy()))
def test_reduce_over_non_zero(self):
a = Tensor.ones(3, 2, 0).sum(axis=1)
assert a.shape == (3, 0)
np.testing.assert_equal(a.numpy(), np.sum(np.zeros((3, 2, 0)), axis=1))
def test_reduce_over_zero(self):
a = Tensor.ones(3, 2, 0).sum(axis=2)
assert a.shape == (3, 2)
np.testing.assert_equal(a.numpy(), np.sum(np.zeros((3, 2, 0)), axis=2))
a = Tensor.ones(3, 2, 0).sum(axis=2, keepdim=True)
assert a.shape == (3, 2, 1)
np.testing.assert_equal(a.numpy(), np.sum(np.zeros((3, 2, 0)), axis=2, keepdims=True))
def test_reduce_default(self):
np.testing.assert_equal(Tensor([]).max().numpy(), -float("inf"))
np.testing.assert_equal(Tensor([]).min().numpy(), float("inf"))
np.testing.assert_equal(Tensor([]).sum().numpy(), 0)
if __name__ == '__main__':
unittest.main()

View File

@ -244,7 +244,7 @@ class LazyBuffer:
return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), ReduceOps, LazyOp(op, srcs, unbound_new_shape), self.dtype)
def r(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[sint, ...]) -> LazyBuffer:
if not all_int(self.shape) or prod(self.shape) // prod(new_shape) < getenv("REDUCEOP_SPLIT_THRESHOLD", 32768): return self._reduce_op(op, new_shape) # The amount of work should be big enough to take the benefit of "2 kernels" approach.
if not all_int(self.shape) or (0 in self.shape) or prod(self.shape) // prod(new_shape) < getenv("REDUCEOP_SPLIT_THRESHOLD", 32768): return self._reduce_op(op, new_shape) # The amount of work should be big enough to take the benefit of "2 kernels" approach.
heuristic, divisor, dim_to_split = max(((divisor := math.gcd(256, old))/(stride or math.inf), divisor, i) for i, (old, new, stride) in enumerate(zip(self.shape, new_shape, self.st.real_strides())) if old != new) # type: ignore
if divisor < 16 or heuristic < 0.1: return self._reduce_op(op, new_shape) # Choose largest divisor (>=16) to split on, penalize large strides.
def splitted_shape(dim_aft_div): return self.shape[:dim_to_split] + (self.shape[dim_to_split]//divisor,) + dim_aft_div + self.shape[dim_to_split+1:]

View File

@ -329,7 +329,7 @@ class Compiled:
# if it's aliased, don't use it
# NOTE: this is pretty wrong actually, who knows where else this buffer is used?
output.realized = output.output_buffer
if output.realized:
if output.realized is not None:
for i,a in enumerate(inputs):
# TODO: if this is contiguous it's fine
if a.realized == output.realized:
@ -338,8 +338,9 @@ class Compiled:
break
# we don't have an output buffer, we have to create it, and create to max size if it has symbolic shape
if not output.realized:
if output.realized is None:
output.realized = self.buffer(prod((s if isinstance(s, int) else s.max for s in output.shape)), output.dtype, **kwargs)
if output.realized.size == 0: return output.realized
# all the rawbuffers
rawbuffers = [output.realized] + [x.realized for x in inputs]

View File

@ -3,7 +3,7 @@ import functools, operator
from dataclasses import dataclass
from typing import Tuple, List, Optional, Dict, cast
from tinygrad.helpers import prod, all_int, dedup
from tinygrad.shape.symbolic import Node, NumNode, Variable, VariableOrNum, is_sym_int, sint
from tinygrad.shape.symbolic import Node, NumNode, Variable, VariableOrNum, sint
@functools.lru_cache(maxsize=None)
def filter_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> Tuple[int, ...]:
@ -71,7 +71,10 @@ class View:
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
def expand(self, new_shape: Tuple[sint, ...]) -> View:
assert len(new_shape) == len(self.shape)
assert all(is_sym_int(x) and (s == x or (s == 1 and st == 0)) for s,x,st in zip(self.shape, new_shape, self.strides)), f"can't expand {self.shape} into {new_shape}"
if 0 in self.shape:
assert all((s == x == 0) or (s > 0 and (x % s) == 0) for s,x in zip(self.shape, new_shape)), f"can't expand {self.shape} into {new_shape}"
return View.create(new_shape)
assert all((s == x or (s == 1 and st == 0)) for s,x,st in zip(self.shape, new_shape, self.strides)), f"can't expand {self.shape} into {new_shape}"
# NOTE: can the mask ever be (0,0)?
mask = tuple([(((0,0) if m != (0,1) else (0,ns)) if s != ns else m) for m,s,ns in zip(self.mask, self.shape, new_shape)]) if self.mask else None
return View.create(new_shape, self.strides, self.offset, mask)
@ -96,7 +99,10 @@ class View:
def reshape(self, new_shape: Tuple[sint, ...]) -> Optional[View]:
if self.shape == new_shape: return self
assert all(is_sym_int(x) and x > 0 for x in new_shape), f"shape must be symbolic ints and can't contain 0 or negative numbers {new_shape}"
assert all(x >= 0 for x in new_shape), f"shape can't contain negative numbers {new_shape}"
if 0 in self.shape:
assert 0 in new_shape, f"cannot reshape 0 size to {new_shape}"
return View.create(new_shape)
# check for the same size
if all_int(self.shape):
if all_int(new_shape):

View File

@ -248,7 +248,6 @@ class Tensor:
# ***** movement mlops *****
def reshape(self, shape, *args) -> Tensor:
new_shape = argfix(shape, *args)
assert 0 not in new_shape, f"zeros not allowed in shape {new_shape}"
return mlops.Reshape.apply(self, shape=tuple([-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape]))
def expand(self, shape, *args) -> Tensor: return mlops.Expand.apply(self, shape=tuple([x if x != -1 else s for s,x in zip(self.shape, argfix(shape, *args))]))
def permute(self, order, *args) -> Tensor: return mlops.Permute.apply(self, order=argfix(order, *args))
@ -430,7 +429,8 @@ class Tensor:
def _reduce(self, fxn:Type[Function], axis:Optional[Union[int, Tuple[int, ...]]]=None, keepdim=False) -> Tensor:
axis_: List[int] = list(range(len(self.shape))) if axis is None else ([axis] if isinstance(axis, int) else list(axis))
axis_ = [x if x >= 0 else x+len(self.shape) for x in axis_]
shape = [s for i,s in enumerate(self.shape) if i not in axis_]
shape = tuple(s for i,s in enumerate(self.shape) if i not in axis_)
if 0 in self.shape and 0 not in shape: return Tensor.full(tuple(1 if s == 0 else s for s in self.shape) if keepdim else shape, {mlops.Sum: 0, mlops.Max: -float("inf")}[fxn])
ret = fxn.apply(self, new_shape=tuple([1 if i in axis_ else s for i,s in enumerate(self.shape)]))
return ret if keepdim else ret.reshape(shape=shape)
@ -684,6 +684,7 @@ class Tensor:
def minimum(self, x:Union[Tensor, float]) -> Tensor: return -((-self).maximum(-x))
def where(self:Tensor, input_:Union[Tensor, float], other:Union[Tensor, float]):
if 0 in self.shape: return self
x_,y = self._broadcasted(input_)
x,z = x_._broadcasted(other)
return mlops.Where.apply(x, *y._broadcasted(z))