Merge pull request #11 from 0xNaN/master

tiny gradcheck
This commit is contained in:
George Hotz 2020-10-22 10:22:42 -07:00 committed by GitHub
commit 65f3a9d499
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 83 additions and 1 deletions

View File

@ -2,6 +2,7 @@ import numpy as np
import torch
import unittest
from tinygrad.tensor import Tensor, Conv2D
from tinygrad.gradcheck import numerical_jacobian, jacobian, gradcheck
x_init = np.random.randn(1,3).astype(np.float32)
W_init = np.random.randn(3,3).astype(np.float32)
@ -32,6 +33,37 @@ class TestTinygrad(unittest.TestCase):
for x,y in zip(test_tinygrad(), test_pytorch()):
np.testing.assert_allclose(x, y, atol=1e-5)
def test_jacobian(self):
W = np.random.RandomState(1337).random((10, 5))
x = np.random.RandomState(7331).random((1, 10)) - 0.5
torch_x = torch.tensor(x, requires_grad=True)
torch_W = torch.tensor(W, requires_grad=True)
torch_func = lambda x: torch.nn.functional.log_softmax(x.matmul(torch_W).relu(), dim=1)
PJ = torch.autograd.functional.jacobian(torch_func, torch_x).squeeze().numpy()
tiny_x = Tensor(x)
tiny_W = Tensor(W)
tiny_func = lambda x: x.dot(tiny_W).relu().logsoftmax()
J = jacobian(tiny_func, tiny_x)
NJ = numerical_jacobian(tiny_func, tiny_x)
np.testing.assert_allclose(PJ, J, atol = 1e-5)
np.testing.assert_allclose(PJ, NJ, atol = 1e-5)
def test_gradcheck(self):
W = np.random.RandomState(1337).random((10, 5))
x = np.random.RandomState(7331).random((1, 10)) - 0.5
tiny_x = Tensor(x)
tiny_W = Tensor(W)
tiny_func = lambda x: x.dot(tiny_W).relu().logsoftmax()
self.assertTrue(gradcheck(tiny_func, tiny_x))
# coarse approx. since a "big" eps and the non-linearities of the model
self.assertFalse(gradcheck(tiny_func, tiny_x, eps = 0.1))
def test_conv2d(self):
x = torch.randn((5,2,10,7), requires_grad=True)
w = torch.randn((4,2,3,3), requires_grad=True)
@ -48,7 +80,7 @@ class TestTinygrad(unittest.TestCase):
np.testing.assert_allclose(w.grad, wt.grad, atol=1e-5)
np.testing.assert_allclose(x.grad, xt.grad, atol=1e-5)
if __name__ == '__main__':
unittest.main()

45
tinygrad/gradcheck.py Normal file
View File

@ -0,0 +1,45 @@
import numpy as np
from tinygrad.utils import mask_like
from tinygrad.tensor import Tensor
def jacobian(func, input):
output = func(input)
ji = input.data.reshape(-1).shape[-1]
jo = output.data.reshape(-1).shape[-1]
J = np.zeros((jo,ji))
for o in range(jo):
# tinygrad doesn't support slicing, tiny-hack to select
# the needed scalar an backpropagate only through it
o_scalar = Tensor(mask_like(output.data, o, 1.)).mul(output).sum()
o_scalar.backward()
for i, grad in enumerate(input.grad.reshape(-1)):
J[o,i] = grad
return J
def numerical_jacobian(func, input, eps = 1e-6):
output = func(input)
ji = input.data.reshape(-1).shape[-1]
jo = output.data.reshape(-1).shape[-1]
NJ = np.zeros((jo, ji))
for o in range(jo):
for i in range(ji):
eps_perturb = mask_like(input.data, i, mask_value = eps)
output_perturb_add = func(Tensor(input.data + eps_perturb)).data.reshape(-1)[o]
output_perturb_sub = func(Tensor(input.data - eps_perturb)).data.reshape(-1)[o]
grad_approx = ((output_perturb_add) - (output_perturb_sub)) / (2*eps)
NJ[o,i] = grad_approx
return NJ
def gradcheck(func, input, eps = 1e-06, atol = 1e-5, rtol = 0.001):
NJ = numerical_jacobian(func, input, eps)
J = jacobian(func, input)
return np.allclose(J, NJ, atol=atol, rtol=rtol)

View File

@ -1,5 +1,10 @@
import numpy as np
def mask_like(like, mask_inx, mask_value = 1.0):
mask = np.zeros_like(like).reshape(-1)
mask[mask_inx] = mask_value
return mask.reshape(like.shape)
def layer_init_uniform(*x):
ret = np.random.uniform(-1., 1., size=x)/np.sqrt(np.prod(x))
return ret.astype(np.float32)