mirror of https://github.com/commaai/tinygrad.git
Consistent testing (#137)
* Consistent GPU classes Convert the existing GPU classes into one standard format. Remove duplicated functions in `test_mnist` and create a TestMNISTGPU class. This reduces line count and ensures consistency. Use `@unittest.skipUnless(GPU, "Requires GPU")` instead of `if GPU:` to skip GPU testing. This will ensure that skipped tests are displayed accordingly in the pytest output. * Optim Testing now supports GPU * Tensor testing now supports GPU jacobian and gradcheck auto skipped until GPU float64 support added. * GPU support for custom constructor methods * Remove GPU flag from Model constructors It was requested that the `gpu` kwarg be removed from the model constructor. GPU conversion is now handled in the train function. This also required the conversion of Optimizer parameters as they are constructed prior to execution of the `train` function and are dependant on the model GPU state. * Fix typo: float32->float64 * Clean `get_parameters` utility Just a quick refactor w/ the new support for optimizers. * Remove GPU kwarg from TinyNet Remove `gpu` kwarg from tiny net to match test_mnist `train` function.
This commit is contained in:
parent
34b38dd4d0
commit
89d0ff6989
|
@ -22,6 +22,7 @@ X_train, Y_train, X_test, Y_test = fetch_mnist()
|
|||
|
||||
# create a model
|
||||
class TinyBobNet:
|
||||
|
||||
def __init__(self):
|
||||
self.l1 = Tensor.uniform(784, 128)
|
||||
self.l2 = Tensor.uniform(128, 10)
|
||||
|
@ -54,6 +55,7 @@ class TinyConvNet:
|
|||
return x.dot(self.l1).logsoftmax()
|
||||
|
||||
def train(model, optim, steps, BS=128, gpu=False):
|
||||
if gpu is True: [x.cuda_() for x in get_parameters([model, optim])]
|
||||
losses, accuracies = [], []
|
||||
for i in (t := trange(steps, disable=os.getenv('CI') is not None)):
|
||||
samp = np.random.randint(0, X_train.shape[0], size=(BS))
|
||||
|
@ -94,53 +96,32 @@ def evaluate(model, gpu=False):
|
|||
assert accuracy > 0.95
|
||||
|
||||
class TestMNIST(unittest.TestCase):
|
||||
@unittest.skipUnless(GPU, "Requires GPU")
|
||||
def test_conv_gpu(self):
|
||||
np.random.seed(1337)
|
||||
model = TinyConvNet()
|
||||
[x.cuda_() for x in model.parameters()]
|
||||
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
||||
train(model, optimizer, steps=200, gpu=True)
|
||||
evaluate(model, gpu=True)
|
||||
gpu=False
|
||||
|
||||
def test_conv(self):
|
||||
np.random.seed(1337)
|
||||
model = TinyConvNet()
|
||||
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
||||
train(model, optimizer, steps=200)
|
||||
evaluate(model)
|
||||
|
||||
@unittest.skipUnless(GPU, "Requires GPU")
|
||||
def test_sgd_gpu(self):
|
||||
np.random.seed(1337)
|
||||
model = TinyBobNet()
|
||||
[x.cuda_() for x in model.parameters()]
|
||||
optimizer = optim.SGD(model.parameters(), lr=0.001)
|
||||
train(model, optimizer, steps=1000, gpu=True)
|
||||
evaluate(model, gpu=True)
|
||||
train(model, optimizer, steps=200, gpu=self.gpu)
|
||||
evaluate(model, gpu=self.gpu)
|
||||
|
||||
def test_sgd(self):
|
||||
np.random.seed(1337)
|
||||
model = TinyBobNet()
|
||||
optimizer = optim.SGD(model.parameters(), lr=0.001)
|
||||
train(model, optimizer, steps=1000)
|
||||
evaluate(model)
|
||||
|
||||
@unittest.skipUnless(GPU, "Requires GPU")
|
||||
def test_rmsprop_gpu(self):
|
||||
np.random.seed(1337)
|
||||
model = TinyBobNet()
|
||||
[x.cuda_() for x in model.parameters()]
|
||||
optimizer = optim.RMSprop(model.parameters(), lr=0.0002)
|
||||
train(model, optimizer, steps=1000, gpu=True)
|
||||
evaluate(model, gpu=True)
|
||||
train(model, optimizer, steps=1000, gpu=self.gpu)
|
||||
evaluate(model, gpu=self.gpu)
|
||||
|
||||
def test_rmsprop(self):
|
||||
np.random.seed(1337)
|
||||
model = TinyBobNet()
|
||||
optimizer = optim.RMSprop(model.parameters(), lr=0.0002)
|
||||
train(model, optimizer, steps=1000)
|
||||
evaluate(model)
|
||||
train(model, optimizer, steps=1000, gpu=self.gpu)
|
||||
evaluate(model, gpu=self.gpu)
|
||||
|
||||
@unittest.skipUnless(GPU, "Requires GPU")
|
||||
class TestMNISTGPU(TestMNIST):
|
||||
gpu = True
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
@ -145,9 +145,9 @@ class TestOps(unittest.TestCase):
|
|||
lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=ksz),
|
||||
lambda x: Tensor.avg_pool2d(x, kernel_size=ksz), gpu=self.gpu)
|
||||
|
||||
if GPU:
|
||||
class TestOpsGPU(TestOps):
|
||||
gpu = True
|
||||
@unittest.skipUnless(GPU, "Requires GPU")
|
||||
class TestOpsGPU(TestOps):
|
||||
gpu = True
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
|
|
|
@ -1,20 +1,22 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
import unittest
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.tensor import Tensor, GPU
|
||||
from tinygrad.optim import Adam, SGD, RMSprop
|
||||
from tinygrad.utils import get_parameters
|
||||
|
||||
x_init = np.random.randn(1,3).astype(np.float32)
|
||||
W_init = np.random.randn(3,3).astype(np.float32)
|
||||
m_init = np.random.randn(1,3).astype(np.float32)
|
||||
|
||||
def step_tinygrad(optim, kwargs={}):
|
||||
def step_tinygrad(optim, kwargs={}, gpu=False):
|
||||
net = TinyNet()
|
||||
optim = optim([net.x, net.W], **kwargs)
|
||||
if gpu is True: [x.cuda_() for x in get_parameters([net, optim])]
|
||||
out = net.forward()
|
||||
out.backward()
|
||||
optim.step()
|
||||
return net.x.data, net.W.data
|
||||
return net.x.cpu().data, net.W.cpu().data
|
||||
|
||||
def step_pytorch(optim, kwargs={}):
|
||||
net = TorchNet()
|
||||
|
@ -52,21 +54,29 @@ class TorchNet():
|
|||
|
||||
|
||||
class TestOptim(unittest.TestCase):
|
||||
gpu = False
|
||||
|
||||
def test_adam(self):
|
||||
for x,y in zip(step_tinygrad(Adam),
|
||||
for x,y in zip(step_tinygrad(Adam, gpu=self.gpu),
|
||||
step_pytorch(torch.optim.Adam)):
|
||||
np.testing.assert_allclose(x, y, atol=1e-4)
|
||||
|
||||
def test_sgd(self):
|
||||
for x,y in zip(step_tinygrad(SGD, kwargs={'lr': 0.001}),
|
||||
for x,y in zip(step_tinygrad(SGD, kwargs={'lr': 0.001}, gpu=self.gpu),
|
||||
step_pytorch(torch.optim.SGD, kwargs={'lr': 0.001})):
|
||||
np.testing.assert_allclose(x, y, atol=1e-5)
|
||||
|
||||
def test_rmsprop(self):
|
||||
for x,y in zip(step_tinygrad(RMSprop, kwargs={'lr': 0.001, 'decay': 0.99}),
|
||||
for x,y in zip(step_tinygrad(RMSprop, kwargs={'lr': 0.001, 'decay': 0.99}, gpu=self.gpu),
|
||||
step_pytorch(torch.optim.RMSprop,
|
||||
kwargs={'lr': 0.001, 'alpha': 0.99})):
|
||||
np.testing.assert_allclose(x, y, atol=1e-5)
|
||||
|
||||
|
||||
@unittest.skipUnless(GPU, "Requires GPU")
|
||||
class TestOptimGPU(TestOptim):
|
||||
gpu = True
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
import unittest
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.tensor import Tensor, GPU
|
||||
from extra.gradcheck import numerical_jacobian, jacobian, gradcheck
|
||||
|
||||
x_init = np.random.randn(1,3).astype(np.float32)
|
||||
|
@ -9,16 +9,18 @@ W_init = np.random.randn(3,3).astype(np.float32)
|
|||
m_init = np.random.randn(1,3).astype(np.float32)
|
||||
|
||||
class TestTinygrad(unittest.TestCase):
|
||||
gpu = False
|
||||
|
||||
def test_backward_pass(self):
|
||||
def test_tinygrad():
|
||||
x = Tensor(x_init)
|
||||
W = Tensor(W_init)
|
||||
m = Tensor(m_init)
|
||||
x = Tensor(x_init, gpu=self.gpu)
|
||||
W = Tensor(W_init, gpu=self.gpu)
|
||||
m = Tensor(m_init, gpu=self.gpu)
|
||||
out = x.dot(W).relu()
|
||||
out = out.logsoftmax()
|
||||
out = out.mul(m).add(m).sum()
|
||||
out.backward()
|
||||
return out.data, x.grad.data, W.grad.data
|
||||
return out.cpu().data, x.grad.cpu().data, W.grad.cpu().data
|
||||
|
||||
def test_pytorch():
|
||||
x = torch.tensor(x_init, requires_grad=True)
|
||||
|
@ -42,8 +44,8 @@ class TestTinygrad(unittest.TestCase):
|
|||
torch_func = lambda x: torch.nn.functional.log_softmax(x.matmul(torch_W).relu(), dim=1)
|
||||
PJ = torch.autograd.functional.jacobian(torch_func, torch_x).squeeze().numpy()
|
||||
|
||||
tiny_x = Tensor(x)
|
||||
tiny_W = Tensor(W)
|
||||
tiny_x = Tensor(x, gpu=self.gpu)
|
||||
tiny_W = Tensor(W, gpu=self.gpu)
|
||||
tiny_func = lambda x: x.dot(tiny_W).relu().logsoftmax()
|
||||
J = jacobian(tiny_func, tiny_x)
|
||||
NJ = numerical_jacobian(tiny_func, tiny_x)
|
||||
|
@ -55,8 +57,8 @@ class TestTinygrad(unittest.TestCase):
|
|||
W = np.random.RandomState(1337).random((10, 5))
|
||||
x = np.random.RandomState(7331).random((1, 10)) - 0.5
|
||||
|
||||
tiny_x = Tensor(x)
|
||||
tiny_W = Tensor(W)
|
||||
tiny_x = Tensor(x, gpu=self.gpu)
|
||||
tiny_W = Tensor(W, gpu=self.gpu)
|
||||
tiny_func = lambda x: x.dot(tiny_W).relu().logsoftmax()
|
||||
|
||||
self.assertTrue(gradcheck(tiny_func, tiny_x))
|
||||
|
@ -64,5 +66,17 @@ class TestTinygrad(unittest.TestCase):
|
|||
# coarse approx. since a "big" eps and the non-linearities of the model
|
||||
self.assertFalse(gradcheck(tiny_func, tiny_x, eps = 0.1))
|
||||
|
||||
|
||||
@unittest.skipUnless(GPU, "Requires GPU")
|
||||
class TestTinygradGPU(TestTinygrad):
|
||||
gpu = True
|
||||
|
||||
@unittest.skip("float64 not supported on GPU")
|
||||
def test_jacobian(self): pass
|
||||
|
||||
@unittest.skip("float64 not supported on GPU")
|
||||
def test_gradcheck(self): pass
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
@ -15,19 +15,15 @@ def fetch(url):
|
|||
os.rename(fp+".tmp", fp)
|
||||
return dat
|
||||
|
||||
def get_parameters(model):
|
||||
if isinstance(model, Tensor):
|
||||
return [model]
|
||||
def get_parameters(obj):
|
||||
parameters = []
|
||||
if hasattr(model, '__dict__'):
|
||||
for k,v in model.__dict__.items():
|
||||
if isinstance(v, Tensor):
|
||||
parameters.append(v)
|
||||
elif isinstance(v, list):
|
||||
for x in v:
|
||||
parameters.extend(get_parameters(x))
|
||||
elif hasattr(v, '__dict__'):
|
||||
parameters.extend(get_parameters(v))
|
||||
#print(k, type(v))
|
||||
if isinstance(obj, Tensor):
|
||||
parameters.append(obj)
|
||||
elif isinstance(obj, list):
|
||||
for x in obj:
|
||||
parameters.extend(get_parameters(x))
|
||||
elif hasattr(obj, '__dict__'):
|
||||
for k,v in obj.__dict__.items():
|
||||
parameters.extend(get_parameters(v))
|
||||
return parameters
|
||||
|
||||
|
|
Loading…
Reference in New Issue