mirror of https://github.com/commaai/tinygrad.git
added beautiful fashion mnist and example (#6961)
* added beautiful fashion mnist and example * fixing whitespace * refactor Fashion MNIST to fewer lines * fix newline to reduce diff * Update beautiful_mnist.py * Update beautiful_mnist.py --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
parent
b5546912e2
commit
960c495755
|
@ -18,7 +18,7 @@ class Model:
|
||||||
def __call__(self, x:Tensor) -> Tensor: return x.sequential(self.layers)
|
def __call__(self, x:Tensor) -> Tensor: return x.sequential(self.layers)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
X_train, Y_train, X_test, Y_test = mnist()
|
X_train, Y_train, X_test, Y_test = mnist(fashion=getenv("FASHION"))
|
||||||
|
|
||||||
model = Model()
|
model = Model()
|
||||||
opt = nn.optim.Adam(nn.state.get_parameters(model))
|
opt = nn.optim.Adam(nn.state.get_parameters(model))
|
||||||
|
@ -47,4 +47,4 @@ if __name__ == "__main__":
|
||||||
# verify eval acc
|
# verify eval acc
|
||||||
if target := getenv("TARGET_EVAL_ACC_PCT", 0.0):
|
if target := getenv("TARGET_EVAL_ACC_PCT", 0.0):
|
||||||
if test_acc >= target and test_acc != 100.0: print(colored(f"{test_acc=} >= {target}", "green"))
|
if test_acc >= target and test_acc != 100.0: print(colored(f"{test_acc=} >= {target}", "green"))
|
||||||
else: raise ValueError(colored(f"{test_acc=} < {target}", "red"))
|
else: raise ValueError(colored(f"{test_acc=} < {target}", "red"))
|
||||||
|
|
|
@ -2,8 +2,9 @@ from tinygrad.tensor import Tensor
|
||||||
from tinygrad.helpers import fetch
|
from tinygrad.helpers import fetch
|
||||||
from tinygrad.nn.state import tar_extract
|
from tinygrad.nn.state import tar_extract
|
||||||
|
|
||||||
def _mnist(file): return Tensor(fetch("https://storage.googleapis.com/cvdf-datasets/mnist/"+file, gunzip=True))
|
def mnist(device=None, fashion=False):
|
||||||
def mnist(device=None):
|
base_url = "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/" if fashion else "https://storage.googleapis.com/cvdf-datasets/mnist/"
|
||||||
|
def _mnist(file): return Tensor(fetch(base_url+file, gunzip=True))
|
||||||
return _mnist("train-images-idx3-ubyte.gz")[0x10:].reshape(-1,1,28,28).to(device), _mnist("train-labels-idx1-ubyte.gz")[8:].to(device), \
|
return _mnist("train-images-idx3-ubyte.gz")[0x10:].reshape(-1,1,28,28).to(device), _mnist("train-labels-idx1-ubyte.gz")[8:].to(device), \
|
||||||
_mnist("t10k-images-idx3-ubyte.gz")[0x10:].reshape(-1,1,28,28).to(device), _mnist("t10k-labels-idx1-ubyte.gz")[8:].to(device)
|
_mnist("t10k-images-idx3-ubyte.gz")[0x10:].reshape(-1,1,28,28).to(device), _mnist("t10k-labels-idx1-ubyte.gz")[8:].to(device)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue