2024-04-06 23:08:40 +08:00
|
|
|
#!/usr/bin/env python3
|
|
|
|
import pathlib, json
|
2024-06-23 02:45:06 +08:00
|
|
|
from tinygrad.helpers import trange
|
2024-04-06 23:08:40 +08:00
|
|
|
from extra.datasets import fetch_mnist
|
|
|
|
from PIL import Image
|
|
|
|
import numpy as np
|
|
|
|
from multiprocessing import Pool
|
|
|
|
|
|
|
|
X_train, Y_train, X_test, Y_test = fetch_mnist()
|
|
|
|
|
|
|
|
def act(arg):
|
|
|
|
(basedir, i, train) = arg
|
|
|
|
if train:
|
|
|
|
img = np.uint8(X_train[i]).reshape(28, 28)
|
|
|
|
nm = f"train/{Y_train[i]}/{i}.jpg"
|
|
|
|
else:
|
|
|
|
img = np.uint8(X_test[i]).reshape(28, 28)
|
|
|
|
nm = f"val/{Y_test[i]}/{i}.jpg"
|
|
|
|
Image.fromarray(img).resize((224, 224)).convert('RGB').save(basedir / nm)
|
|
|
|
|
|
|
|
def create_fake_mnist_imagenet(basedir:pathlib.Path):
|
|
|
|
print(f"creating mock MNIST dataset at {basedir}")
|
|
|
|
basedir.mkdir(exist_ok=True)
|
|
|
|
|
|
|
|
with (basedir / "imagenet_class_index.json").open('w') as f:
|
|
|
|
f.write(json.dumps({str(i):[str(i), str(i)] for i in range(10)}))
|
|
|
|
|
|
|
|
for i in range(10):
|
|
|
|
(basedir / f"train/{i}").mkdir(parents=True, exist_ok=True)
|
|
|
|
(basedir / f"val/{i}").mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
|
|
def gen(train):
|
|
|
|
for idx in trange(X_train.shape[0] if train else X_test.shape[0]):
|
|
|
|
yield (basedir, idx, train)
|
|
|
|
|
|
|
|
with Pool(64) as p:
|
|
|
|
for _ in p.imap_unordered(act, gen(True)): pass
|
|
|
|
for _ in p.imap_unordered(act, gen(False)): pass
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
create_fake_mnist_imagenet(pathlib.Path("./mnist"))
|