mirror of https://github.com/commaai/tinygrad.git
set training in functions
This commit is contained in:
parent
51bf164b72
commit
bcb3ceeca3
|
@ -105,9 +105,9 @@ if __name__ == "__main__":
|
|||
|
||||
X_train, Y_train, X_test, Y_test = make_dataset()
|
||||
optim = Adam(get_parameters(model), lr=0.001)
|
||||
train(model, X_train, Y_train, optim, 500, BS=16)
|
||||
|
||||
Tensor.training = False
|
||||
evaluate(model, X_test, Y_test, num_classes=10)
|
||||
for i in range(5):
|
||||
train(model, X_train, Y_train, optim, 500, BS=32)
|
||||
evaluate(model, X_test, Y_test, num_classes=10)
|
||||
|
||||
|
||||
|
|
|
@ -15,6 +15,7 @@ def sparse_categorical_crossentropy(out, Y):
|
|||
return out.mul(y).mean()
|
||||
|
||||
def train(model, X_train, Y_train, optim, steps, BS=128, device=Device.CPU, lossfn=sparse_categorical_crossentropy):
|
||||
Tensor.training = True
|
||||
if device == Device.GPU: [x.gpu_() for x in get_parameters([model, optim])]
|
||||
elif device == Device.ANE: [x.ane_() for x in get_parameters([model, optim])]
|
||||
losses, accuracies = [], []
|
||||
|
@ -27,7 +28,6 @@ def train(model, X_train, Y_train, optim, steps, BS=128, device=Device.CPU, loss
|
|||
# network
|
||||
out = model.forward(x)
|
||||
|
||||
# NLL loss function
|
||||
loss = lossfn(out, y)
|
||||
optim.zero_grad()
|
||||
loss.backward()
|
||||
|
@ -43,6 +43,7 @@ def train(model, X_train, Y_train, optim, steps, BS=128, device=Device.CPU, loss
|
|||
t.set_description("loss %.2f accuracy %.2f" % (loss, accuracy))
|
||||
|
||||
def evaluate(model, X_test, Y_test, num_classes=None, device=Device.CPU, BS=128):
|
||||
Tensor.training = False
|
||||
def numpy_eval(num_classes):
|
||||
Y_test_preds_out = np.zeros(list(Y_test.shape)+[num_classes])
|
||||
for i in trange(len(Y_test)//BS, disable=os.getenv('CI') is not None):
|
||||
|
@ -54,3 +55,4 @@ def evaluate(model, X_test, Y_test, num_classes=None, device=Device.CPU, BS=128)
|
|||
accuracy = numpy_eval(num_classes)
|
||||
print("test set accuracy is %f" % accuracy)
|
||||
return accuracy
|
||||
|
||||
|
|
Loading…
Reference in New Issue