Zero dim Tensor support (#777)

* add and reorganize test_slice_* tests

* refactor Tensor.__getitem__()

* preliminary tests for 1) 0D tensors and 2) varargs for Tensor.zeros and Tensor.ones

* always compare shapes of the numpy arrays obtained from tinygrad and torch tensors

* add more tests for 0D support

* remove test_tensor.test_slicing(). All slicing tests at test/test_ops.py

* add zero-dim support

* make test_end2end.py consistent with 0dim support

* add test for tensor with zero in shape

* don't simplify ones if shape is ()

* skip tests that need zero-size tensor support.

- zero-size tensor support not related to 0dim tensors.

* add tests for __getitem__() supporting strides >= 1

* refactor __getitem__: support for strides >= 1

* minor refactors and add comments to __getitem__

* add tests for slices with negative steps

* add support for slices with negative strides
This commit is contained in:
Joqsan 2023-06-01 21:32:02 +03:00 committed by GitHub
parent ae83e9844c
commit ef129bcb85
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 234 additions and 67 deletions

View File

@ -22,14 +22,14 @@ def compare_tiny_torch(model, model_torch, X, Y):
out = model(X)
loss = (out * Y).mean()
print(loss.realize().numpy()[0])
print(loss.realize().numpy())
out_torch = model_torch(torch.Tensor(X.numpy()))
loss_torch = (out_torch * torch.Tensor(Y.numpy())).mean()
print(loss_torch.detach().numpy())
# assert losses match
np.testing.assert_allclose(loss.realize().numpy()[0], loss_torch.detach().numpy(), atol=1e-4)
np.testing.assert_allclose(loss.realize().numpy(), loss_torch.detach().numpy(), atol=1e-4)
# zero and backward
optimizer.zero_grad()

View File

