set training in functions

This commit is contained in:
George Hotz 2020-12-28 22:45:46 -05:00
parent 51bf164b72
commit bcb3ceeca3
2 changed files with 6 additions and 4 deletions

View File

@ -105,9 +105,9 @@ if __name__ == "__main__":
X_train, Y_train, X_test, Y_test = make_dataset()
optim = Adam(get_parameters(model), lr=0.001)
train(model, X_train, Y_train, optim, 500, BS=16)
Tensor.training = False
evaluate(model, X_test, Y_test, num_classes=10)
for i in range(5):
train(model, X_train, Y_train, optim, 500, BS=32)
evaluate(model, X_test, Y_test, num_classes=10)

View File

@ -15,6 +15,7 @@ def sparse_categorical_crossentropy(out, Y):
return out.mul(y).mean()
def train(model, X_train, Y_train, optim, steps, BS=128, device=Device.CPU, lossfn=sparse_categorical_crossentropy):
Tensor.training = True
if device == Device.GPU: [x.gpu_() for x in get_parameters([model, optim])]
elif device == Device.ANE: [x.ane_() for x in get_parameters([model, optim])]
losses, accuracies = [], []
@ -27,7 +28,6 @@ def train(model, X_train, Y_train, optim, steps, BS=128, device=Device.CPU, loss
# network
out = model.forward(x)
# NLL loss function
loss = lossfn(out, y)
optim.zero_grad()
loss.backward()
@ -43,6 +43,7 @@ def train(model, X_train, Y_train, optim, steps, BS=128, device=Device.CPU, loss
t.set_description("loss %.2f accuracy %.2f" % (loss, accuracy))
def evaluate(model, X_test, Y_test, num_classes=None, device=Device.CPU, BS=128):
Tensor.training = False
def numpy_eval(num_classes):
Y_test_preds_out = np.zeros(list(Y_test.shape)+[num_classes])
for i in trange(len(Y_test)//BS, disable=os.getenv('CI') is not None):
@ -54,3 +55,4 @@ def evaluate(model, X_test, Y_test, num_classes=None, device=Device.CPU, BS=128)
accuracy = numpy_eval(num_classes)
print("test set accuracy is %f" % accuracy)
return accuracy