tinygrad/datasets/imagenet.py

31 lines
989 B
Python
Raw Normal View History

2022-01-16 15:16:38 +08:00
import os
import json
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
BASEDIR = "/home/batman/imagenet"
train_files = open(os.path.join(BASEDIR, "train_files")).read().strip().split("\n")
val_files = open(os.path.join(BASEDIR, "val_files")).read().strip().split("\n")
2022-01-16 15:16:38 +08:00
ci = json.load(open(os.path.join(BASEDIR, "imagenet_class_index.json")))
cir = {v[0]: int(k) for k,v in ci.items()}
rrc = transforms.RandomResizedCrop(224)
def image_load(fn):
img = Image.open(fn).convert('RGB')
ret = np.array(rrc(img))
return ret
def fetch_batch(bs, val=False):
files = val_files if val else train_files
samp = np.random.randint(0, len(files), size=(bs))
files = [files[i] for i in samp]
X = [image_load(os.path.join(BASEDIR, "val" if val else "train", x)) for x in files]
Y = [cir[x.split("/")[0]] for x in files]
return np.transpose(np.array(X), (0,3,1,2)), np.array(Y)
if __name__ == "__main__":
X,Y = fetch_batch(64)
print(X.shape, Y)