mirror of https://github.com/commaai/tinygrad.git
31 lines
993 B
Python
31 lines
993 B
Python
import os
|
|
import json
|
|
import numpy as np
|
|
from PIL import Image
|
|
import torchvision.transforms as transforms
|
|
|
|
BASEDIR = "/Users/kafka/fun/imagenet"
|
|
train_files = open(os.path.join(BASEDIR, "train_files")).read().strip().split("\n")
|
|
val_files = open(os.path.join(BASEDIR, "val_files")).read().strip().split("\n")
|
|
ci = json.load(open(os.path.join(BASEDIR, "imagenet_class_index.json")))
|
|
cir = {v[0]: int(k) for k,v in ci.items()}
|
|
|
|
rrc = transforms.RandomResizedCrop(224)
|
|
def image_load(fn):
|
|
img = Image.open(fn).convert('RGB')
|
|
ret = np.array(rrc(img))
|
|
return ret
|
|
|
|
def fetch_batch(bs, val=False):
|
|
files = val_files if val else train_files
|
|
samp = np.random.randint(0, len(files), size=(bs))
|
|
files = [files[i] for i in samp]
|
|
X = [image_load(os.path.join(BASEDIR, "val" if val else "train", x)) for x in files]
|
|
Y = [cir[x.split("/")[0]] for x in files]
|
|
return np.transpose(np.array(X), (0,3,1,2)), np.array(Y)
|
|
|
|
if __name__ == "__main__":
|
|
X,Y = fetch_batch(64)
|
|
print(X.shape, Y)
|
|
|