From fffd9b05f52a9259692168a5bfbf939f0cd5f39e Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Sat, 6 Apr 2024 08:08:40 -0700 Subject: [PATCH] mock mnist data for imagenet trainer (#4095) * mock mnist data for imagenet * move print and test * needed to reshape --- examples/mlperf/model_train.py | 14 ++------ extra/datasets/.gitignore | 1 + extra/datasets/fake_imagenet_from_mnist.py | 41 ++++++++++++++++++++++ extra/datasets/imagenet.py | 24 +++++++++---- 4 files changed, 63 insertions(+), 17 deletions(-) create mode 100755 extra/datasets/fake_imagenet_from_mnist.py diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index 89055b08..0d418938 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -56,11 +56,8 @@ def train_resnet(): eval_start_epoch = getenv("EVAL_START_EPOCH", 0) eval_epochs = getenv("EVAL_EPOCHS", 1) - if getenv("MOCKDATA"): - steps_in_train_epoch, steps_in_val_epoch = 100, 0 - else: - steps_in_train_epoch = config["steps_in_train_epoch"] = (len(get_train_files()) // BS) - steps_in_val_epoch = config["steps_in_val_epoch"] = (len(get_val_files()) // EVAL_BS) + steps_in_train_epoch = config["steps_in_train_epoch"] = (len(get_train_files()) // BS) + steps_in_val_epoch = config["steps_in_val_epoch"] = (len(get_val_files()) // EVAL_BS) config["DEFAULT_FLOAT"] = dtypes.default_float.name config["BEAM"] = BEAM.value @@ -133,12 +130,7 @@ def train_resnet(): for e in range(start_epoch, epochs): # ** train loop ** Tensor.training = True - if getenv("MOCKDATA"): - def mockdata(): - for _ in range(steps_in_train_epoch): yield Tensor.ones(BS,224,224,3,dtype=dtypes.uint8), [0]*BS, None - batch_loader = mockdata() - else: - batch_loader = batch_load_resnet(batch_size=BS, val=False, shuffle=True, seed=seed*epochs + e) + batch_loader = batch_load_resnet(batch_size=BS, val=False, shuffle=True, seed=seed*epochs + e) it = iter(tqdm(batch_loader, total=steps_in_train_epoch, desc=f"epoch {e}", disable=BENCHMARK)) i, proc = 0, data_get(it) st = time.perf_counter() diff --git a/extra/datasets/.gitignore b/extra/datasets/.gitignore index 7b843be3..cce2cccf 100644 --- a/extra/datasets/.gitignore +++ b/extra/datasets/.gitignore @@ -1,2 +1,3 @@ imagenet imagenet_bak +mnist diff --git a/extra/datasets/fake_imagenet_from_mnist.py b/extra/datasets/fake_imagenet_from_mnist.py new file mode 100755 index 00000000..983ac384 --- /dev/null +++ b/extra/datasets/fake_imagenet_from_mnist.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python3 +import pathlib, json +from tqdm import trange +from extra.datasets import fetch_mnist +from PIL import Image +import numpy as np +from multiprocessing import Pool + +X_train, Y_train, X_test, Y_test = fetch_mnist() + +def act(arg): + (basedir, i, train) = arg + if train: + img = np.uint8(X_train[i]).reshape(28, 28) + nm = f"train/{Y_train[i]}/{i}.jpg" + else: + img = np.uint8(X_test[i]).reshape(28, 28) + nm = f"val/{Y_test[i]}/{i}.jpg" + Image.fromarray(img).resize((224, 224)).convert('RGB').save(basedir / nm) + +def create_fake_mnist_imagenet(basedir:pathlib.Path): + print(f"creating mock MNIST dataset at {basedir}") + basedir.mkdir(exist_ok=True) + + with (basedir / "imagenet_class_index.json").open('w') as f: + f.write(json.dumps({str(i):[str(i), str(i)] for i in range(10)})) + + for i in range(10): + (basedir / f"train/{i}").mkdir(parents=True, exist_ok=True) + (basedir / f"val/{i}").mkdir(parents=True, exist_ok=True) + + def gen(train): + for idx in trange(X_train.shape[0] if train else X_test.shape[0]): + yield (basedir, idx, train) + + with Pool(64) as p: + for _ in p.imap_unordered(act, gen(True)): pass + for _ in p.imap_unordered(act, gen(False)): pass + +if __name__ == "__main__": + create_fake_mnist_imagenet(pathlib.Path("./mnist")) \ No newline at end of file diff --git a/extra/datasets/imagenet.py b/extra/datasets/imagenet.py index f00cc657..53dd7d57 100644 --- a/extra/datasets/imagenet.py +++ b/extra/datasets/imagenet.py @@ -5,17 +5,29 @@ from PIL import Image import functools, pathlib from tinygrad.helpers import diskcache, getenv -BASEDIR = pathlib.Path(__file__).parent / "imagenet" - @functools.lru_cache(None) def get_imagenet_categories(): ci = json.load(open(BASEDIR / "imagenet_class_index.json")) return {v[0]: int(k) for k,v in ci.items()} -@diskcache -def get_train_files(): - if not (files:=glob.glob(p:=str(BASEDIR / "train/*/*"))): raise FileNotFoundError(f"No training files in {p}") - return files +if getenv("MNISTMOCK"): + BASEDIR = pathlib.Path(__file__).parent / "mnist" + + @functools.lru_cache(None) + def get_train_files(): + if not BASEDIR.exists(): + from extra.datasets.fake_imagenet_from_mnist import create_fake_mnist_imagenet + create_fake_mnist_imagenet(BASEDIR) + + if not (files:=glob.glob(p:=str(BASEDIR / "train/*/*"))): raise FileNotFoundError(f"No training files in {p}") + return files +else: + BASEDIR = pathlib.Path(__file__).parent / "imagenet" + + @diskcache + def get_train_files(): + if not (files:=glob.glob(p:=str(BASEDIR / "train/*/*"))): raise FileNotFoundError(f"No training files in {p}") + return files @functools.lru_cache(None) def get_val_files():