don't crash the dataloader for imagenet

This commit is contained in:
George Hotz 2022-01-16 08:41:26 -08:00
parent 907ff7dbb6
commit 2cae2dfa07
2 changed files with 7 additions and 3 deletions

View File

@ -5,8 +5,8 @@ from PIL import Image
import torchvision.transforms as transforms
BASEDIR = "/home/batman/imagenet"
train_files = open(os.path.join(BASEDIR, "train_files")).read().split("\n")
val_files = open(os.path.join(BASEDIR, "val_files")).read().split("\n")
train_files = open(os.path.join(BASEDIR, "train_files")).read().strip().split("\n")
val_files = open(os.path.join(BASEDIR, "val_files")).read().strip().split("\n")
ci = json.load(open(os.path.join(BASEDIR, "imagenet_class_index.json")))
cir = {v[0]: int(k) for k,v in ci.items()}

View File

@ -1,4 +1,5 @@
import os
import traceback
import time
import numpy as np
from models.efficientnet import EfficientNet
@ -53,7 +54,10 @@ if __name__ == "__main__":
from multiprocessing import Process, Queue
def loader(q):
while 1:
try:
q.put(fetch_batch(BS))
except Exception:
traceback.print_exc()
q = Queue(16)
for i in range(2):
p = Process(target=loader, args=(q,))