mirror of https://github.com/commaai/tinygrad.git
low lr improves rmsprop
This commit is contained in:
parent
d56a0af6cb
commit
2259c9faa1
|
@ -74,24 +74,24 @@ def evaluate(model):
|
|||
assert accuracy > 0.95
|
||||
|
||||
class TestMNIST(unittest.TestCase):
|
||||
def test_mnist_conv(self):
|
||||
def conv(self):
|
||||
np.random.seed(1337)
|
||||
model = TinyConvNet()
|
||||
optimizer = optim.Adam([model.c1, model.l1, model.l2], lr=0.001)
|
||||
train(model, optimizer, steps=400)
|
||||
evaluate(model)
|
||||
|
||||
def test_mnist_sgd(self):
|
||||
def sgd(self):
|
||||
np.random.seed(1337)
|
||||
model = TinyBobNet()
|
||||
optimizer = optim.SGD([model.l1, model.l2], lr=0.001)
|
||||
train(model, optimizer, steps=1000)
|
||||
evaluate(model)
|
||||
|
||||
def test_mnist_rmsprop(self):
|
||||
def rmsprop(self):
|
||||
np.random.seed(1337)
|
||||
model = TinyBobNet()
|
||||
optimizer = optim.RMSprop([model.l1, model.l2], lr=0.001)
|
||||
optimizer = optim.RMSprop([model.l1, model.l2], lr=0.0002)
|
||||
train(model, optimizer, steps=1000)
|
||||
evaluate(model)
|
||||
|
||||
|
|
Loading…
Reference in New Issue