diff --git a/datasets/imagenet.py b/datasets/imagenet.py index b5610533..2968737e 100644 --- a/datasets/imagenet.py +++ b/datasets/imagenet.py @@ -5,8 +5,8 @@ from PIL import Image import torchvision.transforms as transforms BASEDIR = "/home/batman/imagenet" -train_files = open(os.path.join(BASEDIR, "train_files")).read().split("\n") -val_files = open(os.path.join(BASEDIR, "val_files")).read().split("\n") +train_files = open(os.path.join(BASEDIR, "train_files")).read().strip().split("\n") +val_files = open(os.path.join(BASEDIR, "val_files")).read().strip().split("\n") ci = json.load(open(os.path.join(BASEDIR, "imagenet_class_index.json"))) cir = {v[0]: int(k) for k,v in ci.items()} diff --git a/examples/train_efficientnet.py b/examples/train_efficientnet.py index 0a348d80..a193fd77 100644 --- a/examples/train_efficientnet.py +++ b/examples/train_efficientnet.py @@ -1,4 +1,5 @@ import os +import traceback import time import numpy as np from models.efficientnet import EfficientNet @@ -53,7 +54,10 @@ if __name__ == "__main__": from multiprocessing import Process, Queue def loader(q): while 1: - q.put(fetch_batch(BS)) + try: + q.put(fetch_batch(BS)) + except Exception: + traceback.print_exc() q = Queue(16) for i in range(2): p = Process(target=loader, args=(q,))