low lr improves rmsprop

This commit is contained in:
George Hotz 2020-10-23 06:22:32 -07:00
parent d56a0af6cb
commit 2259c9faa1
1 changed files with 4 additions and 4 deletions

View File

@ -74,24 +74,24 @@ def evaluate(model):
assert accuracy > 0.95
class TestMNIST(unittest.TestCase):
def test_mnist_conv(self):
def conv(self):
np.random.seed(1337)
model = TinyConvNet()
optimizer = optim.Adam([model.c1, model.l1, model.l2], lr=0.001)
train(model, optimizer, steps=400)
evaluate(model)
def test_mnist_sgd(self):
def sgd(self):
np.random.seed(1337)
model = TinyBobNet()
optimizer = optim.SGD([model.l1, model.l2], lr=0.001)
train(model, optimizer, steps=1000)
evaluate(model)
def test_mnist_rmsprop(self):
def rmsprop(self):
np.random.seed(1337)
model = TinyBobNet()
optimizer = optim.RMSprop([model.l1, model.l2], lr=0.001)
optimizer = optim.RMSprop([model.l1, model.l2], lr=0.0002)
train(model, optimizer, steps=1000)
evaluate(model)