From 960c495755f0c964381e63ca909c34004d752a1a Mon Sep 17 00:00:00 2001 From: Kinvert Date: Thu, 10 Oct 2024 00:01:07 -0400 Subject: [PATCH] added beautiful fashion mnist and example (#6961) * added beautiful fashion mnist and example * fixing whitespace * refactor Fashion MNIST to fewer lines * fix newline to reduce diff * Update beautiful_mnist.py * Update beautiful_mnist.py --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com> --- examples/beautiful_mnist.py | 4 ++-- tinygrad/nn/datasets.py | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/examples/beautiful_mnist.py b/examples/beautiful_mnist.py index 7f003f6f..4f43e414 100644 --- a/examples/beautiful_mnist.py +++ b/examples/beautiful_mnist.py @@ -18,7 +18,7 @@ class Model: def __call__(self, x:Tensor) -> Tensor: return x.sequential(self.layers) if __name__ == "__main__": - X_train, Y_train, X_test, Y_test = mnist() + X_train, Y_train, X_test, Y_test = mnist(fashion=getenv("FASHION")) model = Model() opt = nn.optim.Adam(nn.state.get_parameters(model)) @@ -47,4 +47,4 @@ if __name__ == "__main__": # verify eval acc if target := getenv("TARGET_EVAL_ACC_PCT", 0.0): if test_acc >= target and test_acc != 100.0: print(colored(f"{test_acc=} >= {target}", "green")) - else: raise ValueError(colored(f"{test_acc=} < {target}", "red")) \ No newline at end of file + else: raise ValueError(colored(f"{test_acc=} < {target}", "red")) diff --git a/tinygrad/nn/datasets.py b/tinygrad/nn/datasets.py index 6fac195c..44ffafe5 100644 --- a/tinygrad/nn/datasets.py +++ b/tinygrad/nn/datasets.py @@ -2,8 +2,9 @@ from tinygrad.tensor import Tensor from tinygrad.helpers import fetch from tinygrad.nn.state import tar_extract -def _mnist(file): return Tensor(fetch("https://storage.googleapis.com/cvdf-datasets/mnist/"+file, gunzip=True)) -def mnist(device=None): +def mnist(device=None, fashion=False): + base_url = "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/" if fashion else "https://storage.googleapis.com/cvdf-datasets/mnist/" + def _mnist(file): return Tensor(fetch(base_url+file, gunzip=True)) return _mnist("train-images-idx3-ubyte.gz")[0x10:].reshape(-1,1,28,28).to(device), _mnist("train-labels-idx1-ubyte.gz")[8:].to(device), \ _mnist("t10k-images-idx3-ubyte.gz")[0x10:].reshape(-1,1,28,28).to(device), _mnist("t10k-labels-idx1-ubyte.gz")[8:].to(device)