tinygrad/examples/train_efficientnet.py

104 lines
2.9 KiB
Python
Raw Normal View History

import traceback
2020-11-10 08:01:16 +08:00
import time
from multiprocessing import Process, Queue
2020-11-10 08:01:16 +08:00
import numpy as np
2020-12-07 04:20:14 +08:00
from tqdm import trange
from tinygrad.state import get_parameters
from tinygrad.nn import optim
from tinygrad.helpers import getenv
from tinygrad.tensor import Tensor
from extra.datasets import fetch_cifar
from models.efficientnet import EfficientNet
2020-12-07 04:20:14 +08:00
class TinyConvNet:
def __init__(self, classes=10):
conv = 3
inter_chan, out_chan = 8, 16 # for speed
self.c1 = Tensor.uniform(inter_chan,3,conv,conv)
self.c2 = Tensor.uniform(out_chan,inter_chan,conv,conv)
self.l1 = Tensor.uniform(out_chan*6*6, classes)
2020-12-07 04:20:14 +08:00
def forward(self, x):
x = x.conv2d(self.c1).relu().max_pool2d()
x = x.conv2d(self.c2).relu().max_pool2d()
x = x.reshape(shape=[x.shape[0], -1])
2022-01-17 02:45:58 +08:00
return x.dot(self.l1)
2020-12-07 04:20:14 +08:00
2020-11-10 08:01:16 +08:00
if __name__ == "__main__":
IMAGENET = getenv("IMAGENET")
2022-01-16 15:27:31 +08:00
classes = 1000 if IMAGENET else 10
2020-12-07 04:20:14 +08:00
TINY = getenv("TINY")
TRANSFER = getenv("TRANSFER")
2020-12-07 04:20:14 +08:00
if TINY:
model = TinyConvNet(classes)
elif TRANSFER:
model = EfficientNet(getenv("NUM", 0), classes, has_se=True)
2021-12-01 00:13:34 +08:00
model.load_from_pretrained()
2020-12-07 04:20:14 +08:00
else:
model = EfficientNet(getenv("NUM", 0), classes, has_se=False)
2020-12-07 04:29:42 +08:00
parameters = get_parameters(model)
2022-01-16 11:47:01 +08:00
print("parameter count", len(parameters))
2020-12-07 03:10:30 +08:00
optimizer = optim.Adam(parameters, lr=0.001)
2020-11-10 08:01:16 +08:00
BS, steps = getenv("BS", 64 if TINY else 16), getenv("STEPS", 2048)
print(f"training with batch size {BS} for {steps} steps")
2020-11-10 08:01:16 +08:00
2022-01-16 15:27:31 +08:00
if IMAGENET:
from extra.datasets.imagenet import fetch_batch
2022-01-16 15:27:31 +08:00
def loader(q):
while 1:
try:
q.put(fetch_batch(BS))
except Exception:
traceback.print_exc()
2022-01-16 15:27:31 +08:00
q = Queue(16)
for i in range(2):
p = Process(target=loader, args=(q,))
p.daemon = True
p.start()
else:
X_train, Y_train = fetch_cifar()
2022-01-16 11:57:41 +08:00
Tensor.training = True
2020-12-07 04:20:14 +08:00
for i in (t := trange(steps)):
2022-01-16 15:16:38 +08:00
if IMAGENET:
2022-01-16 15:27:31 +08:00
X, Y = q.get(True)
2022-01-16 15:16:38 +08:00
else:
samp = np.random.randint(0, X_train.shape[0], size=(BS))
2022-01-16 15:27:31 +08:00
X, Y = X_train[samp], Y_train[samp]
2020-12-05 02:00:32 +08:00
st = time.time()
2022-01-16 15:27:31 +08:00
out = model.forward(Tensor(X.astype(np.float32), requires_grad=False))
2020-12-07 04:20:14 +08:00
fp_time = (time.time()-st)*1000.0
2020-11-10 17:19:52 +08:00
2020-12-07 04:20:14 +08:00
y = np.zeros((BS,classes), np.float32)
y[range(y.shape[0]),Y] = -classes
2022-01-16 14:41:27 +08:00
y = Tensor(y, requires_grad=False)
loss = out.log_softmax().mul(y).mean()
2020-11-10 08:05:52 +08:00
2020-12-07 03:10:30 +08:00
optimizer.zero_grad()
2020-12-07 02:45:04 +08:00
2020-12-05 02:00:32 +08:00
st = time.time()
loss.backward()
2020-12-07 04:20:14 +08:00
bp_time = (time.time()-st)*1000.0
2020-11-10 08:01:16 +08:00
2020-12-07 03:10:30 +08:00
st = time.time()
optimizer.step()
2020-12-07 04:20:14 +08:00
opt_time = (time.time()-st)*1000.0
2020-12-07 04:29:42 +08:00
st = time.time()
2023-02-19 08:36:12 +08:00
loss = loss.cpu().numpy()
cat = np.argmax(out.cpu().numpy(), axis=1)
2020-12-07 04:20:14 +08:00
accuracy = (cat == Y).mean()
2020-12-07 04:29:42 +08:00
finish_time = (time.time()-st)*1000.0
2020-12-07 04:20:14 +08:00
# printing
t.set_description("loss %.2f accuracy %.2f -- %.2f + %.2f + %.2f + %.2f = %.2f" %
2020-12-07 04:29:42 +08:00
(loss, accuracy,
fp_time, bp_time, opt_time, finish_time,
fp_time + bp_time + opt_time + finish_time))
2020-12-07 03:10:30 +08:00
2020-12-07 02:34:40 +08:00
del out, y, loss