diff --git a/test/test_mnist.py b/test/test_mnist.py index 4e06cc1f..8a2d6a4d 100644 --- a/test/test_mnist.py +++ b/test/test_mnist.py @@ -3,10 +3,20 @@ import os import unittest import numpy as np from tinygrad.tensor import Tensor, GPU -from tinygrad.utils import layer_init_uniform, fetch_mnist +from tinygrad.utils import layer_init_uniform, fetch import tinygrad.optim as optim from tqdm import trange +# mnist loader +def fetch_mnist(): + import gzip + parse = lambda dat: np.frombuffer(gzip.decompress(dat), dtype=np.uint8).copy() + X_train = parse(fetch("http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz"))[0x10:].reshape((-1, 28, 28)) + Y_train = parse(fetch("http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz"))[8:] + X_test = parse(fetch("http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz"))[0x10:].reshape((-1, 28, 28)) + Y_test = parse(fetch("http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz"))[8:] + return X_train, Y_train, X_test, Y_test + # load the mnist dataset X_train, Y_train, X_test, Y_test = fetch_mnist() diff --git a/tinygrad/utils.py b/tinygrad/utils.py index f6e5d03f..b66888a6 100644 --- a/tinygrad/utils.py +++ b/tinygrad/utils.py @@ -23,12 +23,3 @@ def fetch(url): os.rename(fp+".tmp", fp) return dat -def fetch_mnist(): - import gzip - parse = lambda dat: np.frombuffer(gzip.decompress(dat), dtype=np.uint8).copy() - X_train = parse(fetch("http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz"))[0x10:].reshape((-1, 28, 28)) - Y_train = parse(fetch("http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz"))[8:] - X_test = parse(fetch("http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz"))[0x10:].reshape((-1, 28, 28)) - Y_test = parse(fetch("http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz"))[8:] - return X_train, Y_train, X_test, Y_test -