mirror of https://github.com/commaai/tinygrad.git
54 lines
1.6 KiB
Python
54 lines
1.6 KiB
Python
# for imagenet download prepare.sh and run it
|
|
import glob, random
|
|
import json
|
|
import numpy as np
|
|
from PIL import Image
|
|
import functools, pathlib
|
|
|
|
BASEDIR = pathlib.Path(__file__).parent / "imagenet"
|
|
ci = json.load(open(BASEDIR / "imagenet_class_index.json"))
|
|
cir = {v[0]: int(k) for k,v in ci.items()}
|
|
|
|
@functools.lru_cache(None)
|
|
def get_train_files():
|
|
train_files = open(BASEDIR / "train_files").read().strip().split("\n")
|
|
return [(BASEDIR / "train" / x) for x in train_files]
|
|
|
|
@functools.lru_cache(None)
|
|
def get_val_files():
|
|
val_files = glob.glob(str(BASEDIR / "val/*/*"))
|
|
return val_files
|
|
|
|
#rrc = transforms.RandomResizedCrop(224)
|
|
import torchvision.transforms.functional as F
|
|
def image_load(fn):
|
|
img = Image.open(fn).convert('RGB')
|
|
img = F.resize(img, 256, Image.BILINEAR)
|
|
img = F.center_crop(img, 224)
|
|
ret = np.array(img)
|
|
return ret
|
|
|
|
def iterate(bs=32, val=True, shuffle=True):
|
|
files = get_val_files() if val else get_train_files()
|
|
order = list(range(0, len(files)))
|
|
if shuffle: random.shuffle(order)
|
|
from multiprocessing import Pool
|
|
p = Pool(16)
|
|
for i in range(0, len(files), bs):
|
|
X = p.map(image_load, [files[i] for i in order[i:i+bs]])
|
|
Y = [cir[files[i].split("/")[-2]] for i in order[i:i+bs]]
|
|
yield (np.array(X), np.array(Y))
|
|
|
|
def fetch_batch(bs, val=False):
|
|
files = get_val_files() if val else get_train_files()
|
|
samp = np.random.randint(0, len(files), size=(bs))
|
|
files = [files[i] for i in samp]
|
|
X = [image_load(x) for x in files]
|
|
Y = [cir[x.split("/")[0]] for x in files]
|
|
return np.array(X), np.array(Y)
|
|
|
|
if __name__ == "__main__":
|
|
X,Y = fetch_batch(64)
|
|
print(X.shape, Y)
|
|
|