added beautiful fashion mnist and example (#6961)

* added beautiful fashion mnist and example

* fixing whitespace

* refactor Fashion MNIST to fewer lines

* fix newline to reduce diff

* Update beautiful_mnist.py

* Update beautiful_mnist.py

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
Kinvert 2024-10-10 00:01:07 -04:00 committed by GitHub
parent b5546912e2
commit 960c495755
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 5 additions and 4 deletions

View File

@ -18,7 +18,7 @@ class Model:
def __call__(self, x:Tensor) -> Tensor: return x.sequential(self.layers) def __call__(self, x:Tensor) -> Tensor: return x.sequential(self.layers)
if __name__ == "__main__": if __name__ == "__main__":
X_train, Y_train, X_test, Y_test = mnist() X_train, Y_train, X_test, Y_test = mnist(fashion=getenv("FASHION"))
model = Model() model = Model()
opt = nn.optim.Adam(nn.state.get_parameters(model)) opt = nn.optim.Adam(nn.state.get_parameters(model))
@ -47,4 +47,4 @@ if __name__ == "__main__":
# verify eval acc # verify eval acc
if target := getenv("TARGET_EVAL_ACC_PCT", 0.0): if target := getenv("TARGET_EVAL_ACC_PCT", 0.0):
if test_acc >= target and test_acc != 100.0: print(colored(f"{test_acc=} >= {target}", "green")) if test_acc >= target and test_acc != 100.0: print(colored(f"{test_acc=} >= {target}", "green"))
else: raise ValueError(colored(f"{test_acc=} < {target}", "red")) else: raise ValueError(colored(f"{test_acc=} < {target}", "red"))

View File

@ -2,8 +2,9 @@ from tinygrad.tensor import Tensor
from tinygrad.helpers import fetch from tinygrad.helpers import fetch
from tinygrad.nn.state import tar_extract from tinygrad.nn.state import tar_extract
def _mnist(file): return Tensor(fetch("https://storage.googleapis.com/cvdf-datasets/mnist/"+file, gunzip=True)) def mnist(device=None, fashion=False):
def mnist(device=None): base_url = "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/" if fashion else "https://storage.googleapis.com/cvdf-datasets/mnist/"
def _mnist(file): return Tensor(fetch(base_url+file, gunzip=True))
return _mnist("train-images-idx3-ubyte.gz")[0x10:].reshape(-1,1,28,28).to(device), _mnist("train-labels-idx1-ubyte.gz")[8:].to(device), \ return _mnist("train-images-idx3-ubyte.gz")[0x10:].reshape(-1,1,28,28).to(device), _mnist("train-labels-idx1-ubyte.gz")[8:].to(device), \
_mnist("t10k-images-idx3-ubyte.gz")[0x10:].reshape(-1,1,28,28).to(device), _mnist("t10k-labels-idx1-ubyte.gz")[8:].to(device) _mnist("t10k-images-idx3-ubyte.gz")[0x10:].reshape(-1,1,28,28).to(device), _mnist("t10k-labels-idx1-ubyte.gz")[8:].to(device)