mirror of https://github.com/commaai/tinygrad.git
do test_conv_with_bn test
This commit is contained in:
parent
5495c7d64e
commit
623fb1ef28
|
@ -19,7 +19,6 @@ def fetch_cifar(train=True):
|
|||
cifar10_std = np.array([0.24703225141799082, 0.24348516474564, 0.26158783926049628], dtype=np.float32).reshape(1,3,1,1)
|
||||
tt = tarfile.open(fileobj=io.BytesIO(fetch('https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz')), mode='r:gz')
|
||||
if train:
|
||||
# TODO: data_batch 2-5
|
||||
db = [pickle.load(tt.extractfile(f'cifar-10-batches-py/data_batch_{i}'), encoding="bytes") for i in range(1,6)]
|
||||
else:
|
||||
db = [pickle.load(tt.extractfile('cifar-10-batches-py/test_batch'), encoding="bytes")]
|
||||
|
|
|
@ -94,13 +94,12 @@ class TestMNIST(unittest.TestCase):
|
|||
train(model, X_train, Y_train, optimizer, steps=100)
|
||||
assert evaluate(model, X_test, Y_test) > 0.94 # torch gets 0.9415 sometimes
|
||||
|
||||
@unittest.skip("slow and training batchnorm is broken")
|
||||
def test_conv_with_bn(self):
|
||||
np.random.seed(1337)
|
||||
model = TinyConvNet(has_batchnorm=True)
|
||||
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
||||
train(model, X_train, Y_train, optimizer, steps=100)
|
||||
assert evaluate(model, X_test, Y_test) > 0.7 # TODO: batchnorm doesn't work!!!
|
||||
optimizer = optim.AdamW(model.parameters(), lr=0.003)
|
||||
train(model, X_train, Y_train, optimizer, steps=200)
|
||||
assert evaluate(model, X_test, Y_test) > 0.94
|
||||
|
||||
def test_sgd(self):
|
||||
np.random.seed(1337)
|
||||
|
|
Loading…
Reference in New Issue