mirror of https://github.com/commaai/tinygrad.git
commit
65f3a9d499
|
@ -2,6 +2,7 @@ import numpy as np
|
|||
import torch
|
||||
import unittest
|
||||
from tinygrad.tensor import Tensor, Conv2D
|
||||
from tinygrad.gradcheck import numerical_jacobian, jacobian, gradcheck
|
||||
|
||||
x_init = np.random.randn(1,3).astype(np.float32)
|
||||
W_init = np.random.randn(3,3).astype(np.float32)
|
||||
|
@ -32,6 +33,37 @@ class TestTinygrad(unittest.TestCase):
|
|||
for x,y in zip(test_tinygrad(), test_pytorch()):
|
||||
np.testing.assert_allclose(x, y, atol=1e-5)
|
||||
|
||||
def test_jacobian(self):
|
||||
W = np.random.RandomState(1337).random((10, 5))
|
||||
x = np.random.RandomState(7331).random((1, 10)) - 0.5
|
||||
|
||||
torch_x = torch.tensor(x, requires_grad=True)
|
||||
torch_W = torch.tensor(W, requires_grad=True)
|
||||
torch_func = lambda x: torch.nn.functional.log_softmax(x.matmul(torch_W).relu(), dim=1)
|
||||
PJ = torch.autograd.functional.jacobian(torch_func, torch_x).squeeze().numpy()
|
||||
|
||||
tiny_x = Tensor(x)
|
||||
tiny_W = Tensor(W)
|
||||
tiny_func = lambda x: x.dot(tiny_W).relu().logsoftmax()
|
||||
J = jacobian(tiny_func, tiny_x)
|
||||
NJ = numerical_jacobian(tiny_func, tiny_x)
|
||||
|
||||
np.testing.assert_allclose(PJ, J, atol = 1e-5)
|
||||
np.testing.assert_allclose(PJ, NJ, atol = 1e-5)
|
||||
|
||||
def test_gradcheck(self):
|
||||
W = np.random.RandomState(1337).random((10, 5))
|
||||
x = np.random.RandomState(7331).random((1, 10)) - 0.5
|
||||
|
||||
tiny_x = Tensor(x)
|
||||
tiny_W = Tensor(W)
|
||||
tiny_func = lambda x: x.dot(tiny_W).relu().logsoftmax()
|
||||
|
||||
self.assertTrue(gradcheck(tiny_func, tiny_x))
|
||||
|
||||
# coarse approx. since a "big" eps and the non-linearities of the model
|
||||
self.assertFalse(gradcheck(tiny_func, tiny_x, eps = 0.1))
|
||||
|
||||
def test_conv2d(self):
|
||||
x = torch.randn((5,2,10,7), requires_grad=True)
|
||||
w = torch.randn((4,2,3,3), requires_grad=True)
|
||||
|
@ -48,7 +80,7 @@ class TestTinygrad(unittest.TestCase):
|
|||
np.testing.assert_allclose(w.grad, wt.grad, atol=1e-5)
|
||||
np.testing.assert_allclose(x.grad, xt.grad, atol=1e-5)
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
|
|
|
@ -0,0 +1,45 @@
|
|||
import numpy as np
|
||||
|
||||
from tinygrad.utils import mask_like
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
def jacobian(func, input):
|
||||
output = func(input)
|
||||
|
||||
ji = input.data.reshape(-1).shape[-1]
|
||||
jo = output.data.reshape(-1).shape[-1]
|
||||
J = np.zeros((jo,ji))
|
||||
|
||||
for o in range(jo):
|
||||
# tinygrad doesn't support slicing, tiny-hack to select
|
||||
# the needed scalar an backpropagate only through it
|
||||
o_scalar = Tensor(mask_like(output.data, o, 1.)).mul(output).sum()
|
||||
o_scalar.backward()
|
||||
|
||||
for i, grad in enumerate(input.grad.reshape(-1)):
|
||||
J[o,i] = grad
|
||||
return J
|
||||
|
||||
def numerical_jacobian(func, input, eps = 1e-6):
|
||||
output = func(input)
|
||||
|
||||
ji = input.data.reshape(-1).shape[-1]
|
||||
jo = output.data.reshape(-1).shape[-1]
|
||||
NJ = np.zeros((jo, ji))
|
||||
|
||||
for o in range(jo):
|
||||
for i in range(ji):
|
||||
|
||||
eps_perturb = mask_like(input.data, i, mask_value = eps)
|
||||
output_perturb_add = func(Tensor(input.data + eps_perturb)).data.reshape(-1)[o]
|
||||
output_perturb_sub = func(Tensor(input.data - eps_perturb)).data.reshape(-1)[o]
|
||||
|
||||
grad_approx = ((output_perturb_add) - (output_perturb_sub)) / (2*eps)
|
||||
|
||||
NJ[o,i] = grad_approx
|
||||
return NJ
|
||||
|
||||
def gradcheck(func, input, eps = 1e-06, atol = 1e-5, rtol = 0.001):
|
||||
NJ = numerical_jacobian(func, input, eps)
|
||||
J = jacobian(func, input)
|
||||
return np.allclose(J, NJ, atol=atol, rtol=rtol)
|
|
@ -1,5 +1,10 @@
|
|||
import numpy as np
|
||||
|
||||
def mask_like(like, mask_inx, mask_value = 1.0):
|
||||
mask = np.zeros_like(like).reshape(-1)
|
||||
mask[mask_inx] = mask_value
|
||||
return mask.reshape(like.shape)
|
||||
|
||||
def layer_init_uniform(*x):
|
||||
ret = np.random.uniform(-1., 1., size=x)/np.sqrt(np.prod(x))
|
||||
return ret.astype(np.float32)
|
||||
|
|
Loading…
Reference in New Issue