mirror of https://github.com/commaai/tinygrad.git
Zero dim Tensor support (#777)
* add and reorganize test_slice_* tests * refactor Tensor.__getitem__() * preliminary tests for 1) 0D tensors and 2) varargs for Tensor.zeros and Tensor.ones * always compare shapes of the numpy arrays obtained from tinygrad and torch tensors * add more tests for 0D support * remove test_tensor.test_slicing(). All slicing tests at test/test_ops.py * add zero-dim support * make test_end2end.py consistent with 0dim support * add test for tensor with zero in shape * don't simplify ones if shape is () * skip tests that need zero-size tensor support. - zero-size tensor support not related to 0dim tensors. * add tests for __getitem__() supporting strides >= 1 * refactor __getitem__: support for strides >= 1 * minor refactors and add comments to __getitem__ * add tests for slices with negative steps * add support for slices with negative strides
This commit is contained in:
parent
ae83e9844c
commit
ef129bcb85
|
@ -22,14 +22,14 @@ def compare_tiny_torch(model, model_torch, X, Y):
|
|||
|
||||
out = model(X)
|
||||
loss = (out * Y).mean()
|
||||
print(loss.realize().numpy()[0])
|
||||
print(loss.realize().numpy())
|
||||
|
||||
out_torch = model_torch(torch.Tensor(X.numpy()))
|
||||
loss_torch = (out_torch * torch.Tensor(Y.numpy())).mean()
|
||||
print(loss_torch.detach().numpy())
|
||||
|
||||
# assert losses match
|
||||
np.testing.assert_allclose(loss.realize().numpy()[0], loss_torch.detach().numpy(), atol=1e-4)
|
||||
np.testing.assert_allclose(loss.realize().numpy(), loss_torch.detach().numpy(), atol=1e-4)
|
||||
|
||||
# zero and backward
|
||||
optimizer.zero_grad()
|
||||
|
|
168
test/test_ops.py
168
test/test_ops.py
|
@ -15,7 +15,7 @@ def helper_test_op(shps, torch_fxn, tinygrad_fxn=None, atol=1e-6, rtol=1e-3, gra
|
|||
if shps is None:
|
||||
ts = [torch.tensor(x, requires_grad=True) for x in vals]
|
||||
else:
|
||||
ts = [torch.tensor((np.random.random(size=x).astype(np.float32)+a)*b, requires_grad=True) for x in shps]
|
||||
ts = [torch.tensor((np.random.random(size=x)+a)*b, requires_grad=True, dtype=torch.float32) for x in shps]
|
||||
|
||||
tst = [Tensor(x.detach().numpy(), requires_grad=not FORWARD_ONLY) for x in ts]
|
||||
|
||||
|
@ -29,7 +29,7 @@ def helper_test_op(shps, torch_fxn, tinygrad_fxn=None, atol=1e-6, rtol=1e-3, gra
|
|||
|
||||
def compare(s, x,y,atol,rtol):
|
||||
if PRINT_TENSORS: print(s, x, y)
|
||||
if y.shape != tuple(): assert x.shape == y.shape, f"shape mismatch (tinygrad){x.shape} != (torch){y.shape}"
|
||||
assert x.shape == y.shape, f"shape mismatch: tinygrad={x.shape} | torch={y.shape}"
|
||||
try:
|
||||
np.testing.assert_allclose(x,y, atol=atol, rtol=rtol)
|
||||
except Exception:
|
||||
|
@ -62,6 +62,8 @@ class TestOps(unittest.TestCase):
|
|||
helper_test_op([], lambda: torch.full((45,65), 4), lambda: Tensor.full((45,65), 4), forward_only=True)
|
||||
def test_zeros(self):
|
||||
helper_test_op([], lambda: torch.zeros(45,65), lambda: Tensor.zeros(45,65), forward_only=True)
|
||||
helper_test_op([], lambda: torch.zeros([45,65]), lambda: Tensor.zeros([45,65]), forward_only=True)
|
||||
helper_test_op([], lambda: torch.zeros([]), lambda: Tensor.zeros([]), forward_only=True)
|
||||
def test_zeros_like(self):
|
||||
a = Tensor([[1,2,3],[4,5,6]])
|
||||
b = torch.tensor([[1,2,3],[4,5,6]])
|
||||
|
@ -70,12 +72,15 @@ class TestOps(unittest.TestCase):
|
|||
helper_test_op([], lambda: torch.empty(45,65)*0/0, lambda: Tensor.empty(45,65)*0/0, forward_only=True)
|
||||
def test_ones(self):
|
||||
helper_test_op([], lambda: torch.ones(45,65), lambda: Tensor.ones(45,65), forward_only=True)
|
||||
helper_test_op([], lambda: torch.ones([45,65]), lambda: Tensor.ones([45,65]), forward_only=True)
|
||||
helper_test_op([], lambda: torch.ones([]), lambda: Tensor.ones([]), forward_only=True)
|
||||
def test_ones_like(self):
|
||||
a = Tensor([[1,2,3],[4,5,6]])
|
||||
b = torch.tensor([[1,2,3],[4,5,6]])
|
||||
helper_test_op([], lambda: torch.ones_like(b), lambda: Tensor.ones_like(a), forward_only=True)
|
||||
def test_eye(self):
|
||||
helper_test_op([], lambda: torch.eye(10), lambda: Tensor.eye(10), forward_only=True)
|
||||
|
||||
def test_arange(self):
|
||||
helper_test_op([], lambda: torch.arange(10), lambda: Tensor.arange(10), forward_only=True)
|
||||
def test_where(self):
|
||||
|
@ -121,43 +126,58 @@ class TestOps(unittest.TestCase):
|
|||
|
||||
def test_maximum(self):
|
||||
helper_test_op([(45,65), (45,65)], torch.maximum, Tensor.maximum)
|
||||
helper_test_op([(), ()], torch.maximum, Tensor.maximum)
|
||||
helper_test_op(None, torch.maximum, Tensor.maximum, vals=[[1., 0., 3., 4.], [1., 2., 3., 0.]])
|
||||
def test_minimum(self):
|
||||
helper_test_op([(45,65), (45,65)], torch.minimum, Tensor.minimum)
|
||||
helper_test_op([(), ()], torch.minimum, Tensor.minimum)
|
||||
def test_add(self):
|
||||
helper_test_op([(45,65), (45,65)], lambda x,y: x+y, Tensor.add)
|
||||
helper_test_op([(), ()], lambda x,y: x+y, Tensor.add)
|
||||
def test_add_simple(self):
|
||||
helper_test_op([(256), (256)], lambda x,y: x+y, Tensor.add, forward_only=True)
|
||||
def test_broadcasted_add(self):
|
||||
helper_test_op([(45,65), (45,1)], lambda x,y: x+y, lambda x,y: x+y)
|
||||
helper_test_op([(45,65), ()], lambda x,y: x+y, lambda x,y: x+y)
|
||||
def test_broadcasted_add_2(self):
|
||||
helper_test_op([(45,65), (65,)], lambda x,y: x+y, lambda x,y: x+y)
|
||||
def test_sub(self):
|
||||
helper_test_op([(45,65), (45,65)], lambda x,y: x-y, Tensor.sub)
|
||||
helper_test_op([(), ()], lambda x,y: x-y, Tensor.sub)
|
||||
def test_neg(self):
|
||||
helper_test_op([(45,65)], lambda x: -x)
|
||||
helper_test_op([()], lambda x: -x)
|
||||
def test_mul(self):
|
||||
helper_test_op([(64,64), (64,64)], lambda x,y: x*y, Tensor.mul)
|
||||
helper_test_op([(), ()], lambda x,y: x*y, Tensor.mul)
|
||||
def test_div(self):
|
||||
helper_test_op([(45,65), (45,65)], lambda x,y: x/y, Tensor.div)
|
||||
helper_test_op([(), ()], lambda x,y: x/y, Tensor.div)
|
||||
def test_div_const(self):
|
||||
helper_test_op([(45,65)], lambda x: x/255, lambda x: x/255)
|
||||
helper_test_op([(45,65)], lambda x: x/1, lambda x: x/1)
|
||||
helper_test_op([(45,65)], lambda x: 1/x, lambda x: 1/x)
|
||||
helper_test_op([(45,65)], lambda x: x/2, lambda x: x/2)
|
||||
helper_test_op([(45,65)], lambda x: 2/x, lambda x: 2/x)
|
||||
helper_test_op([()], lambda x: x/2, lambda x: x/2)
|
||||
helper_test_op([()], lambda x: 2/x, lambda x: 2/x)
|
||||
def test_pow(self):
|
||||
helper_test_op([(45,65)], lambda x: x**2, lambda x: Tensor.pow(x,2), a=0)
|
||||
helper_test_op([(45,65)], lambda x: x**3, lambda x: Tensor.pow(x,3), a=0)
|
||||
helper_test_op([(45,65)], lambda x: x**-2, lambda x: Tensor.pow(x,-2), a=0)
|
||||
helper_test_op([(45,65), (45,65)], lambda x,y: x**y, Tensor.pow, a=0)
|
||||
helper_test_op([()], lambda x: x**2, lambda x: Tensor.pow(x,2), a=0)
|
||||
helper_test_op([()], lambda x: x**-2, lambda x: Tensor.pow(x,-2), a=0)
|
||||
def test_pow_const(self):
|
||||
helper_test_op([(45,65)], lambda x: x**1.0, lambda x: x**1.0)
|
||||
helper_test_op([(45,65)], lambda x: 1.0**x, lambda x: 1.0**x)
|
||||
helper_test_op([(45,65)], lambda x: x**2.0, lambda x: x**2.0)
|
||||
helper_test_op([(45,65)], lambda x: 2.0**x, lambda x: 2.0**x)
|
||||
helper_test_op([()], lambda x: x**2.0, lambda x: x**2.0)
|
||||
helper_test_op([()], lambda x: 2.0**x, lambda x: 2.0**x)
|
||||
def test_sqrt(self):
|
||||
helper_test_op([(45,65)], lambda x: x.sqrt(), Tensor.sqrt, a=0)
|
||||
helper_test_op([()], lambda x: x.sqrt(), Tensor.sqrt, a=0)
|
||||
|
||||
def test_sin(self):
|
||||
helper_test_op([(45,65)], lambda x: x.sin(), Tensor.sin, a=0)
|
||||
|
@ -168,47 +188,65 @@ class TestOps(unittest.TestCase):
|
|||
|
||||
def test_relu(self):
|
||||
helper_test_op([(64,64)], lambda x: x.relu(), Tensor.relu)
|
||||
helper_test_op([()], lambda x: x.relu(), Tensor.relu)
|
||||
def test_relu_exact(self):
|
||||
helper_test_op(None, lambda x: x.relu(), Tensor.relu, vals=[[-1.,0,1]])
|
||||
def test_relu_maximum_exact(self):
|
||||
helper_test_op(None, lambda x: torch.maximum(x, torch.zeros_like(x, requires_grad=False)), lambda x: Tensor.maximum(x, 0), vals=[[-1.,0,1]])
|
||||
def test_leakyrelu(self):
|
||||
helper_test_op([(45,65)], lambda x: torch.nn.functional.leaky_relu(x,0.01), Tensor.leakyrelu)
|
||||
helper_test_op([()], lambda x: torch.nn.functional.leaky_relu(x,0.01), Tensor.leakyrelu)
|
||||
def test_celu(self):
|
||||
for val in range(1, 5):
|
||||
helper_test_op([(45,65)], lambda x: torch.nn.functional.celu(x,val), lambda x: x.celu(val))
|
||||
helper_test_op([()], lambda x: torch.nn.functional.celu(x,val), lambda x: x.celu(val))
|
||||
def test_abs(self):
|
||||
helper_test_op([(45,65)], lambda x: torch.abs(x), Tensor.abs)
|
||||
helper_test_op([()], lambda x: torch.abs(x), Tensor.abs)
|
||||
def test_log(self):
|
||||
helper_test_op([(45,65)], lambda x: torch.log(x), Tensor.log)
|
||||
helper_test_op([()], lambda x: torch.log(x), Tensor.log)
|
||||
def test_exp(self):
|
||||
helper_test_op([(45,65)], lambda x: torch.exp(x), Tensor.exp)
|
||||
helper_test_op([()], lambda x: torch.exp(x), Tensor.exp)
|
||||
def test_sign(self):
|
||||
helper_test_op([(45,65)], lambda x: torch.sign(x), Tensor.sign)
|
||||
helper_test_op([()], lambda x: torch.sign(x), Tensor.sign)
|
||||
def test_softsign(self):
|
||||
helper_test_op([(45,65)], lambda x: torch.nn.functional.softsign(x), Tensor.softsign)
|
||||
helper_test_op([()], lambda x: torch.nn.functional.softsign(x), Tensor.softsign)
|
||||
def test_sigmoid(self):
|
||||
helper_test_op([(45,65)], lambda x: x.sigmoid(), Tensor.sigmoid)
|
||||
helper_test_op([()], lambda x: x.sigmoid(), Tensor.sigmoid, forward_only=True)
|
||||
def test_softplus(self):
|
||||
helper_test_op([(45,65)], lambda x: torch.nn.functional.softplus(x), Tensor.softplus, atol=1e-6, grad_atol=1e-6)
|
||||
helper_test_op([()], lambda x: torch.nn.functional.softplus(x), Tensor.softplus, atol=1e-6, grad_atol=1e-6)
|
||||
@unittest.skip("not supported in older pytorch")
|
||||
def test_gelu(self):
|
||||
helper_test_op([(45,65)], lambda x: torch.nn.functional.gelu(x, approximate="tanh"), Tensor.gelu)
|
||||
def test_quick_gelu(self):
|
||||
helper_test_op([(45,65)], lambda x: x * torch.sigmoid(1.702 * x), Tensor.quick_gelu)
|
||||
helper_test_op([()], lambda x: x * torch.sigmoid(1.702 * x), Tensor.quick_gelu)
|
||||
def test_elu(self):
|
||||
helper_test_op([(45,65)], lambda x: torch.nn.functional.elu(x), Tensor.elu)
|
||||
helper_test_op([(45,65)], lambda x: torch.nn.functional.elu(x, alpha=0.1), lambda x: Tensor.elu(x, alpha=0.1))
|
||||
helper_test_op([()], lambda x: torch.nn.functional.elu(x), Tensor.elu)
|
||||
def test_relu6(self):
|
||||
helper_test_op([(45,65)], lambda x: torch.nn.functional.relu6(x), Tensor.relu6)
|
||||
helper_test_op([()], lambda x: torch.nn.functional.relu6(x), Tensor.relu6)
|
||||
def test_hardswish(self):
|
||||
helper_test_op([(45,65)], lambda x: torch.nn.functional.hardswish(x), Tensor.hardswish, atol=1e-6, grad_atol=1e-6)
|
||||
helper_test_op([()], lambda x: torch.nn.functional.hardswish(x), Tensor.hardswish, atol=1e-6, grad_atol=1e-6)
|
||||
def test_mish(self):
|
||||
def _mish_pytorch(x):
|
||||
return x*torch.tanh(torch.nn.functional.softplus(x))
|
||||
helper_test_op([(45,65)], _mish_pytorch, Tensor.mish, atol=1e-4)
|
||||
helper_test_op([()], _mish_pytorch, Tensor.mish, atol=1e-4)
|
||||
def test_dot(self):
|
||||
helper_test_op([(45,65), (65,100)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4)
|
||||
with self.assertRaises(RuntimeError):
|
||||
a = Tensor(3.14)
|
||||
a.matmul(a)
|
||||
def test_matmul_simple(self):
|
||||
helper_test_op([(4), (4,4)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4)
|
||||
def test_matmul(self):
|
||||
|
@ -225,6 +263,10 @@ class TestOps(unittest.TestCase):
|
|||
helper_test_op([(64,64), (64,64)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-3)
|
||||
def test_broadcastdot(self):
|
||||
helper_test_op([(10,45,65), (65,45)], lambda x,y: x @ y, Tensor.dot, atol=1e-4)
|
||||
with self.assertRaises(RuntimeError):
|
||||
a = Tensor(3.14)
|
||||
b = Tensor.ones(3,3)
|
||||
a @ b
|
||||
def test_multidot(self):
|
||||
helper_test_op([(10,45,65), (10,65,45)], lambda x,y: x @ y, Tensor.dot, atol=1e-4)
|
||||
helper_test_op([(3,3,45,65), (3,3,65,45)], lambda x,y: x @ y, Tensor.dot, atol=1e-4)
|
||||
|
@ -241,10 +283,12 @@ class TestOps(unittest.TestCase):
|
|||
helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(0,2)), lambda x: Tensor.sum(x, axis=(0,2)))
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(1,2)), lambda x: Tensor.sum(x, axis=(1,2)))
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=1), lambda x: Tensor.sum(x, axis=1))
|
||||
helper_test_op([()], lambda x: x.sum(), Tensor.sum)
|
||||
def test_min(self):
|
||||
helper_test_op([(3,3)], lambda x: x.min(), Tensor.min)
|
||||
helper_test_op([(45,3)], lambda x: x.min(), Tensor.min)
|
||||
helper_test_op([(45,3)], lambda x: x.min().mul(0.5), lambda x: Tensor.min(x).mul(0.5))
|
||||
helper_test_op([()], lambda x: x.min(), Tensor.min)
|
||||
def test_max(self):
|
||||
helper_test_op([(45,3)], lambda x: x.max(), Tensor.max)
|
||||
helper_test_op([(45,3)], lambda x: x.max().mul(0.5), lambda x: Tensor.max(x).mul(0.5))
|
||||
|
@ -253,8 +297,10 @@ class TestOps(unittest.TestCase):
|
|||
[[1.0,1.0,0.0,1.0]],
|
||||
])
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.max(axis=1)[0], lambda x: Tensor.max(x, axis=1))
|
||||
helper_test_op([()], lambda x: x.max(), Tensor.max)
|
||||
def test_mean(self):
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.mean())
|
||||
helper_test_op([()], lambda x: x.mean())
|
||||
def test_mean_axis(self):
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.mean(axis=(1,2)), lambda x: Tensor.mean(x, axis=(1,2)))
|
||||
def test_std(self):
|
||||
|
@ -275,28 +321,34 @@ class TestOps(unittest.TestCase):
|
|||
helper_test_op([(45, 65, 85)], lambda x: torch.std(x, keepdim=True, correction=0, dim=0), lambda x: Tensor.std(x, keepdim=True, correction=0, axis=0))
|
||||
def test_log_softmax(self):
|
||||
helper_test_op([(45,65)], lambda x: torch.nn.LogSoftmax(dim=1)(x), Tensor.log_softmax, atol=1e-7, grad_atol=1e-7)
|
||||
helper_test_op([()], lambda x: torch.nn.LogSoftmax(dim=0)(x), Tensor.log_softmax, atol=1e-7, grad_atol=1e-7)
|
||||
def test_log_softmax_other_axis(self):
|
||||
helper_test_op([(10,10,10)], lambda x: x.log_softmax(0), lambda x: x.log_softmax(0), atol=1e-7, grad_atol=1e-7)
|
||||
helper_test_op([(10,10,10)], lambda x: x.log_softmax(1), lambda x: x.log_softmax(1), atol=1e-7, grad_atol=1e-7)
|
||||
helper_test_op([(10,10,10)], lambda x: x.log_softmax(2), lambda x: x.log_softmax(2), atol=1e-7, grad_atol=1e-7)
|
||||
def test_tanh(self):
|
||||
helper_test_op([(45,65)], lambda x: x.tanh(), Tensor.tanh, atol=1e-6, grad_atol=1e-6)
|
||||
helper_test_op([()], lambda x: x.tanh(), Tensor.tanh, atol=1e-6, grad_atol=1e-6)
|
||||
def test_hardtanh(self):
|
||||
for val in range(10, 30, 5):
|
||||
helper_test_op([(45,65)], lambda x: torch.nn.functional.hardtanh(x,-val, val), lambda x: x.hardtanh(-val, val), atol=1e-6, grad_atol=1e-6)
|
||||
helper_test_op([()], lambda x: torch.nn.functional.hardtanh(x,-val, val), lambda x: x.hardtanh(-val, val), atol=1e-6, grad_atol=1e-6)
|
||||
def test_topo_sort(self):
|
||||
helper_test_op([(45,65)], lambda x: (x+x)*x, lambda x: x.add(x).mul(x), atol=1e-6, grad_atol=1e-6)
|
||||
helper_test_op([()], lambda x: (x+x)*x, lambda x: x.add(x).mul(x), atol=1e-6, grad_atol=1e-6)
|
||||
|
||||
def test_scalar_mul(self):
|
||||
helper_test_op([(45,65)], lambda x: x*2, lambda x: x*2)
|
||||
helper_test_op([()], lambda x: x*2, lambda x: x*2)
|
||||
def test_scalar_rmul(self):
|
||||
helper_test_op([(45,65)], lambda x: 2*x, lambda x: 2*x)
|
||||
|
||||
helper_test_op([()], lambda x: 2*x, lambda x: 2*x)
|
||||
def test_scalar_sub(self):
|
||||
helper_test_op([(45,65)], lambda x: x-2, lambda x: x-2)
|
||||
helper_test_op([()], lambda x: x-2, lambda x: x-2)
|
||||
def test_scalar_rsub(self):
|
||||
helper_test_op([(45,65)], lambda x: 2-x, lambda x: 2-x)
|
||||
|
||||
helper_test_op([()], lambda x: 2-x, lambda x: 2-x)
|
||||
def test_flip_eye_crash(self):
|
||||
helper_test_op([], lambda: (torch.eye(10)@torch.eye(10).flip(0)),
|
||||
lambda: (Tensor.eye(10)@Tensor.eye(10).flip(0)), forward_only=True)
|
||||
|
@ -310,6 +362,7 @@ class TestOps(unittest.TestCase):
|
|||
|
||||
def test_broadcast_simple(self):
|
||||
helper_test_op([(45,65), (45,1)], lambda x,y: x/y, lambda x,y: x/y)
|
||||
helper_test_op([(45,65), ()], lambda x,y: x/y, lambda x,y: x/y)
|
||||
|
||||
def test_broadcast_partial(self):
|
||||
for torch_op, tinygrad_op in [(torch.add, Tensor.add), (torch.sub, Tensor.sub), (torch.mul, Tensor.mul),
|
||||
|
@ -320,19 +373,83 @@ class TestOps(unittest.TestCase):
|
|||
# NOTE: ANE backwards?
|
||||
helper_test_op(shapes, torch_op, tinygrad_op, a=-0.5 if tinygrad_op != Tensor.pow else 0.0)
|
||||
|
||||
def test_slice_simple(self):
|
||||
helper_test_op([(3,3)], lambda x: x[1:2, 1:2], lambda x: x[1:2, 1:2])
|
||||
def test_slice_in_bounds_1dim(self):
|
||||
helper_test_op([(3)], lambda x: x[1:3], lambda x: x[1:3])
|
||||
helper_test_op([(3)], lambda x: x[0:2], lambda x: x[0:2])
|
||||
helper_test_op([(3)], lambda x: x[-2:2], lambda x: x[-2:2])
|
||||
|
||||
def test_slice(self):
|
||||
helper_test_op([(3,3,3,3)], lambda x: x[1:2], lambda x: x[1:2])
|
||||
helper_test_op([(3,3,3,3)], lambda x: x[1:2, 1:2], lambda x: x[1:2, 1:2])
|
||||
helper_test_op([(3,3,3,3)], lambda x: x[1:2, 1:2, 0:-1], lambda x: x[1:2, 1:2, 0:-1])
|
||||
def test_slice_on_0dim_tensor(self):
|
||||
helper_test_op([()], lambda x: x[None], lambda x: x[None])
|
||||
|
||||
def test_slice_one(self):
|
||||
with self.assertRaises(IndexError):
|
||||
a = Tensor(3.14)
|
||||
a[0]
|
||||
|
||||
def test_slice_int_indexing(self):
|
||||
helper_test_op([(3)], lambda x: x[1], lambda x: x[1])
|
||||
|
||||
def test_slice_one_multi(self):
|
||||
helper_test_op([(3)], lambda x: x[-2], lambda x: x[-2])
|
||||
helper_test_op([(10,10)], lambda x: x[1], lambda x: x[1])
|
||||
helper_test_op([(3,3,3)], lambda x: x[1,1,1], lambda x: x[1,1,1])
|
||||
|
||||
def test_slice_in_bounds_multidim(self):
|
||||
helper_test_op([(3,3,3)], lambda x: x[1:2], lambda x: x[1:2])
|
||||
helper_test_op([(3,3,3)], lambda x: x[1:2, 2], lambda x: x[1:2, 2])
|
||||
helper_test_op([(3,3,3)], lambda x: x[1:2, 1:2], lambda x: x[1:2, 1:2])
|
||||
helper_test_op([(3,3,3)], lambda x: x[1:2, 1:2, 0:-1], lambda x: x[1:2, 1:2, 0:-1])
|
||||
|
||||
def test_slice_with_none(self):
|
||||
helper_test_op([(3,3,3)], lambda x: x[None], lambda x: x[None])
|
||||
helper_test_op([(3,3,3)], lambda x: x[1:2, None], lambda x: x[1:2, None])
|
||||
helper_test_op([(3,3,3)], lambda x: x[1:2, None, 1:2], lambda x: x[1:2, None, 1:2])
|
||||
helper_test_op([(3,3,3)], lambda x: x[1:2, 1:2, None, -1], lambda x: x[1:2, 1:2, None, -1])
|
||||
|
||||
def test_slice_one_endpoint_out_of_bounds(self):
|
||||
helper_test_op([(3,3,3)], lambda x: x[0:4], lambda x: x[0:4])
|
||||
helper_test_op([(3,3,3)], lambda x: x[-6:4], lambda x: x[-6:4])
|
||||
helper_test_op([(3,3,3)], lambda x: x[1:50], lambda x: x[1:50])
|
||||
helper_test_op([(3,3,3)], lambda x: x[1:50, 1:2, -1], lambda x: x[1:50, 1:2, -1])
|
||||
|
||||
def test_slice_stride_gt_one(self):
|
||||
helper_test_op([(7,5,10)], lambda x: x[::2, ::3, ::4], lambda x: x[::2, ::3, ::4])
|
||||
helper_test_op([(7,5,10)], lambda x: x[1:5:2, ::3, ::4], lambda x: x[1:5:2, ::3, ::4])
|
||||
helper_test_op([(7,5,10)], lambda x: x[1:5:2, 3, ::4], lambda x: x[1:5:2, 3, ::4])
|
||||
helper_test_op([(7,5,10)], lambda x: x[1:5:2, None, None, 3, None, ::4], lambda x: x[1:5:2, None, None, 3, None, ::4])
|
||||
|
||||
def test_slice_negative_strides(self):
|
||||
# Torch doesn't support slicing with negative steps
|
||||
a = np.random.randn(10, 10, 10).astype(np.float32)
|
||||
t = Tensor(a)
|
||||
np.testing.assert_allclose(a[::-1], t[::-1].numpy())
|
||||
np.testing.assert_allclose(a[::-2], t[::-2].numpy())
|
||||
np.testing.assert_allclose(a[:, 2:0:-1], t[:, 2:0:-1].numpy())
|
||||
np.testing.assert_allclose(a[:, 2:0:-1, 3:1:-2], t[:, 2:0:-1, 3:1:-2].numpy())
|
||||
np.testing.assert_allclose(a[4:0:-3, 2:0:-1, -1:-5:-2], t[4:0:-3, 2:0:-1, -1:-5:-2].numpy())
|
||||
|
||||
@unittest.skip("No suppport for tensors with 0s in shape")
|
||||
def test_slice_both_endpoints_out_of_bounds(self):
|
||||
helper_test_op([(3,3,3)], lambda x: x[5:10], lambda x: x[5:10], forward_only=True)
|
||||
helper_test_op([(3,3,3)], lambda x: x[-15:-7], lambda x: x[-15:-7], forward_only=True)
|
||||
|
||||
@unittest.skip("No suppport for tensors with 0s in shape")
|
||||
def test_slice_start_gt_end(self):
|
||||
helper_test_op([(3,3,3)], lambda x: x[-2:2], lambda x: x[-2:2], forward_only=True)
|
||||
helper_test_op([(3,3,3)], lambda x: x[-2:-5], lambda x: x[-2:-5], forward_only=True)
|
||||
|
||||
@unittest.skip("No suppport for tensors with 0s in shape")
|
||||
def test_slice_empty(self):
|
||||
helper_test_op([(10,10)], lambda x: x[1:1], lambda x: x[1:1], forward_only=True)
|
||||
|
||||
@unittest.skip("No suppport for tensors with 0s in shape")
|
||||
def test_slice_zero_in_shape(self):
|
||||
helper_test_op([(10,10)], lambda x: x[1:1], lambda x: x[1:1]) # x.shape = (0, 10)
|
||||
helper_test_op([(3,3,3)], lambda x: x[-2:-5], lambda x: x[-2:-5]) # x.shape = (0, 3, 3)
|
||||
|
||||
def test_slice_errors(self):
|
||||
a = Tensor.ones(4, 3)
|
||||
with self.assertRaises(IndexError):
|
||||
a[1, 77, 77, 77] # IndexError: (finds too many indices before the out of bounds)
|
||||
a[1, 77] # IndexError: (out of bounds).
|
||||
a[0, -77]
|
||||
|
||||
def test_pad2d(self):
|
||||
helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4)), lambda x: x.pad2d(padding=(1,2,3,4)))
|
||||
|
@ -342,10 +459,18 @@ class TestOps(unittest.TestCase):
|
|||
helper_test_op([(3,3,3)], lambda x: x.transpose(0,2), lambda x: x.transpose(0,2))
|
||||
helper_test_op([(1,2,3,4)], lambda x: x.movedim((3,0,2,1),(0,1,2,3)), lambda x: x.permute(order=(3,0,2,1)))
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.movedim((3,2,1,0),(0,1,2,3)), lambda x: x.permute(order=(3,2,1,0)))
|
||||
helper_test_op([()], lambda x: x.permute(()), lambda x: x.permute(()))
|
||||
|
||||
def test_reshape(self):
|
||||
helper_test_op([(4,3,6,6)], lambda x: torch.reshape(x, (-1,3,6,6)), lambda x: x.reshape(shape=(-1,3,6,6)))
|
||||
helper_test_op([(4,3,6,6)], lambda x: torch.reshape(x, (-1,1,6,6)), lambda x: x.reshape(shape=(-1,1,6,6)))
|
||||
helper_test_op([()], lambda x: torch.reshape(x, []), lambda x: x.reshape([]))
|
||||
helper_test_op([(1,)], lambda x: torch.reshape(x, []), lambda x: x.reshape([]))
|
||||
helper_test_op([()], lambda x: torch.reshape(x, [1]), lambda x: x.reshape([1]))
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
x = Tensor.ones((4,3,6,6))
|
||||
x.reshape([])
|
||||
|
||||
def test_flip(self):
|
||||
helper_test_op([(4,3,6,6)], lambda x: torch.flip(x, (0,)), lambda x: x.flip(axis=(0,)))
|
||||
|
@ -354,23 +479,31 @@ class TestOps(unittest.TestCase):
|
|||
helper_test_op([(4,3,6,6)], lambda x: torch.flip(x, (3,)), lambda x: x.flip(axis=(3,)))
|
||||
helper_test_op([(4,3,6,6)], lambda x: torch.flip(x, (0,1,3)).flip((0,)), lambda x: x.flip(axis=(0,1,3)).flip(0))
|
||||
helper_test_op([(4,3,6,6)], lambda x: torch.flip(x, (3,)), lambda x: x.flip(axis=(-1,)))
|
||||
helper_test_op([()], lambda x: torch.flip(x, ()), lambda x: x.flip(axis=()))
|
||||
helper_test_op([(1,)], lambda x: torch.flip(x, ()), lambda x: x.flip(axis=()))
|
||||
helper_test_op([(4, 3, 6, 6)], lambda x: torch.flip(x, ()), lambda x: x.flip(axis=()))
|
||||
|
||||
def test_unsqueeze(self):
|
||||
helper_test_op([(4,3,6,6)], lambda x: torch.unsqueeze(x, 0), lambda x: x.unsqueeze(dim=0))
|
||||
helper_test_op([(4,3,6,6)], lambda x: torch.unsqueeze(x, 4), lambda x: x.unsqueeze(dim=4))
|
||||
helper_test_op([(4,3,6,6)], lambda x: torch.unsqueeze(x, -1), lambda x: x.unsqueeze(dim=-1))
|
||||
helper_test_op([(4,3,6,6)], lambda x: torch.unsqueeze(x, -3), lambda x: x.unsqueeze(dim=-3))
|
||||
helper_test_op([()], lambda x: torch.unsqueeze(x, 0), lambda x: x.unsqueeze(dim=0))
|
||||
|
||||
def test_flatten(self):
|
||||
for axis in range(3):
|
||||
helper_test_op([(4,3,6,6)], lambda x: torch.flatten(x, start_dim=axis), lambda x: x.flatten(axis))
|
||||
helper_test_op([()], lambda x: x.flatten(), lambda x: x.flatten())
|
||||
helper_test_op([(1,)], lambda x: x.flatten(), lambda x: x.flatten())
|
||||
|
||||
def test_detach(self):
|
||||
helper_test_op([(4,3,6,6)], lambda x: x.detach(), lambda x: x.detach(), forward_only=True)
|
||||
helper_test_op([()], lambda x: x.detach(), lambda x: x.detach(), forward_only=True)
|
||||
|
||||
def test_expand(self):
|
||||
arg = (4,3,2,6)
|
||||
helper_test_op([(4,3,1,6)], lambda x: x.expand(arg), lambda x: x.expand(shape=arg))
|
||||
helper_test_op([()], lambda x: x.expand([]), lambda x: x.expand(shape=[]))
|
||||
|
||||
@unittest.skip("very slow")
|
||||
def test_sd_big_conv(self):
|
||||
|
@ -695,6 +828,10 @@ class TestOps(unittest.TestCase):
|
|||
for dim in range(-1, 2):
|
||||
helper_test_op([(45,65), (45,65)], lambda x,y: torch.cat((x,y), dim), lambda x,y: x.cat(y, dim=dim))
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
a = Tensor(3.14)
|
||||
a.cat(a)
|
||||
|
||||
def test_multicat(self):
|
||||
for dim in range(-1, 2):
|
||||
helper_test_op([(45,65), (45,65), (45,65)], lambda x,y,z: torch.cat((x,y,z), dim), lambda x,y,z: x.cat(y, z, dim=dim))
|
||||
|
@ -707,6 +844,9 @@ class TestOps(unittest.TestCase):
|
|||
|
||||
with self.assertRaises(IndexError):
|
||||
Tensor.stack([x], dim=77)
|
||||
|
||||
a = Tensor(3.14)
|
||||
np.testing.assert_allclose(Tensor.stack([a, a]).numpy(), Tensor([3.14, 3.14]).numpy())
|
||||
|
||||
def test_repeat(self):
|
||||
x = Tensor.randn(45, 65, 3)
|
||||
|
@ -715,6 +855,7 @@ class TestOps(unittest.TestCase):
|
|||
for reps in [[], [4], [2, 1], [3, 2, 2]]:
|
||||
repeats = base_repeats + reps
|
||||
helper_test_op([(45, 65, 3)], lambda x: x.repeat(*repeats), lambda x: x.repeat(repeats))
|
||||
helper_test_op([()], lambda x: x.repeat(*repeats), lambda x: x.repeat(repeats))
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
x.repeat((2, 4))
|
||||
|
@ -722,7 +863,6 @@ class TestOps(unittest.TestCase):
|
|||
with self.assertRaises(AssertionError):
|
||||
x.repeat((2, 0, 4))
|
||||
|
||||
|
||||
def test_clip(self):
|
||||
helper_test_op([(45,65)], lambda x: x.clip(-2.3, 1.2), lambda x: x.clip(-2.3, 1.2))
|
||||
|
||||
|
|
|
@ -14,6 +14,13 @@ W_init = np.random.randn(3,3).astype(np.float32)
|
|||
m_init = np.random.randn(1,3).astype(np.float32)
|
||||
|
||||
class TestTinygrad(unittest.TestCase):
|
||||
def test_zerodim_initialization(self):
|
||||
a = Tensor(55)
|
||||
b = Tensor(3.14)
|
||||
|
||||
self.assertEqual(a.shape, ())
|
||||
self.assertEqual(b.shape, ())
|
||||
|
||||
def test_plus_equals(self):
|
||||
a = Tensor.randn(10,10)
|
||||
b = Tensor.randn(10,10)
|
||||
|
@ -23,20 +30,6 @@ class TestTinygrad(unittest.TestCase):
|
|||
val2 = a.numpy()
|
||||
np.testing.assert_allclose(val1, val2)
|
||||
|
||||
def test_slicing(self):
|
||||
x = Tensor.randn(10,10)
|
||||
slices = [0,1,9,-1,-10,None] + [slice(s,e) for s,e in itertools.combinations([0,1,-1,None], r=2)] + [slice(9,11), slice(-11,-9)]
|
||||
fmt = lambda s: f'{s.start}:{s.stop}' if isinstance(s, slice) else str(s)
|
||||
for s in list(itertools.product(slices, slices)) + [(None,0,None,0,None), (slice(0,2),None,None,slice(2,4),None,None)]:
|
||||
np.testing.assert_equal(x.numpy()[s], x[s].numpy(), f'Test failed for slice x[{",".join(fmt(x) for x in s)}]')
|
||||
for s in [-11,10]:
|
||||
with self.assertRaises(IndexError):
|
||||
x[s]
|
||||
with self.assertRaises(AssertionError):
|
||||
x[::2]
|
||||
with self.assertRaises(AssertionError):
|
||||
x[0,0,0]
|
||||
|
||||
def test_backward_pass(self):
|
||||
def test_tinygrad():
|
||||
x = Tensor(x_init, requires_grad=True)
|
||||
|
|
|
@ -414,6 +414,7 @@ class Linearizer:
|
|||
def simplify_ones(self):
|
||||
# remove places where the shape is all ones
|
||||
# TODO: this should be factored in to multi shape stride
|
||||
if self.shape_len == 0: return
|
||||
all_ones = [all(st.shape[i]==1 for st in self.sts) for i in range(self.shape_len)]
|
||||
# keep at least 1 one
|
||||
if all(all_ones): all_ones[-1] = False
|
||||
|
|
|
@ -7,6 +7,7 @@ base_image_type = (100, 2, "imageh", np.float16) if FLOAT16 else (100, 4, "image
|
|||
|
||||
def image_dot(self, w):
|
||||
# NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1)
|
||||
if (n1:=len(self.shape))*(n2:=len(w.shape)) == 0: raise RuntimeError(f"both arguments to matmul need to be at least 1D, but they are {n1}D and {n2}D")
|
||||
bs, groups = prod(self.shape[0:-2]), prod(w.shape[0:-2])
|
||||
cin, cout = w.shape[-2], w.shape[-1]
|
||||
out_shape_t = self.shape[0:-2] + (cout,-1)
|
||||
|
|
|
@ -74,13 +74,13 @@ class View:
|
|||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def strides_for_shape(shape:Tuple[int, ...]) -> Tuple[int, ...]:
|
||||
strides = [1]
|
||||
strides = [1] if shape else []
|
||||
for d in shape[::-1][:-1]: strides = [d*strides[0]] + strides
|
||||
return tuple(st if s != 1 else 0 for st, s in zip(strides, shape))
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def view_from_shape(shape:Tuple[int, ...]) -> View:
|
||||
assert all(isinstance(x, int) for x in shape) and len(shape) != 0
|
||||
assert all(isinstance(x, int) for x in shape)
|
||||
return View(tuple(shape), strides_for_shape(shape))
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
|
||||
from __future__ import annotations
|
||||
import math, functools, itertools
|
||||
import math, functools, itertools, operator
|
||||
import numpy as np
|
||||
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence
|
||||
from tinygrad.helpers import prod, argfix, make_pair, getenv, IMAGE, DEBUG, flatten, DType, dtypes, LazyNumpyArray
|
||||
|
@ -33,9 +33,9 @@ class Tensor:
|
|||
no_grad: ClassVar[bool] = False
|
||||
default_type: ClassVar[DType] = dtypes.float32
|
||||
|
||||
def __init__(self, data:Union[list, LazyBuffer, LazyNumpyArray, np.ndarray], device=Device.DEFAULT, dtype:Optional[DType]=None, requires_grad:Optional[bool]=None):
|
||||
def __init__(self, data:Union[int, float, list, LazyBuffer, LazyNumpyArray, np.ndarray], device=Device.DEFAULT, dtype:Optional[DType]=None, requires_grad:Optional[bool]=None):
|
||||
device = (device.split(":", 1)[0].upper() + ((":"+device.split(":", 1)[1]) if ':' in device else '')).replace(":0", "") # canonicalize device
|
||||
if isinstance(data, list):
|
||||
if isinstance(data, (int, float, list)):
|
||||
data = np.array(data, dtype=(dtype if dtype is not None else Tensor.default_type).np)
|
||||
elif isinstance(data, LazyBuffer) and data.device != device:
|
||||
# TODO: this has to realize, it shouldn't have to
|
||||
|
@ -47,7 +47,6 @@ class Tensor:
|
|||
# by here, it's either LazyNumpyArray or LazyBuffer
|
||||
# TODO: it should all be LazyBuffer I think
|
||||
if isinstance(data, LazyNumpyArray):
|
||||
data = data if data.shape else data.reshape((1,))
|
||||
lazydata = LazyBuffer.fromCPU(data.astype(dtype.np) if dtype is not None else data, device)
|
||||
elif isinstance(data, LazyBuffer):
|
||||
assert dtype is None or dtype == data.dtype, "dtype doesn't match, and casting isn't supported"
|
||||
|
@ -122,15 +121,13 @@ class Tensor:
|
|||
# ***** creation helper functions *****
|
||||
|
||||
@staticmethod
|
||||
def full(shape:Tuple[int, ...], fill_value, **kwargs):
|
||||
new_shape = argfix(shape)
|
||||
return Tensor([fill_value], **kwargs).reshape([1]*len(new_shape)).expand(new_shape).contiguous()
|
||||
def full(shape:Tuple[int, ...], fill_value, **kwargs): return Tensor(fill_value, **kwargs).reshape([1]*len(new_shape := argfix(shape))).expand(new_shape).contiguous()
|
||||
|
||||
@staticmethod
|
||||
def zeros(*shape, **kwargs): return Tensor.full(shape, 0, **kwargs)
|
||||
def zeros(*shape, **kwargs): return Tensor.full(argfix(*shape), 0, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def ones(*shape, **kwargs): return Tensor.full(shape, 1, **kwargs)
|
||||
def ones(*shape, **kwargs): return Tensor.full(argfix(*shape), 1, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def full_like(tensor, fill_value, dtype:Optional[DType]=None, **kwargs):
|
||||
|
@ -203,11 +200,11 @@ class Tensor:
|
|||
return _deepwalk(self, set(), [])
|
||||
|
||||
def backward(self):
|
||||
assert self.shape == (1,)
|
||||
assert self.shape == tuple(), f"backward can only be called for scalar tensors, but it has shape {self.shape})"
|
||||
|
||||
# fill in the first grad with one. don't use Tensor.ones because we don't need contiguous
|
||||
# this is "implicit gradient creation"
|
||||
self.grad = Tensor([1], device=self.device, requires_grad=False)
|
||||
self.grad = Tensor(1, device=self.device, requires_grad=False)
|
||||
|
||||
for t0 in reversed(self.deepwalk()):
|
||||
if not any(x.requires_grad for x in t0._ctx.parents):
|
||||
|
@ -227,7 +224,7 @@ class Tensor:
|
|||
|
||||
def reshape(self, shape, *args) -> Tensor:
|
||||
new_shape = argfix(shape, *args)
|
||||
assert len(new_shape) > 0 and all(x != 0 for x in new_shape), f"zeros not allowed in shape {new_shape}"
|
||||
assert all(x != 0 for x in new_shape), f"zeros not allowed in shape {new_shape}"
|
||||
return mlops.Reshape.apply(self, shape=tuple(-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape))
|
||||
def expand(self, shape, *args) -> Tensor: return mlops.Expand.apply(self, shape=tuple(x if x != -1 else s for s,x in zip(self.shape, argfix(shape, *args))))
|
||||
def permute(self, order, *args) -> Tensor: return mlops.Permute.apply(self, order=argfix(order, *args))
|
||||
|
@ -243,37 +240,71 @@ class Tensor:
|
|||
padding = tuple((max(0, -p[0]), max(0, p[1]-self.shape[i])) for i,p in enumerate(arg_))
|
||||
return self.pad(padding).shrink(tuple((p[0] + padding[i][0], p[1] + padding[i][0]) for i,p in enumerate(arg_)))
|
||||
|
||||
# Tensors mostly follow the normal python indexing / slicing behavior for sequences
|
||||
# - Negative indices are taken relative to the end of the sequence, so X[-2] returns the 2nd-to-last element
|
||||
# - A slice i:j returns the elements with indices in [i, j)
|
||||
# - If omitted, i and j will default to 0 and N, respectively, where N is the length of the sequence
|
||||
# - Negative values for i and j are taken relative to the end of the sequence
|
||||
# - Both i and j will be clamped to the range (-N, N], where N in the length of the sequence
|
||||
# - If omitted, i and j will default to 0 and N, respectively, where N is the length of the sequence
|
||||
# - Negative values for i and j are taken relative to the end of the sequence
|
||||
# - Both i and j will be clamped to the range (-N, N], where N in the length of the sequence
|
||||
# - Indexing with np.newaxis or None on a given axis will add a new dimension of size one before that axis
|
||||
# - Empty slices are not allowed
|
||||
# - Strides other than 1 are not allowed
|
||||
# - Empty slices are not allowed (tensors with 0s in shape have to be supported first, for all backends).
|
||||
# - For a slice [i:j:k] finding the correct indices is delegated to slice.indices(len).
|
||||
# - Strides > 1 and < 0 are now allowed!:
|
||||
# - This works by applying Shrink -> [[Flip -> ] Pad -> Reshape -> Shrink] -> Reshape (ops in brackets are optional)
|
||||
# - Idea of stride < 0 support:
|
||||
# - Do the slice first, flip the axes were slice.step is negative, do slice.step -> -slice.step. Go to steps below.
|
||||
# - Idea of stride `s` > 1 support (Pad -> Reshape -> Shrink):
|
||||
# - Instead of doing [::s] on axis [dim_sz], do [:, 0] on axes [dim_sz_padded // s, s].
|
||||
# - So pad dim_sz with as many zeros as needed (dim_sz -> dim_sz_padded) so that reshape to [dim_sz_padded // s, s]
|
||||
# is possible.
|
||||
# - Apply Shrink to do the slice [:, 0] on axes of shapes [dim_sz_padded // s, s].
|
||||
def __getitem__(self, val):
|
||||
def slcfix(i, sz, default): return default if i is None else max(0, min(sz, sz+i if i < 0 else i)) # Fix negative idxs, clamp to [0,N]
|
||||
new_slice, new_shape = [], []
|
||||
val = [val] if not isinstance(val, (list, tuple)) else val
|
||||
assert sum(s is not None for s in val) <= len(self.shape)
|
||||
assert all(s.step is None or s.step == 1 for s in val if isinstance(s, slice))
|
||||
for i,(sz,s) in enumerate(zip(self.shape, [v for v in val if v is not None])): # Slicing only depends on ints + slices
|
||||
if isinstance(s, int) and not (-sz <= s < sz):
|
||||
raise IndexError(f"index {s} is out of bounds for dimension {i} with size {sz}")
|
||||
new_slice.append((s%sz, s%sz+1) if isinstance(s, int) else (slcfix(s.start, sz, 0), slcfix(s.stop, sz, sz)))
|
||||
for s,sz in zip(val, [self.shape[i-1] for i in itertools.accumulate([int(s is not None) for s in val])]): # Shape depends on slices + positions of Nones
|
||||
if not isinstance(s, int):
|
||||
new_shape.append(1 if s is None else slcfix(s.stop, sz, sz) - slcfix(s.start, sz, 0))
|
||||
new_shape += [self.shape[i] for i in range(len(new_slice), len(self.shape))]
|
||||
new_slice += [(0,self.shape[i]) for i in range(len(new_slice), len(self.shape))]
|
||||
return self.slice(new_slice).reshape(new_shape if len(new_shape) else (1,))
|
||||
def normalize_int(e, i, dim_sz):
|
||||
if -dim_sz <= e < dim_sz: return e if e != -1 else dim_sz-1
|
||||
raise IndexError(f"index {e} is out of bounds for dimension {i} with size {self.shape[i]}")
|
||||
val = list(val) if isinstance(val, tuple) else [val]
|
||||
if (num_slices := sum(isinstance(v, (slice, int)) for v in val)) > len(self.shape):
|
||||
raise IndexError(f"too many indices for tensor of dimension {len(self.shape)}")
|
||||
orig_slices = list(val) + [slice(None)] * (len(self.shape) - num_slices)
|
||||
valid_slices = list(itertools.filterfalse(lambda x: x is None, orig_slices))
|
||||
valid_slices = [v if isinstance(v, slice) else slice(y := normalize_int(v, i, dim_sz), y+1) for i, (v, dim_sz) in enumerate(zip(valid_slices, self.shape))]
|
||||
start, stop, strides = zip(*y) if (y := [s.indices(dim_sz) for s, dim_sz in zip(valid_slices, self.shape)]) else ((), (), ())
|
||||
new_slice = tuple((s, e) if st > 0 else (e+1, s+1) for s, e, st in zip(start, stop, strides))
|
||||
new_shape = tuple(e - s for s, e in new_slice)
|
||||
# Shrink
|
||||
sliced_tensor = self.shrink(new_slice)
|
||||
# Flip
|
||||
if (flip_axes := tuple(i for i, s in enumerate(strides) if s < 0)):
|
||||
sliced_tensor = sliced_tensor.flip(axis=flip_axes)
|
||||
if any(s > 1 or s < 0 for s in strides):
|
||||
# normalize if negative strides
|
||||
strides = tuple(abs(s) for s in strides)
|
||||
def num_zeros(step, dim_sz): return 0 if step == 1 or (y := dim_sz % step) == 0 else (step - y)
|
||||
# Pad: add pad at the end: [dim_sz] -> [dim_sz_padded]
|
||||
paddings = tuple((0, num_zeros(s, dim_sz)) for s, dim_sz in zip(strides, sliced_tensor.shape))
|
||||
padded_tensor = sliced_tensor.pad(paddings)
|
||||
# Reshape: [dim_sz_padded] -> [dim_sz_padded // s, s]
|
||||
new_shape = functools.reduce(operator.add, [[sh // s, s] for sh, s in zip(padded_tensor.shape, strides)], []) # type: ignore
|
||||
reshaped_tensor = padded_tensor.reshape(new_shape)
|
||||
# Shrink: do [:, 0]
|
||||
new_shape = new_shape[::2]
|
||||
final_slice = functools.reduce(operator.add, (((0, sh), (0, 1)) for sh in new_shape), ())
|
||||
sliced_tensor = reshaped_tensor.shrink(final_slice)
|
||||
final_shape = []
|
||||
it_shape = iter(new_shape)
|
||||
for i in orig_slices:
|
||||
if isinstance(i, (int, slice)):
|
||||
dim_shape = next(it_shape)
|
||||
if isinstance(i, slice): final_shape.append(dim_shape)
|
||||
else: # i is None
|
||||
final_shape.append(1)
|
||||
return sliced_tensor.reshape(tuple(final_shape)) # Reshape
|
||||
|
||||
def cat(self, *args, dim=0):
|
||||
dim = (dim + len(self.shape)) if dim < 0 else dim
|
||||
for y in args:
|
||||
assert len(y.shape) == len(self.shape) and all(y.shape[i] == s for i,s in enumerate(self.shape) if i != dim)
|
||||
catargs = [self] + list(args)
|
||||
assert all(len(t.shape) != 0 for t in catargs), "zero-dimensional tensor cannot be concatenated"
|
||||
shape_cumsum = [0, *itertools.accumulate([y.shape[dim] for y in catargs])]
|
||||
slc = [[(0, s) for s in self.shape] for _ in catargs]
|
||||
for s,k in zip(slc, shape_cumsum):
|
||||
|
@ -327,7 +358,7 @@ class Tensor:
|
|||
axis_ = [x if x >= 0 else x+len(self.shape) for x in axis_]
|
||||
shape = [self.shape[i] for i in range(len(self.shape)) if i not in axis_]
|
||||
ret = fxn.apply(self, new_shape=tuple(1 if i in axis_ else self.shape[i] for i in range(len(self.shape))))
|
||||
return ret if keepdim else ret.reshape(shape=[1] if shape == [] else shape)
|
||||
return ret if keepdim else ret.reshape(shape=shape)
|
||||
|
||||
def sum(self, axis=None, keepdim=False): return self._reduce(mlops.Sum, axis, keepdim)
|
||||
def max(self, axis=None, keepdim=False): return self._reduce(mlops.Max, axis, keepdim)
|
||||
|
@ -425,6 +456,7 @@ class Tensor:
|
|||
return ret if bias is None else ret.add(bias.reshape(1, -1, *[1 for _ in range(len(HW))]))
|
||||
|
||||
def dot(self, w:Tensor) -> Tensor:
|
||||
if (n1:=len(self.shape))*(n2:=len(w.shape)) == 0: raise RuntimeError(f"both arguments to matmul need to be at least 1D, but they are {n1}D and {n2}D")
|
||||
x = self.reshape(*self.shape[0:-1], 1, self.shape[-1])
|
||||
w = w.reshape(*w.shape[0:-2], 1, w.shape[-2], w.shape[-1]).transpose(-1, -2)
|
||||
r = (x*w).sum(-1)
|
||||
|
@ -471,7 +503,7 @@ class Tensor:
|
|||
# ***** broadcasted binary mlops *****
|
||||
|
||||
def _broadcasted(self, fxn:Type[Function], other:Union[Tensor, float], reverse:bool=False) -> Tensor:
|
||||
x,y = [Tensor([t], device=self.device, requires_grad=False) if not isinstance(t, Tensor) else t for t in ([other,self] if reverse else [self,other])]
|
||||
x,y = [Tensor(t, device=self.device, requires_grad=False) if not isinstance(t, Tensor) else t for t in ([other,self] if reverse else [self,other])]
|
||||
x,y = [t.reshape([1]*(max(len(x.shape), len(y.shape))-len(t.shape)) + list(t.shape)) for t in [x,y]]
|
||||
shape_ret = tuple(max(sx, sy) for sx,sy in zip(x.shape, y.shape))
|
||||
return fxn.apply(x.expand(shape_ret), y.expand(shape_ret))
|
||||
|
|
Loading…
Reference in New Issue