do test_conv_with_bn test

This commit is contained in:
George Hotz 2023-03-19 23:53:56 -07:00
parent 5495c7d64e
commit 623fb1ef28
2 changed files with 3 additions and 5 deletions

View File

@ -19,7 +19,6 @@ def fetch_cifar(train=True):
cifar10_std = np.array([0.24703225141799082, 0.24348516474564, 0.26158783926049628], dtype=np.float32).reshape(1,3,1,1)
tt = tarfile.open(fileobj=io.BytesIO(fetch('https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz')), mode='r:gz')
if train:
# TODO: data_batch 2-5
db = [pickle.load(tt.extractfile(f'cifar-10-batches-py/data_batch_{i}'), encoding="bytes") for i in range(1,6)]
else:
db = [pickle.load(tt.extractfile('cifar-10-batches-py/test_batch'), encoding="bytes")]

View File

@ -94,13 +94,12 @@ class TestMNIST(unittest.TestCase):
train(model, X_train, Y_train, optimizer, steps=100)
assert evaluate(model, X_test, Y_test) > 0.94 # torch gets 0.9415 sometimes
@unittest.skip("slow and training batchnorm is broken")
def test_conv_with_bn(self):
np.random.seed(1337)
model = TinyConvNet(has_batchnorm=True)
optimizer = optim.Adam(model.parameters(), lr=0.001)
train(model, X_train, Y_train, optimizer, steps=100)
assert evaluate(model, X_test, Y_test) > 0.7 # TODO: batchnorm doesn't work!!!
optimizer = optim.AdamW(model.parameters(), lr=0.003)
train(model, X_train, Y_train, optimizer, steps=200)
assert evaluate(model, X_test, Y_test) > 0.94
def test_sgd(self):
np.random.seed(1337)