mirror of https://github.com/commaai/tinygrad.git
Fix examples/train_efficientnet (#1947)
* added missing colon * bug fixes for cifar10 dataset loading needed a reshape to work with conv layers and resolve fetched tensor to numpy since further code expects numpy array
This commit is contained in:
parent
d4671cd8e3
commit
579cabf668
|
@ -59,15 +59,17 @@ if __name__ == "__main__":
|
|||
p.daemon = True
|
||||
p.start()
|
||||
else:
|
||||
X_train, Y_train = fetch_cifar()
|
||||
X_train, Y_train, _, _ = fetch_cifar()
|
||||
X_train = X_train.reshape((-1, 3, 32, 32))
|
||||
Y_train = Y_train.reshape((-1,))
|
||||
|
||||
with Tensor.train()
|
||||
with Tensor.train():
|
||||
for i in (t := trange(steps)):
|
||||
if IMAGENET:
|
||||
X, Y = q.get(True)
|
||||
else:
|
||||
samp = np.random.randint(0, X_train.shape[0], size=(BS))
|
||||
X, Y = X_train[samp], Y_train[samp]
|
||||
X, Y = X_train.numpy()[samp], Y_train.numpy()[samp]
|
||||
|
||||
st = time.time()
|
||||
out = model.forward(Tensor(X.astype(np.float32), requires_grad=False))
|
||||
|
|
Loading…
Reference in New Issue