@ -15,7 +15,7 @@ def helper_test_op(shps, torch_fxn, tinygrad_fxn=None, atol=1e-6, rtol=1e-3, gra
if shps is None:
ts = [torch.tensor(x, requires_grad=True) for x in vals]
else:
ts = [torch.tensor((np.random.random(size=x).astype(np.float32)+a)*b, requires_grad=True) for x in shps]
ts = [torch.tensor((np.random.random(size=x)+a)*b, requires_grad=True, dtype=torch.float32) for x in shps]
tst = [Tensor(x.detach().numpy(), requires_grad=not FORWARD_ONLY) for x in ts]
@ -29,7 +29,7 @@ def helper_test_op(shps, torch_fxn, tinygrad_fxn=None, atol=1e-6, rtol=1e-3, gra
def compare(s, x,y,atol,rtol):
if PRINT_TENSORS: print(s, x, y)
if y.shape != tuple(): assert x.shape == y.shape, f"shape mismatch (tinygrad){x.shape} != (torch){y.shape}"
assert x.shape == y.shape, f"shape mismatch: tinygrad={x.shape} | torch={y.shape}"
try:
np.testing.assert_allclose(x,y, atol=atol, rtol=rtol)
except Exception:
@ -62,6 +62,8 @@ class TestOps(unittest.TestCase):
helper_test_op([], lambda: torch.full((45,65), 4), lambda: Tensor.full((45,65), 4), forward_only=True)
def test_zeros(self):
helper_test_op([], lambda: torch.zeros(45,65), lambda: Tensor.zeros(45,65), forward_only=True)
helper_test_op([], lambda: torch.zeros([45,65]), lambda: Tensor.zeros([45,65]), forward_only=True)
helper_test_op([], lambda: torch.zeros([]), lambda: Tensor.zeros([]), forward_only=True)
def test_zeros_like(self):
a = Tensor([[1,2,3],[4,5,6]])
b = torch.tensor([[1,2,3],[4,5,6]])
@ -70,12 +72,15 @@ class TestOps(unittest.TestCase):
helper_test_op([], lambda: torch.empty(45,65)*0/0, lambda: Tensor.empty(45,65)*0/0, forward_only=True)
def test_ones(self):
helper_test_op([], lambda: torch.ones(45,65), lambda: Tensor.ones(45,65), forward_only=True)
helper_test_op([], lambda: torch.ones([45,65]), lambda: Tensor.ones([45,65]), forward_only=True)
helper_test_op([], lambda: torch.ones([]), lambda: Tensor.ones([]), forward_only=True)
def test_ones_like(self):
a = Tensor([[1,2,3],[4,5,6]])
b = torch.tensor([[1,2,3],[4,5,6]])
helper_test_op([], lambda: torch.ones_like(b), lambda: Tensor.ones_like(a), forward_only=True)
def test_eye(self):
helper_test_op([], lambda: torch.eye(10), lambda: Tensor.eye(10), forward_only=True)
def test_arange(self):
helper_test_op([], lambda: torch.arange(10), lambda: Tensor.arange(10), forward_only=True)
def test_where(self):
@ -121,43 +126,58 @@ class TestOps(unittest.TestCase):
def test_maximum(self):
helper_test_op([(45,65), (45,65)], torch.maximum, Tensor.maximum)
helper_test_op([(), ()], torch.maximum, Tensor.maximum)
helper_test_op(None, torch.maximum, Tensor.maximum, vals=[[1., 0., 3., 4.], [1., 2., 3., 0.]])
def test_minimum(self):
helper_test_op([(45,65), (45,65)], torch.minimum, Tensor.minimum)
helper_test_op([(), ()], torch.minimum, Tensor.minimum)
def test_add(self):
helper_test_op([(45,65), (45,65)], lambda x,y: x+y, Tensor.add)
helper_test_op([(), ()], lambda x,y: x+y, Tensor.add)
def test_add_simple(self):
helper_test_op([(256), (256)], lambda x,y: x+y, Tensor.add, forward_only=True)
def test_broadcasted_add(self):
helper_test_op([(45,65), (45,1)], lambda x,y: x+y, lambda x,y: x+y)
helper_test_op([(45,65), ()], lambda x,y: x+y, lambda x,y: x+y)
def test_broadcasted_add_2(self):
helper_test_op([(45,65), (65,)], lambda x,y: x+y, lambda x,y: x+y)
def test_sub(self):
helper_test_op([(45,65), (45,65)], lambda x,y: x-y, Tensor.sub)
helper_test_op([(), ()], lambda x,y: x-y, Tensor.sub)
def test_neg(self):
helper_test_op([(45,65)], lambda x: -x)
helper_test_op([()], lambda x: -x)
def test_mul(self):
helper_test_op([(64,64), (64,64)], lambda x,y: x*y, Tensor.mul)
helper_test_op([(), ()], lambda x,y: x*y, Tensor.mul)
def test_div(self):
helper_test_op([(45,65), (45,65)], lambda x,y: x/y, Tensor.div)
helper_test_op([(), ()], lambda x,y: x/y, Tensor.div)
def test_div_const(self):
helper_test_op([(45,65)], lambda x: x/255, lambda x: x/255)
helper_test_op([(45,65)], lambda x: x/1, lambda x: x/1)
helper_test_op([(45,65)], lambda x: 1/x, lambda x: 1/x)
helper_test_op([(45,65)], lambda x: x/2, lambda x: x/2)
helper_test_op([(45,65)], lambda x: 2/x, lambda x: 2/x)
helper_test_op([()], lambda x: x/2, lambda x: x/2)
helper_test_op([()], lambda x: 2/x, lambda x: 2/x)
def test_pow(self):
helper_test_op([(45,65)], lambda x: x**2, lambda x: Tensor.pow(x,2), a=0)
helper_test_op([(45,65)], lambda x: x**3, lambda x: Tensor.pow(x,3), a=0)
helper_test_op([(45,65)], lambda x: x**-2, lambda x: Tensor.pow(x,-2), a=0)
helper_test_op([(45,65), (45,65)], lambda x,y: x**y, Tensor.pow, a=0)
helper_test_op([()], lambda x: x**2, lambda x: Tensor.pow(x,2), a=0)
helper_test_op([()], lambda x: x**-2, lambda x: Tensor.pow(x,-2), a=0)
def test_pow_const(self):
helper_test_op([(45,65)], lambda x: x**1.0, lambda x: x**1.0)
helper_test_op([(45,65)], lambda x: 1.0**x, lambda x: 1.0**x)
helper_test_op([(45,65)], lambda x: x**2.0, lambda x: x**2.0)
helper_test_op([(45,65)], lambda x: 2.0**x, lambda x: 2.0**x)
helper_test_op([()], lambda x: x**2.0, lambda x: x**2.0)
helper_test_op([()], lambda x: 2.0**x, lambda x: 2.0**x)
def test_sqrt(self):
helper_test_op([(45,65)], lambda x: x.sqrt(), Tensor.sqrt, a=0)
helper_test_op([()], lambda x: x.sqrt(), Tensor.sqrt, a=0)
def test_sin(self):
helper_test_op([(45,65)], lambda x: x.sin(), Tensor.sin, a=0)
@ -168,47 +188,65 @@ class TestOps(unittest.TestCase):
def test_relu(self):
helper_test_op([(64,64)], lambda x: x.relu(), Tensor.relu)
helper_test_op([()], lambda x: x.relu(), Tensor.relu)
def test_relu_exact(self):
helper_test_op(None, lambda x: x.relu(), Tensor.relu, vals=[[-1.,0,1]])
def test_relu_maximum_exact(self):
helper_test_op(None, lambda x: torch.maximum(x, torch.zeros_like(x, requires_grad=False)), lambda x: Tensor.maximum(x, 0), vals=[[-1.,0,1]])
def test_leakyrelu(self):
helper_test_op([(45,65)], lambda x: torch.nn.functional.leaky_relu(x,0.01), Tensor.leakyrelu)
helper_test_op([()], lambda x: torch.nn.functional.leaky_relu(x,0.01), Tensor.leakyrelu)
def test_celu(self):
for val in range(1, 5):
helper_test_op([(45,65)], lambda x: torch.nn.functional.celu(x,val), lambda x: x.celu(val))
helper_test_op([()], lambda x: torch.nn.functional.celu(x,val), lambda x: x.celu(val))
def test_abs(self):
helper_test_op([(45,65)], lambda x: torch.abs(x), Tensor.abs)
helper_test_op([()], lambda x: torch.abs(x), Tensor.abs)
def test_log(self):
helper_test_op([(45,65)], lambda x: torch.log(x), Tensor.log)
helper_test_op([()], lambda x: torch.log(x), Tensor.log)
def test_exp(self):
helper_test_op([(45,65)], lambda x: torch.exp(x), Tensor.exp)
helper_test_op([()], lambda x: torch.exp(x), Tensor.exp)
def test_sign(self):
helper_test_op([(45,65)], lambda x: torch.sign(x), Tensor.sign)
helper_test_op([()], lambda x: torch.sign(x), Tensor.sign)
def test_softsign(self):
helper_test_op([(45,65)], lambda x: torch.nn.functional.softsign(x), Tensor.softsign)
helper_test_op([()], lambda x: torch.nn.functional.softsign(x), Tensor.softsign)
def test_sigmoid(self):
helper_test_op([(45,65)], lambda x: x.sigmoid(), Tensor.sigmoid)
helper_test_op([()], lambda x: x.sigmoid(), Tensor.sigmoid, forward_only=True)
def test_softplus(self):
helper_test_op([(45,65)], lambda x: torch.nn.functional.softplus(x), Tensor.softplus, atol=1e-6, grad_atol=1e-6)
helper_test_op([()], lambda x: torch.nn.functional.softplus(x), Tensor.softplus, atol=1e-6, grad_atol=1e-6)
@unittest.skip("not supported in older pytorch")
def test_gelu(self):
helper_test_op([(45,65)], lambda x: torch.nn.functional.gelu(x, approximate="tanh"), Tensor.gelu)
def test_quick_gelu(self):
helper_test_op([(45,65)], lambda x: x * torch.sigmoid(1.702 * x), Tensor.quick_gelu)
helper_test_op([()], lambda x: x * torch.sigmoid(1.702 * x), Tensor.quick_gelu)
def test_elu(self):
helper_test_op([(45,65)], lambda x: torch.nn.functional.elu(x), Tensor.elu)
helper_test_op([(45,65)], lambda x: torch.nn.functional.elu(x, alpha=0.1), lambda x: Tensor.elu(x, alpha=0.1))
helper_test_op([()], lambda x: torch.nn.functional.elu(x), Tensor.elu)
def test_relu6(self):
helper_test_op([(45,65)], lambda x: torch.nn.functional.relu6(x), Tensor.relu6)
helper_test_op([()], lambda x: torch.nn.functional.relu6(x), Tensor.relu6)
def test_hardswish(self):
helper_test_op([(45,65)], lambda x: torch.nn.functional.hardswish(x), Tensor.hardswish, atol=1e-6, grad_atol=1e-6)
helper_test_op([()], lambda x: torch.nn.functional.hardswish(x), Tensor.hardswish, atol=1e-6, grad_atol=1e-6)
def test_mish(self):
def _mish_pytorch(x):
return x*torch.tanh(torch.nn.functional.softplus(x))
helper_test_op([(45,65)], _mish_pytorch, Tensor.mish, atol=1e-4)
helper_test_op([()], _mish_pytorch, Tensor.mish, atol=1e-4)
def test_dot(self):
helper_test_op([(45,65), (65,100)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4)
with self.assertRaises(RuntimeError):
a = Tensor(3.14)
a.matmul(a)
def test_matmul_simple(self):
helper_test_op([(4), (4,4)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4)
def test_matmul(self):
@ -225,6 +263,10 @@ class TestOps(unittest.TestCase):
helper_test_op([(64,64), (64,64)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-3)
def test_broadcastdot(self):
helper_test_op([(10,45,65), (65,45)], lambda x,y: x @ y, Tensor.dot, atol=1e-4)
with self.assertRaises(RuntimeError):
a = Tensor(3.14)
b = Tensor.ones(3,3)
a @ b
def test_multidot(self):
helper_test_op([(10,45,65), (10,65,45)], lambda x,y: x @ y, Tensor.dot, atol=1e-4)
helper_test_op([(3,3,45,65), (3,3,65,45)], lambda x,y: x @ y, Tensor.dot, atol=1e-4)
@ -241,10 +283,12 @@ class TestOps(unittest.TestCase):
helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(0,2)), lambda x: Tensor.sum(x, axis=(0,2)))
helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(1,2)), lambda x: Tensor.sum(x, axis=(1,2)))
helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=1), lambda x: Tensor.sum(x, axis=1))
helper_test_op([()], lambda x: x.sum(), Tensor.sum)
def test_min(self):
helper_test_op([(3,3)], lambda x: x.min(), Tensor.min)
helper_test_op([(45,3)], lambda x: x.min(), Tensor.min)
helper_test_op([(45,3)], lambda x: x.min().mul(0.5), lambda x: Tensor.min(x).mul(0.5))
helper_test_op([()], lambda x: x.min(), Tensor.min)
def test_max(self):
helper_test_op([(45,3)], lambda x: x.max(), Tensor.max)
helper_test_op([(45,3)], lambda x: x.max().mul(0.5), lambda x: Tensor.max(x).mul(0.5))
@ -253,8 +297,10 @@ class TestOps(unittest.TestCase):
[[1.0,1.0,0.0,1.0]],
])
helper_test_op([(3,4,5,6)], lambda x: x.max(axis=1)[0], lambda x: Tensor.max(x, axis=1))
helper_test_op([()], lambda x: x.max(), Tensor.max)
def test_mean(self):
helper_test_op([(3,4,5,6)], lambda x: x.mean())
helper_test_op([()], lambda x: x.mean())
def test_mean_axis(self):
helper_test_op([(3,4,5,6)], lambda x: x.mean(axis=(1,2)), lambda x: Tensor.mean(x, axis=(1,2)))
def test_std(self):
@ -275,28 +321,34 @@ class TestOps(unittest.TestCase):
helper_test_op([(45, 65, 85)], lambda x: torch.std(x, keepdim=True, correction=0, dim=0), lambda x: Tensor.std(x, keepdim=True, correction=0, axis=0))
def test_log_softmax(self):
helper_test_op([(45,65)], lambda x: torch.nn.LogSoftmax(dim=1)(x), Tensor.log_softmax, atol=1e-7, grad_atol=1e-7)
helper_test_op([()], lambda x: torch.nn.LogSoftmax(dim=0)(x), Tensor.log_softmax, atol=1e-7, grad_atol=1e-7)
def test_log_softmax_other_axis(self):
helper_test_op([(10,10,10)], lambda x: x.log_softmax(0), lambda x: x.log_softmax(0), atol=1e-7, grad_atol=1e-7)
helper_test_op([(10,10,10)], lambda x: x.log_softmax(1), lambda x: x.log_softmax(1), atol=1e-7, grad_atol=1e-7)
helper_test_op([(10,10,10)], lambda x: x.log_softmax(2), lambda x: x.log_softmax(2), atol=1e-7, grad_atol=1e-7)
def test_tanh(self):
helper_test_op([(45,65)], lambda x: x.tanh(), Tensor.tanh, atol=1e-6, grad_atol=1e-6)
helper_test_op([()], lambda x: x.tanh(), Tensor.tanh, atol=1e-6, grad_atol=1e-6)
def test_hardtanh(self):
for val in range(10, 30, 5):
helper_test_op([(45,65)], lambda x: torch.nn.functional.hardtanh(x,-val, val), lambda x: x.hardtanh(-val, val), atol=1e-6, grad_atol=1e-6)
helper_test_op([()], lambda x: torch.nn.functional.hardtanh(x,-val, val), lambda x: x.hardtanh(-val, val), atol=1e-6, grad_atol=1e-6)
def test_topo_sort(self):
helper_test_op([(45,65)], lambda x: (x+x)*x, lambda x: x.add(x).mul(x), atol=1e-6, grad_atol=1e-6)
helper_test_op([()], lambda x: (x+x)*x, lambda x: x.add(x).mul(x), atol=1e-6, grad_atol=1e-6)
def test_scalar_mul(self):
helper_test_op([(45,65)], lambda x: x*2, lambda x: x*2)
helper_test_op([()], lambda x: x*2, lambda x: x*2)
def test_scalar_rmul(self):
helper_test_op([(45,65)], lambda x: 2*x, lambda x: 2*x)
helper_test_op([()], lambda x: 2*x, lambda x: 2*x)
def test_scalar_sub(self):
helper_test_op([(45,65)], lambda x: x-2, lambda x: x-2)
helper_test_op([()], lambda x: x-2, lambda x: x-2)
def test_scalar_rsub(self):
helper_test_op([(45,65)], lambda x: 2-x, lambda x: 2-x)
helper_test_op([()], lambda x: 2-x, lambda x: 2-x)
def test_flip_eye_crash(self):
helper_test_op([], lambda: (torch.eye(10)@torch.eye(10).flip(0)),
lambda: (Tensor.eye(10)@Tensor.eye(10).flip(0)), forward_only=True)
@ -310,6 +362,7 @@ class TestOps(unittest.TestCase):
def test_broadcast_simple(self):
helper_test_op([(45,65), (45,1)], lambda x,y: x/y, lambda x,y: x/y)
helper_test_op([(45,65), ()], lambda x,y: x/y, lambda x,y: x/y)
def test_broadcast_partial(self):
for torch_op, tinygrad_op in [(torch.add, Tensor.add), (torch.sub, Tensor.sub), (torch.mul, Tensor.mul),
@ -320,19 +373,83 @@ class TestOps(unittest.TestCase):
# NOTE: ANE backwards?
helper_test_op(shapes, torch_op, tinygrad_op, a=-0.5 if tinygrad_op != Tensor.pow else 0.0)
def test_slice_simple(self):
helper_test_op([(3,3)], lambda x: x[1:2, 1:2], lambda x: x[1:2, 1:2])
def test_slice_in_bounds_1dim(self):
helper_test_op([(3)], lambda x: x[1:3], lambda x: x[1:3])
helper_test_op([(3)], lambda x: x[0:2], lambda x: x[0:2])
helper_test_op([(3)], lambda x: x[-2:2], lambda x: x[-2:2])
def test_slice(self):
helper_test_op([(3,3,3,3)], lambda x: x[1:2], lambda x: x[1:2])
helper_test_op([(3,3,3,3)], lambda x: x[1:2, 1:2], lambda x: x[1:2, 1:2])
helper_test_op([(3,3,3,3)], lambda x: x[1:2, 1:2, 0:-1], lambda x: x[1:2, 1:2, 0:-1])
def test_slice_on_0dim_tensor(self):
helper_test_op([()], lambda x: x[None], lambda x: x[None])
def test_slice_one(self):
with self.assertRaises(IndexError):
a = Tensor(3.14)
a[0]
def test_slice_int_indexing(self):
helper_test_op([(3)], lambda x: x[1], lambda x: x[1])
def test_slice_one_multi(self):
helper_test_op([(3)], lambda x: x[-2], lambda x: x[-2])
helper_test_op([(10,10)], lambda x: x[1], lambda x: x[1])
helper_test_op([(3,3,3)], lambda x: x[1,1,1], lambda x: x[1,1,1])
def test_slice_in_bounds_multidim(self):
helper_test_op([(3,3,3)], lambda x: x[1:2], lambda x: x[1:2])
helper_test_op([(3,3,3)], lambda x: x[1:2, 2], lambda x: x[1:2, 2])
helper_test_op([(3,3,3)], lambda x: x[1:2, 1:2], lambda x: x[1:2, 1:2])
helper_test_op([(3,3,3)], lambda x: x[1:2, 1:2, 0:-1], lambda x: x[1:2, 1:2, 0:-1])
def test_slice_with_none(self):
helper_test_op([(3,3,3)], lambda x: x[None], lambda x: x[None])
helper_test_op([(3,3,3)], lambda x: x[1:2, None], lambda x: x[1:2, None])
helper_test_op([(3,3,3)], lambda x: x[1:2, None, 1:2], lambda x: x[1:2, None, 1:2])
helper_test_op([(3,3,3)], lambda x: x[1:2, 1:2, None, -1], lambda x: x[1:2, 1:2, None, -1])
def test_slice_one_endpoint_out_of_bounds(self):
helper_test_op([(3,3,3)], lambda x: x[0:4], lambda x: x[0:4])
helper_test_op([(3,3,3)], lambda x: x[-6:4], lambda x: x[-6:4])
helper_test_op([(3,3,3)], lambda x: x[1:50], lambda x: x[1:50])
helper_test_op([(3,3,3)], lambda x: x[1:50, 1:2, -1], lambda x: x[1:50, 1:2, -1])
def test_slice_stride_gt_one(self):
helper_test_op([(7,5,10)], lambda x: x[::2, ::3, ::4], lambda x: x[::2, ::3, ::4])
helper_test_op([(7,5,10)], lambda x: x[1:5:2, ::3, ::4], lambda x: x[1:5:2, ::3, ::4])
helper_test_op([(7,5,10)], lambda x: x[1:5:2, 3, ::4], lambda x: x[1:5:2, 3, ::4])
helper_test_op([(7,5,10)], lambda x: x[1:5:2, None, None, 3, None, ::4], lambda x: x[1:5:2, None, None, 3, None, ::4])
def test_slice_negative_strides(self):
# Torch doesn't support slicing with negative steps
a = np.random.randn(10, 10, 10).astype(np.float32)
t = Tensor(a)
np.testing.assert_allclose(a[::-1], t[::-1].numpy())
np.testing.assert_allclose(a[::-2], t[::-2].numpy())
np.testing.assert_allclose(a[:, 2:0:-1], t[:, 2:0:-1].numpy())
np.testing.assert_allclose(a[:, 2:0:-1, 3:1:-2], t[:, 2:0:-1, 3:1:-2].numpy())
np.testing.assert_allclose(a[4:0:-3, 2:0:-1, -1:-5:-2], t[4:0:-3, 2:0:-1, -1:-5:-2].numpy())
@unittest.skip("No suppport for tensors with 0s in shape")
def test_slice_both_endpoints_out_of_bounds(self):
helper_test_op([(3,3,3)], lambda x: x[5:10], lambda x: x[5:10], forward_only=True)
helper_test_op([(3,3,3)], lambda x: x[-15:-7], lambda x: x[-15:-7], forward_only=True)
@unittest.skip("No suppport for tensors with 0s in shape")
def test_slice_start_gt_end(self):
helper_test_op([(3,3,3)], lambda x: x[-2:2], lambda x: x[-2:2], forward_only=True)
helper_test_op([(3,3,3)], lambda x: x[-2:-5], lambda x: x[-2:-5], forward_only=True)
@unittest.skip("No suppport for tensors with 0s in shape")
def test_slice_empty(self):
helper_test_op([(10,10)], lambda x: x[1:1], lambda x: x[1:1], forward_only=True)
@unittest.skip("No suppport for tensors with 0s in shape")
def test_slice_zero_in_shape(self):
helper_test_op([(10,10)], lambda x: x[1:1], lambda x: x[1:1]) # x.shape = (0, 10)
helper_test_op([(3,3,3)], lambda x: x[-2:-5], lambda x: x[-2:-5]) # x.shape = (0, 3, 3)
def test_slice_errors(self):
a = Tensor.ones(4, 3)
with self.assertRaises(IndexError):
a[1, 77, 77, 77] # IndexError: (finds too many indices before the out of bounds)
a[1, 77] # IndexError: (out of bounds).
a[0, -77]
def test_pad2d(self):
helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4)), lambda x: x.pad2d(padding=(1,2,3,4)))
@ -342,10 +459,18 @@ class TestOps(unittest.TestCase):
helper_test_op([(3,3,3)], lambda x: x.transpose(0,2), lambda x: x.transpose(0,2))
helper_test_op([(1,2,3,4)], lambda x: x.movedim((3,0,2,1),(0,1,2,3)), lambda x: x.permute(order=(3,0,2,1)))
helper_test_op([(3,4,5,6)], lambda x: x.movedim((3,2,1,0),(0,1,2,3)), lambda x: x.permute(order=(3,2,1,0)))
helper_test_op([()], lambda x: x.permute(()), lambda x: x.permute(()))
def test_reshape(self):
helper_test_op([(4,3,6,6)], lambda x: torch.reshape(x, (-1,3,6,6)), lambda x: x.reshape(shape=(-1,3,6,6)))
helper_test_op([(4,3,6,6)], lambda x: torch.reshape(x, (-1,1,6,6)), lambda x: x.reshape(shape=(-1,1,6,6)))
helper_test_op([()], lambda x: torch.reshape(x, []), lambda x: x.reshape([]))
helper_test_op([(1,)], lambda x: torch.reshape(x, []), lambda x: x.reshape([]))
helper_test_op([()], lambda x: torch.reshape(x, [1]), lambda x: x.reshape([1]))
with self.assertRaises(AssertionError):
x = Tensor.ones((4,3,6,6))
x.reshape([])
def test_flip(self):
helper_test_op([(4,3,6,6)], lambda x: torch.flip(x, (0,)), lambda x: x.flip(axis=(0,)))
@ -354,23 +479,31 @@ class TestOps(unittest.TestCase):
helper_test_op([(4,3,6,6)], lambda x: torch.flip(x, (3,)), lambda x: x.flip(axis=(3,)))
helper_test_op([(4,3,6,6)], lambda x: torch.flip(x, (0,1,3)).flip((0,)), lambda x: x.flip(axis=(0,1,3)).flip(0))
helper_test_op([(4,3,6,6)], lambda x: torch.flip(x, (3,)), lambda x: x.flip(axis=(-1,)))
helper_test_op([()], lambda x: torch.flip(x, ()), lambda x: x.flip(axis=()))
helper_test_op([(1,)], lambda x: torch.flip(x, ()), lambda x: x.flip(axis=()))
helper_test_op([(4, 3, 6, 6)], lambda x: torch.flip(x, ()), lambda x: x.flip(axis=()))
def test_unsqueeze(self):
helper_test_op([(4,3,6,6)], lambda x: torch.unsqueeze(x, 0), lambda x: x.unsqueeze(dim=0))
helper_test_op([(4,3,6,6)], lambda x: torch.unsqueeze(x, 4), lambda x: x.unsqueeze(dim=4))
helper_test_op([(4,3,6,6)], lambda x: torch.unsqueeze(x, -1), lambda x: x.unsqueeze(dim=-1))
helper_test_op([(4,3,6,6)], lambda x: torch.unsqueeze(x, -3), lambda x: x.unsqueeze(dim=-3))
helper_test_op([()], lambda x: torch.unsqueeze(x, 0), lambda x: x.unsqueeze(dim=0))
def test_flatten(self):
for axis in range(3):
helper_test_op([(4,3,6,6)], lambda x: torch.flatten(x, start_dim=axis), lambda x: x.flatten(axis))
helper_test_op([()], lambda x: x.flatten(), lambda x: x.flatten())
helper_test_op([(1,)], lambda x: x.flatten(), lambda x: x.flatten())
def test_detach(self):
helper_test_op([(4,3,6,6)], lambda x: x.detach(), lambda x: x.detach(), forward_only=True)
helper_test_op([()], lambda x: x.detach(), lambda x: x.detach(), forward_only=True)
def test_expand(self):
arg = (4,3,2,6)
helper_test_op([(4,3,1,6)], lambda x: x.expand(arg), lambda x: x.expand(shape=arg))
helper_test_op([()], lambda x: x.expand([]), lambda x: x.expand(shape=[]))
@unittest.skip("very slow")
def test_sd_big_conv(self):
@ -695,6 +828,10 @@ class TestOps(unittest.TestCase):
for dim in range(-1, 2):
helper_test_op([(45,65), (45,65)], lambda x,y: torch.cat((x,y), dim), lambda x,y: x.cat(y, dim=dim))
with self.assertRaises(AssertionError):
a = Tensor(3.14)
a.cat(a)
def test_multicat(self):
for dim in range(-1, 2):
helper_test_op([(45,65), (45,65), (45,65)], lambda x,y,z: torch.cat((x,y,z), dim), lambda x,y,z: x.cat(y, z, dim=dim))
@ -707,6 +844,9 @@ class TestOps(unittest.TestCase):
with self.assertRaises(IndexError):
Tensor.stack([x], dim=77)
a = Tensor(3.14)
np.testing.assert_allclose(Tensor.stack([a, a]).numpy(), Tensor([3.14, 3.14]).numpy())
def test_repeat(self):
x = Tensor.randn(45, 65, 3)
@ -715,6 +855,7 @@ class TestOps(unittest.TestCase):
for reps in [[], [4], [2, 1], [3, 2, 2]]:
repeats = base_repeats + reps
helper_test_op([(45, 65, 3)], lambda x: x.repeat(*repeats), lambda x: x.repeat(repeats))
helper_test_op([()], lambda x: x.repeat(*repeats), lambda x: x.repeat(repeats))
with self.assertRaises(AssertionError):
x.repeat((2, 4))
@ -722,7 +863,6 @@ class TestOps(unittest.TestCase):
with self.assertRaises(AssertionError):
x.repeat((2, 0, 4))
def test_clip(self):
helper_test_op([(45,65)], lambda x: x.clip(-2.3, 1.2), lambda x: x.clip(-2.3, 1.2))

