mirror of https://github.com/commaai/tinygrad.git
don't crash the dataloader for imagenet
This commit is contained in:
parent
907ff7dbb6
commit
2cae2dfa07
|
@ -5,8 +5,8 @@ from PIL import Image
|
|||
import torchvision.transforms as transforms
|
||||
|
||||
BASEDIR = "/home/batman/imagenet"
|
||||
train_files = open(os.path.join(BASEDIR, "train_files")).read().split("\n")
|
||||
val_files = open(os.path.join(BASEDIR, "val_files")).read().split("\n")
|
||||
train_files = open(os.path.join(BASEDIR, "train_files")).read().strip().split("\n")
|
||||
val_files = open(os.path.join(BASEDIR, "val_files")).read().strip().split("\n")
|
||||
ci = json.load(open(os.path.join(BASEDIR, "imagenet_class_index.json")))
|
||||
cir = {v[0]: int(k) for k,v in ci.items()}
|
||||
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import os
|
||||
import traceback
|
||||
import time
|
||||
import numpy as np
|
||||
from models.efficientnet import EfficientNet
|
||||
|
@ -53,7 +54,10 @@ if __name__ == "__main__":
|
|||
from multiprocessing import Process, Queue
|
||||
def loader(q):
|
||||
while 1:
|
||||
try:
|
||||
q.put(fetch_batch(BS))
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
q = Queue(16)
|
||||
for i in range(2):
|
||||
p = Process(target=loader, args=(q,))
|
||||
|
|
Loading…
Reference in New Issue