tinygrad/test/external/external_test_optim.py

76 lines
2.4 KiB
Python
Raw Normal View History

#!/usr/bin/env python
import unittest
import numpy as np
import tensorflow as tf
import tensorflow_addons as tfa
from tinygrad.tensor import Tensor
from tinygrad.nn.optim import LAMB
np.random.seed(1337)
x_init = np.random.randn(1,4).astype(np.float32)
W_init = np.random.randn(4,4).astype(np.float32)
m_init = np.random.randn(1,4).astype(np.float32)
class TinyNet():
def __init__(self):
self.x = Tensor(x_init.copy(), requires_grad=True)
self.W = Tensor(W_init.copy(), requires_grad=True)
self.m = Tensor(m_init.copy())
def forward(self):
out = self.x.matmul(self.W).relu()
out = out.log_softmax(1)
out = out.mul(self.m).add(self.m).sum()
return out
class TinyNetTF():
def __init__(self):
self.x = tf.Variable(x_init.copy(), trainable=True)
self.W = tf.Variable(W_init.copy(), trainable=True)
self.m = tf.constant(m_init.copy())
def forward(self):
out = tf.matmul(self.x, self.W)
out = tf.nn.relu(out)
out = tf.nn.log_softmax(out, axis=1)
out = tf.multiply(out, self.m) + self.m
out = tf.reduce_sum(out)
return out
def step(optim, steps=1, kwargs={}):
net = TinyNet()
optim = optim([net.x, net.W], **kwargs)
for _ in range(steps):
out = net.forward()
optim.zero_grad()
out.backward()
optim.step()
return net.x.detach().numpy(), net.W.detach().numpy()
def step_tf(optim, steps=1, kwargs={}):
net = TinyNetTF()
optim = optim(**kwargs)
for _ in range(steps):
with tf.GradientTape() as tape:
out = net.forward()
grads = tape.gradient(out, [net.x, net.W])
optim.apply_gradients(zip(grads, [net.x, net.W]))
return net.x.numpy(), net.W.numpy()
class ExternalTestOptim(unittest.TestCase):
def _test_optim(self, tinygrad_optim, tensorflow_optim, steps, opts, atol, rtol):
for x,y in zip(step(tinygrad_optim, steps, kwargs=opts),
step_tf(tensorflow_optim, steps, kwargs=opts)):
np.testing.assert_allclose(x, y, atol=atol, rtol=rtol)
def _test_lamb(self, steps, opts, atol, rtol): self._test_optim(LAMB, tfa.optimizers.LAMB, steps, opts, atol, rtol)
def test_lamb(self): self._test_lamb(1, {'lr': 0.001}, 1e-5, 0)
def test_lamb_high_lr(self): self._test_lamb(1, {'lr': 10}, 1e-5, 1e-5)
def test_multistep_lamb(self): self._test_lamb(10, {'lr': 0.001}, 1e-5, 0)
def test_multistep_lamb_high_lr(self): self._test_lamb(10, {'lr': 10}, 1e-5, 3e-4)
if __name__ == '__main__':
unittest.main()