View File

@ -14,6 +14,13 @@ W_init = np.random.randn(3,3).astype(np.float32)
m_init = np.random.randn(1,3).astype(np.float32)
class TestTinygrad(unittest.TestCase):
def test_zerodim_initialization(self):
a = Tensor(55)
b = Tensor(3.14)
self.assertEqual(a.shape, ())
self.assertEqual(b.shape, ())
def test_plus_equals(self):
a = Tensor.randn(10,10)
b = Tensor.randn(10,10)
@ -23,20 +30,6 @@ class TestTinygrad(unittest.TestCase):
val2 = a.numpy()
np.testing.assert_allclose(val1, val2)
def test_slicing(self):
x = Tensor.randn(10,10)
slices = [0,1,9,-1,-10,None] + [slice(s,e) for s,e in itertools.combinations([0,1,-1,None], r=2)] + [slice(9,11), slice(-11,-9)]
fmt = lambda s: f'{s.start}:{s.stop}' if isinstance(s, slice) else str(s)
for s in list(itertools.product(slices, slices)) + [(None,0,None,0,None), (slice(0,2),None,None,slice(2,4),None,None)]:
np.testing.assert_equal(x.numpy()[s], x[s].numpy(), f'Test failed for slice x[{",".join(fmt(x) for x in s)}]')
for s in [-11,10]:
with self.assertRaises(IndexError):
x[s]
with self.assertRaises(AssertionError):
x[::2]
with self.assertRaises(AssertionError):
x[0,0,0]
def test_backward_pass(self):
def test_tinygrad():
x = Tensor(x_init, requires_grad=True)

