mock mnist data for imagenet trainer (#4095)

* mock mnist data for imagenet

* move print and test

* needed to reshape
This commit is contained in:
George Hotz 2024-04-06 08:08:40 -07:00 committed by GitHub
parent 8739d33fe9
commit fffd9b05f5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 63 additions and 17 deletions

View File

@ -56,11 +56,8 @@ def train_resnet():
eval_start_epoch = getenv("EVAL_START_EPOCH", 0)
eval_epochs = getenv("EVAL_EPOCHS", 1)
if getenv("MOCKDATA"):
steps_in_train_epoch, steps_in_val_epoch = 100, 0
else:
steps_in_train_epoch = config["steps_in_train_epoch"] = (len(get_train_files()) // BS)
steps_in_val_epoch = config["steps_in_val_epoch"] = (len(get_val_files()) // EVAL_BS)
steps_in_train_epoch = config["steps_in_train_epoch"] = (len(get_train_files()) // BS)
steps_in_val_epoch = config["steps_in_val_epoch"] = (len(get_val_files()) // EVAL_BS)
config["DEFAULT_FLOAT"] = dtypes.default_float.name
config["BEAM"] = BEAM.value
@ -133,12 +130,7 @@ def train_resnet():
for e in range(start_epoch, epochs):
# ** train loop **
Tensor.training = True
if getenv("MOCKDATA"):
def mockdata():
for _ in range(steps_in_train_epoch): yield Tensor.ones(BS,224,224,3,dtype=dtypes.uint8), [0]*BS, None
batch_loader = mockdata()
else:
batch_loader = batch_load_resnet(batch_size=BS, val=False, shuffle=True, seed=seed*epochs + e)
batch_loader = batch_load_resnet(batch_size=BS, val=False, shuffle=True, seed=seed*epochs + e)
it = iter(tqdm(batch_loader, total=steps_in_train_epoch, desc=f"epoch {e}", disable=BENCHMARK))
i, proc = 0, data_get(it)
st = time.perf_counter()

View File

@ -1,2 +1,3 @@
imagenet
imagenet_bak
mnist

View File

@ -0,0 +1,41 @@
#!/usr/bin/env python3
import pathlib, json
from tqdm import trange
from extra.datasets import fetch_mnist
from PIL import Image
import numpy as np
from multiprocessing import Pool
X_train, Y_train, X_test, Y_test = fetch_mnist()
def act(arg):
(basedir, i, train) = arg
if train:
img = np.uint8(X_train[i]).reshape(28, 28)
nm = f"train/{Y_train[i]}/{i}.jpg"
else:
img = np.uint8(X_test[i]).reshape(28, 28)
nm = f"val/{Y_test[i]}/{i}.jpg"
Image.fromarray(img).resize((224, 224)).convert('RGB').save(basedir / nm)
def create_fake_mnist_imagenet(basedir:pathlib.Path):
print(f"creating mock MNIST dataset at {basedir}")
basedir.mkdir(exist_ok=True)
with (basedir / "imagenet_class_index.json").open('w') as f:
f.write(json.dumps({str(i):[str(i), str(i)] for i in range(10)}))
for i in range(10):
(basedir / f"train/{i}").mkdir(parents=True, exist_ok=True)
(basedir / f"val/{i}").mkdir(parents=True, exist_ok=True)
def gen(train):
for idx in trange(X_train.shape[0] if train else X_test.shape[0]):
yield (basedir, idx, train)
with Pool(64) as p:
for _ in p.imap_unordered(act, gen(True)): pass
for _ in p.imap_unordered(act, gen(False)): pass
if __name__ == "__main__":
create_fake_mnist_imagenet(pathlib.Path("./mnist"))

View File

@ -5,17 +5,29 @@ from PIL import Image
import functools, pathlib
from tinygrad.helpers import diskcache, getenv
BASEDIR = pathlib.Path(__file__).parent / "imagenet"
@functools.lru_cache(None)
def get_imagenet_categories():
ci = json.load(open(BASEDIR / "imagenet_class_index.json"))
return {v[0]: int(k) for k,v in ci.items()}
@diskcache
def get_train_files():
if not (files:=glob.glob(p:=str(BASEDIR / "train/*/*"))): raise FileNotFoundError(f"No training files in {p}")
return files
if getenv("MNISTMOCK"):
BASEDIR = pathlib.Path(__file__).parent / "mnist"
@functools.lru_cache(None)
def get_train_files():
if not BASEDIR.exists():
from extra.datasets.fake_imagenet_from_mnist import create_fake_mnist_imagenet
create_fake_mnist_imagenet(BASEDIR)
if not (files:=glob.glob(p:=str(BASEDIR / "train/*/*"))): raise FileNotFoundError(f"No training files in {p}")
return files
else:
BASEDIR = pathlib.Path(__file__).parent / "imagenet"
@diskcache
def get_train_files():
if not (files:=glob.glob(p:=str(BASEDIR / "train/*/*"))): raise FileNotFoundError(f"No training files in {p}")
return files
@functools.lru_cache(None)
def get_val_files():