Fix examples/train_efficientnet (#1947)

* added missing colon

* bug fixes for cifar10 dataset loading
needed a reshape to work with conv layers and resolve fetched tensor to numpy since further code expects numpy array
This commit is contained in:
Daniel Riege 2023-10-02 11:23:38 +02:00 committed by GitHub
parent d4671cd8e3
commit 579cabf668
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 5 additions and 3 deletions

View File

@ -59,15 +59,17 @@ if __name__ == "__main__":
p.daemon = True
p.start()
else:
X_train, Y_train = fetch_cifar()
X_train, Y_train, _, _ = fetch_cifar()
X_train = X_train.reshape((-1, 3, 32, 32))
Y_train = Y_train.reshape((-1,))
with Tensor.train()
with Tensor.train():
for i in (t := trange(steps)):
if IMAGENET:
X, Y = q.get(True)
else:
samp = np.random.randint(0, X_train.shape[0], size=(BS))
X, Y = X_train[samp], Y_train[samp]
X, Y = X_train.numpy()[samp], Y_train.numpy()[samp]
st = time.time()
out = model.forward(Tensor(X.astype(np.float32), requires_grad=False))