tinygrad/test/test_optim.py

98 lines
4.4 KiB
Python
Raw Normal View History

import numpy as np
import torch
import unittest
from tinygrad.tensor import Tensor
2023-03-21 03:31:34 +08:00
from tinygrad.nn.optim import Adam, SGD, AdamW
CI < 5 minutes (#1252) * models matrix * fix typo and install gpu deps * install llvm deps if needed * fix * testops with cuda * remove pip cache since not work * cuda env * install cuda deps * maybe it will work now * i can't read * all tests in matrix * trim down more * opencl stuff in matrix * opencl pip cache * test split * change cuda test exclusion * test * fix cuda maybe * add models * add more n=auto * third thing * fix bug * cache pip more * change name * update tests * try again cause why not * balance * try again... * try apt cache for cuda * try on gpu: * try cuda again * update packages step * replace libz-dev with zlib1g-dev * only cache cuda * why error * fix gpuocelot bug * apt cache err * apt cache to slow? * opt and image in single runner * add a couple n=autos * remove test matrix * try cuda apt cache again * libz-dev -> zlib1g-dev * remove -s since not supported by xdist * the cache takes too long and doesn't work * combine webgpu and metal tests * combine imagenet to c and cpu tests * torch tests with linters * torch back by itself * small windows clang test with torch tests * fix a goofy windows bug * im dumb * bro * clang with linters * fix pylint error * linter not work on windows * try with clang again * clang and imagenet? * install deps * fix * fix quote * clang by itself (windows too slow) * env vars for imagenet * cache pip for metal and webgpu tests * try torch with metal and webgpu * doesn't work, too long * remove -v * try -n=logical * don't use logical * revert accidental thing * remove some prints unless CI * fix print unless CI * ignore speed tests for slow tests * clang windows in matrix (ubuntu being tested in imagenet->c test) * try manual pip cache * fix windows pip cache path * all manual pip cache * fix pip cache dir for macos * print_ci function in helpers * CI as variable, no print_ci * missed one * cuda tests with docker image * remove setup-python action for cuda * python->python3? * remove -s -v * try fix pip cache * maybe fix * try to fix pip cache * is this the path? * maybe cache pip * try again * create wheels dir * ? * cuda pip deps in dockerfile * disable pip cache for clang * image from ghcr instead of docker hub * why is clang like this * fast deps * try use different caches * remove the fast thing * try with lighter image * remove setup python for cuda * small docker and cuda fast deps * ignore a few more tests * cool docker thing (maybe) * oops * quotes * fix docker command * fix bug * ignore train efficientnet test * remove dockerfile (docker stuff takes too long) * remove docker stuff and normal cuda * oops * ignore the tests for cuda * does this work * ignore test_train on slow backends * add space * llvm ignore same tests as cuda * nvm * ignore lr scheduler tests * get some stats * fix ignore bug * remove extra ' * remove and * ignore test for llvm * change ignored tests and durationon all backends * fix * and -> or * ignore some more cuda tests * finally? * does this fix it * remove durations=0 * add some more tests to llvm * make last pytest more readable * fix * don't train efficientnet on cpu * try w/out pip cache * pip cache seems to be generally better * pytest file markers * try apt fast for cuda * use quick install for apt-fast * apt-fast not worth * apt-get to apt * fix typo * suppress warnings * register markers * disable debug on fuzz tests * change marker names * apt update and apt install in one command * update marker names in test.yml * webgpu pytest marker
2023-07-24 04:00:56 +08:00
import pytest
pytestmark = pytest.mark.exclude_cuda
2023-03-12 09:49:53 +08:00
np.random.seed(1337)
x_init = np.random.randn(1,4).astype(np.float32)
W_init = np.random.randn(4,4).astype(np.float32)
m_init = np.random.randn(1,4).astype(np.float32)
class TinyNet:
2023-03-12 09:49:53 +08:00
def __init__(self, tensor):
self.x = tensor(x_init.copy(), requires_grad=True)
self.W = tensor(W_init.copy(), requires_grad=True)
self.m = tensor(m_init.copy())
def forward(self):
2023-03-12 09:49:53 +08:00
out = self.x.matmul(self.W).relu()
# print(out.detach().numpy())
2023-03-12 09:49:53 +08:00
out = out.log_softmax(1)
out = out.mul(self.m).add(self.m).sum()
return out
2023-03-12 09:49:53 +08:00
def step(tensor, optim, steps=1, kwargs={}):
net = TinyNet(tensor)
optim = optim([net.x, net.W], **kwargs)
for _ in range(steps):
out = net.forward()
optim.zero_grad()
out.backward()
optim.step()
return net.x.detach().numpy(), net.W.detach().numpy()
class TestOptim(unittest.TestCase):
2023-03-12 09:49:53 +08:00
def _test_optim(self, tinygrad_optim, torch_optim, steps, opts, atol, rtol):
for x,y in zip(step(Tensor, tinygrad_optim, steps, kwargs=opts),
step(torch.tensor, torch_optim, steps, kwargs=opts)):
np.testing.assert_allclose(x, y, atol=atol, rtol=rtol)
2023-03-12 09:49:53 +08:00
def _test_sgd(self, steps, opts, atol, rtol): self._test_optim(SGD, torch.optim.SGD, steps, opts, atol, rtol)
def _test_adam(self, steps, opts, atol, rtol): self._test_optim(Adam, torch.optim.Adam, steps, opts, atol, rtol)
def _test_adamw(self, steps, opts, atol, rtol): self._test_optim(AdamW, torch.optim.AdamW, steps, opts, atol, rtol)
2023-03-12 09:49:53 +08:00
def test_sgd(self): self._test_sgd(1, {'lr': 0.001}, 1e-6, 0)
def test_sgd_high_lr(self): self._test_sgd(1, {'lr': 10}, 1e-6, 1e-5)
def test_sgd_wd(self): self._test_sgd(1, {'lr': 0.001, 'weight_decay': 0.1}, 1e-6, 0)
def test_sgd_high_lr_wd(self): self._test_sgd(1, {'lr': 10, 'weight_decay': 0.1}, 1e-6, 1e-5)
2023-03-12 09:49:53 +08:00
def test_multistep_sgd(self): self._test_sgd(10, {'lr': 0.001}, 1e-6, 0)
def test_multistep_sgd_high_lr(self): self._test_sgd(10, {'lr': 10}, 1e-6, 3e-4)
def test_multistep_sgd_wd(self): self._test_sgd(10, {'lr': 0.001, 'weight_decay': 0.1}, 1e-6, 0)
def test_multistep_sgd_high_lr_wd(self): self._test_sgd(10, {'lr': 9, 'weight_decay': 0.1}, 1e-6, 3e-4)
2023-03-12 09:49:53 +08:00
def test_multistep_sgd_momentum(self): self._test_sgd(10, {'lr': 0.001, 'momentum': 0.9}, 1e-6, 0)
def test_multistep_sgd_high_lr_momentum(self): self._test_sgd(10, {'lr': 10, 'momentum': 0.9}, 1e-5, 3e-4)
def test_multistep_sgd_momentum_wd(self): self._test_sgd(10, {'lr': 0.001, 'momentum': 0.9, 'weight_decay': 0.1}, 1e-6, 0)
def test_multistep_sgd_high_lr_momentum_wd(self): self._test_sgd(10, {'lr': 10, 'momentum': 0.9, 'weight_decay': 0.1}, 1e-5, 3e-4)
2023-03-12 09:49:53 +08:00
def test_multistep_sgd_nesterov_momentum(self): self._test_sgd(10, {'lr': 0.001, 'momentum': 0.9, 'nesterov': True}, 1e-5, 0)
def test_multistep_sgd_high_lr_nesterov_momentum(self): self._test_sgd(10, {'lr': 10, 'momentum': 0.9, 'nesterov': True}, 1e-5, 3e-4)
def test_multistep_sgd_nesterov_momentum_wd(self): self._test_sgd(10, {'lr': 0.001, 'momentum': 0.9, 'nesterov': True, 'weight_decay': 0.1}, 1e-5, 0)
def test_multistep_sgd_high_lr_nesterov_momentum_wd(self): self._test_sgd(10, {'lr': 9, 'momentum': 0.9, 'nesterov': True, 'weight_decay': 0.1}, 1e-5, 3e-4)
2023-03-12 09:49:53 +08:00
def test_adam(self): self._test_adam(1, {'lr': 0.001}, 1e-5, 0)
def test_adam_high_lr(self): self._test_adam(1, {'lr': 10}, 1e-4, 1e-4)
def test_adamw(self): self._test_adamw(1, {'lr': 0.001}, 1e-5, 0)
def test_adamw_high_lr(self): self._test_adamw(1, {'lr': 10}, 1e-4, 1e-4)
2023-03-12 09:49:53 +08:00
def test_multistep_adam(self): self._test_adam(10, {'lr': 0.001}, 1e-5, 0)
def test_multistep_adam_high_lr(self): self._test_adam(10, {'lr': 10}, 2e-4, 5e-4)
2023-03-21 03:31:34 +08:00
def test_multistep_adamw(self): self._test_adamw(10, {'lr': 0.001}, 1e-5, 0)
def test_multistep_adamw_high_lr(self): self._test_adamw(10, {'lr': 10}, 5e-4, 2e-3)
def test_duped_weights(self):
for Opt in [Adam, AdamW, SGD]:
losses = []
for i in range(2):
w = Tensor(x_init.copy())
opt = Opt([w], lr=0.1) if i == 0 else Opt([w, w], lr=0.1)
loss = None
for _ in range(3):
loss = w.sum()
opt.zero_grad()
loss.backward()
opt.step()
losses.append(loss.numpy())
np.testing.assert_allclose(losses[0], losses[1], atol=1e-4, rtol=0)
if __name__ == '__main__':
unittest.main()