View File

@ -414,6 +414,7 @@ class Linearizer:
def simplify_ones(self):
# remove places where the shape is all ones
# TODO: this should be factored in to multi shape stride
if self.shape_len == 0: return
all_ones = [all(st.shape[i]==1 for st in self.sts) for i in range(self.shape_len)]
# keep at least 1 one
if all(all_ones): all_ones[-1] = False

View File

@ -7,6 +7,7 @@ base_image_type = (100, 2, "imageh", np.float16) if FLOAT16 else (100, 4, "image
def image_dot(self, w):
# NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1)
if (n1:=len(self.shape))*(n2:=len(w.shape)) == 0: raise RuntimeError(f"both arguments to matmul need to be at least 1D, but they are {n1}D and {n2}D")
bs, groups = prod(self.shape[0:-2]), prod(w.shape[0:-2])
cin, cout = w.shape[-2], w.shape[-1]
out_shape_t = self.shape[0:-2] + (cout,-1)

View File

@ -74,13 +74,13 @@ class View:
@functools.lru_cache(maxsize=None)
def strides_for_shape(shape:Tuple[int, ...]) -> Tuple[int, ...]:
strides = [1]
strides = [1] if shape else []
for d in shape[::-1][:-1]: strides = [d*strides[0]] + strides
return tuple(st if s != 1 else 0 for st, s in zip(strides, shape))
@functools.lru_cache(maxsize=None)
def view_from_shape(shape:Tuple[int, ...]) -> View:
assert all(isinstance(x, int) for x in shape) and len(shape) != 0
assert all(isinstance(x, int) for x in shape)
return View(tuple(shape), strides_for_shape(shape))
@functools.lru_cache(maxsize=None)

View File

@ -1,6 +1,6 @@
# inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
from __future__ import annotations
import math, functools, itertools
import math, functools, itertools, operator
import numpy as np
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence
from tinygrad.helpers import prod, argfix, make_pair, getenv, IMAGE, DEBUG, flatten, DType, dtypes, LazyNumpyArray
@ -33,9 +33,9 @@ class Tensor:
no_grad: ClassVar[bool] = False
default_type: ClassVar[DType] = dtypes.float32
def __init__(self, data:Union[list, LazyBuffer, LazyNumpyArray, np.ndarray], device=Device.DEFAULT, dtype:Optional[DType]=None, requires_grad:Optional[bool]=None):
def __init__(self, data:Union[int, float, list, LazyBuffer, LazyNumpyArray, np.ndarray], device=Device.DEFAULT, dtype:Optional[DType]=None, requires_grad:Optional[bool]=None):
device = (device.split(":", 1)[0].upper() + ((":"+device.split(":", 1)[1]) if ':' in device else '')).replace(":0", "") # canonicalize device
if isinstance(data, list):
if isinstance(data, (int, float, list)):
data = np.array(data, dtype=(dtype if dtype is not None else Tensor.default_type).np)
elif isinstance(data, LazyBuffer) and data.device != device:
# TODO: this has to realize, it shouldn't have to
@ -47,7 +47,6 @@ class Tensor:
# by here, it's either LazyNumpyArray or LazyBuffer
# TODO: it should all be LazyBuffer I think
if isinstance(data, LazyNumpyArray):
data = data if data.shape else data.reshape((1,))
lazydata = LazyBuffer.fromCPU(data.astype(dtype.np) if dtype is not None else data, device)
elif isinstance(data, LazyBuffer):
assert dtype is None or dtype == data.dtype, "dtype doesn't match, and casting isn't supported"
@ -122,15 +121,13 @@ class Tensor:
# ***** creation helper functions *****
@staticmethod
def full(shape:Tuple[int, ...], fill_value, **kwargs):
new_shape = argfix(shape)
return Tensor([fill_value], **kwargs).reshape([1]*len(new_shape)).expand(new_shape).contiguous()
def full(shape:Tuple[int, ...], fill_value, **kwargs): return Tensor(fill_value, **kwargs).reshape([1]*len(new_shape := argfix(shape))).expand(new_shape).contiguous()
@staticmethod
def zeros(*shape, **kwargs): return Tensor.full(shape, 0, **kwargs)
def zeros(*shape, **kwargs): return Tensor.full(argfix(*shape), 0, **kwargs)
@staticmethod
def ones(*shape, **kwargs): return Tensor.full(shape, 1, **kwargs)
def ones(*shape, **kwargs): return Tensor.full(argfix(*shape), 1, **kwargs)
@staticmethod
def full_like(tensor, fill_value, dtype:Optional[DType]=None, **kwargs):
@ -203,11 +200,11 @@ class Tensor:
return _deepwalk(self, set(), [])
def backward(self):
assert self.shape == (1,)
assert self.shape == tuple(), f"backward can only be called for scalar tensors, but it has shape {self.shape})"
# fill in the first grad with one. don't use Tensor.ones because we don't need contiguous
# this is "implicit gradient creation"
self.grad = Tensor([1], device=self.device, requires_grad=False)
self.grad = Tensor(1, device=self.device, requires_grad=False)
for t0 in reversed(self.deepwalk()):
if not any(x.requires_grad for x in t0._ctx.parents):
@ -227,7 +224,7 @@ class Tensor:
def reshape(self, shape, *args) -> Tensor:
new_shape = argfix(shape, *args)
assert len(new_shape) > 0 and all(x != 0 for x in new_shape), f"zeros not allowed in shape {new_shape}"
assert all(x != 0 for x in new_shape), f"zeros not allowed in shape {new_shape}"
return mlops.Reshape.apply(self, shape=tuple(-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape))
def expand(self, shape, *args) -> Tensor: return mlops.Expand.apply(self, shape=tuple(x if x != -1 else s for s,x in zip(self.shape, argfix(shape, *args))))
def permute(self, order, *args) -> Tensor: return mlops.Permute.apply(self, order=argfix(order, *args))
@ -243,37 +240,71 @@ class Tensor:
padding = tuple((max(0, -p[0]), max(0, p[1]-self.shape[i])) for i,p in enumerate(arg_))
return self.pad(padding).shrink(tuple((p[0] + padding[i][0], p[1] + padding[i][0]) for i,p in enumerate(arg_)))
# Tensors mostly follow the normal python indexing / slicing behavior for sequences
# - Negative indices are taken relative to the end of the sequence, so X[-2] returns the 2nd-to-last element
# - A slice i:j returns the elements with indices in [i, j)
# - If omitted, i and j will default to 0 and N, respectively, where N is the length of the sequence
# - Negative values for i and j are taken relative to the end of the sequence
# - Both i and j will be clamped to the range (-N, N], where N in the length of the sequence
# - If omitted, i and j will default to 0 and N, respectively, where N is the length of the sequence
# - Negative values for i and j are taken relative to the end of the sequence
# - Both i and j will be clamped to the range (-N, N], where N in the length of the sequence
# - Indexing with np.newaxis or None on a given axis will add a new dimension of size one before that axis
# - Empty slices are not allowed
# - Strides other than 1 are not allowed
# - Empty slices are not allowed (tensors with 0s in shape have to be supported first, for all backends).
# - For a slice [i:j:k] finding the correct indices is delegated to slice.indices(len).
# - Strides > 1 and < 0 are now allowed!:
# - This works by applying Shrink -> [[Flip -> ] Pad -> Reshape -> Shrink] -> Reshape (ops in brackets are optional)
# - Idea of stride < 0 support:
# - Do the slice first, flip the axes were slice.step is negative, do slice.step -> -slice.step. Go to steps below.
# - Idea of stride `s` > 1 support (Pad -> Reshape -> Shrink):
# - Instead of doing [::s] on axis [dim_sz], do [:, 0] on axes [dim_sz_padded // s, s].
# - So pad dim_sz with as many zeros as needed (dim_sz -> dim_sz_padded) so that reshape to [dim_sz_padded // s, s]
# is possible.
# - Apply Shrink to do the slice [:, 0] on axes of shapes [dim_sz_padded // s, s].
def __getitem__(self, val):
def slcfix(i, sz, default): return default if i is None else max(0, min(sz, sz+i if i < 0 else i)) # Fix negative idxs, clamp to [0,N]
new_slice, new_shape = [], []
val = [val] if not isinstance(val, (list, tuple)) else val
assert sum(s is not None for s in val) <= len(self.shape)
assert all(s.step is None or s.step == 1 for s in val if isinstance(s, slice))
for i,(sz,s) in enumerate(zip(self.shape, [v for v in val if v is not None])): # Slicing only depends on ints + slices
if isinstance(s, int) and not (-sz <= s < sz):
raise IndexError(f"index {s} is out of bounds for dimension {i} with size {sz}")
new_slice.append((s%sz, s%sz+1) if isinstance(s, int) else (slcfix(s.start, sz, 0), slcfix(s.stop, sz, sz)))
for s,sz in zip(val, [self.shape[i-1] for i in itertools.accumulate([int(s is not None) for s in val])]): # Shape depends on slices + positions of Nones
if not isinstance(s, int):
new_shape.append(1 if s is None else slcfix(s.stop, sz, sz) - slcfix(s.start, sz, 0))
new_shape += [self.shape[i] for i in range(len(new_slice), len(self.shape))]
new_slice += [(0,self.shape[i]) for i in range(len(new_slice), len(self.shape))]
return self.slice(new_slice).reshape(new_shape if len(new_shape) else (1,))
def normalize_int(e, i, dim_sz):
if -dim_sz <= e < dim_sz: return e if e != -1 else dim_sz-1
raise IndexError(f"index {e} is out of bounds for dimension {i} with size {self.shape[i]}")
val = list(val) if isinstance(val, tuple) else [val]
if (num_slices := sum(isinstance(v, (slice, int)) for v in val)) > len(self.shape):
raise IndexError(f"too many indices for tensor of dimension {len(self.shape)}")
orig_slices = list(val) + [slice(None)] * (len(self.shape) - num_slices)
valid_slices = list(itertools.filterfalse(lambda x: x is None, orig_slices))
valid_slices = [v if isinstance(v, slice) else slice(y := normalize_int(v, i, dim_sz), y+1) for i, (v, dim_sz) in enumerate(zip(valid_slices, self.shape))]
start, stop, strides = zip(*y) if (y := [s.indices(dim_sz) for s, dim_sz in zip(valid_slices, self.shape)]) else ((), (), ())
new_slice = tuple((s, e) if st > 0 else (e+1, s+1) for s, e, st in zip(start, stop, strides))
new_shape = tuple(e - s for s, e in new_slice)
# Shrink
sliced_tensor = self.shrink(new_slice)
# Flip
if (flip_axes := tuple(i for i, s in enumerate(strides) if s < 0)):
sliced_tensor = sliced_tensor.flip(axis=flip_axes)
if any(s > 1 or s < 0 for s in strides):
# normalize if negative strides
strides = tuple(abs(s) for s in strides)
def num_zeros(step, dim_sz): return 0 if step == 1 or (y := dim_sz % step) == 0 else (step - y)
# Pad: add pad at the end: [dim_sz] -> [dim_sz_padded]
paddings = tuple((0, num_zeros(s, dim_sz)) for s, dim_sz in zip(strides, sliced_tensor.shape))
padded_tensor = sliced_tensor.pad(paddings)
# Reshape: [dim_sz_padded] -> [dim_sz_padded // s, s]
new_shape = functools.reduce(operator.add, [[sh // s, s] for sh, s in zip(padded_tensor.shape, strides)], []) # type: ignore
reshaped_tensor = padded_tensor.reshape(new_shape)
# Shrink: do [:, 0]
new_shape = new_shape[::2]
final_slice = functools.reduce(operator.add, (((0, sh), (0, 1)) for sh in new_shape), ())
sliced_tensor = reshaped_tensor.shrink(final_slice)
final_shape = []
it_shape = iter(new_shape)
for i in orig_slices:
if isinstance(i, (int, slice)):
dim_shape = next(it_shape)
if isinstance(i, slice): final_shape.append(dim_shape)
else: # i is None
final_shape.append(1)
return sliced_tensor.reshape(tuple(final_shape)) # Reshape
def cat(self, *args, dim=0):
dim = (dim + len(self.shape)) if dim < 0 else dim
for y in args:
assert len(y.shape) == len(self.shape) and all(y.shape[i] == s for i,s in enumerate(self.shape) if i != dim)
catargs = [self] + list(args)
assert all(len(t.shape) != 0 for t in catargs), "zero-dimensional tensor cannot be concatenated"
shape_cumsum = [0, *itertools.accumulate([y.shape[dim] for y in catargs])]
slc = [[(0, s) for s in self.shape] for _ in catargs]
for s,k in zip(slc, shape_cumsum):
@ -327,7 +358,7 @@ class Tensor:
axis_ = [x if x >= 0 else x+len(self.shape) for x in axis_]
shape = [self.shape[i] for i in range(len(self.shape)) if i not in axis_]
ret = fxn.apply(self, new_shape=tuple(1 if i in axis_ else self.shape[i] for i in range(len(self.shape))))
return ret if keepdim else ret.reshape(shape=[1] if shape == [] else shape)
return ret if keepdim else ret.reshape(shape=shape)
def sum(self, axis=None, keepdim=False): return self._reduce(mlops.Sum, axis, keepdim)
def max(self, axis=None, keepdim=False): return self._reduce(mlops.Max, axis, keepdim)
@ -425,6 +456,7 @@ class Tensor:
return ret if bias is None else ret.add(bias.reshape(1, -1, *[1 for _ in range(len(HW))]))
def dot(self, w:Tensor) -> Tensor:
if (n1:=len(self.shape))*(n2:=len(w.shape)) == 0: raise RuntimeError(f"both arguments to matmul need to be at least 1D, but they are {n1}D and {n2}D")
x = self.reshape(*self.shape[0:-1], 1, self.shape[-1])
w = w.reshape(*w.shape[0:-2], 1, w.shape[-2], w.shape[-1]).transpose(-1, -2)
r = (x*w).sum(-1)
@ -471,7 +503,7 @@ class Tensor:
# ***** broadcasted binary mlops *****
def _broadcasted(self, fxn:Type[Function], other:Union[Tensor, float], reverse:bool=False) -> Tensor:
x,y = [Tensor([t], device=self.device, requires_grad=False) if not isinstance(t, Tensor) else t for t in ([other,self] if reverse else [self,other])]
x,y = [Tensor(t, device=self.device, requires_grad=False) if not isinstance(t, Tensor) else t for t in ([other,self] if reverse else [self,other])]
x,y = [t.reshape([1]*(max(len(x.shape), len(y.shape))-len(t.shape)) + list(t.shape)) for t in [x,y]]
shape_ret = tuple(max(sx, sy) for sx,sy in zip(x.shape, y.shape))
return fxn.apply(x.expand(shape_ret), y.expand(shape_ret))