move the mnist loader out of tinygrad proper

This commit is contained in:
George Hotz 2020-11-10 15:37:39 -08:00
parent 498b4d2f27
commit 52ee913c98
2 changed files with 11 additions and 10 deletions

View File

@ -3,10 +3,20 @@ import os
import unittest
import numpy as np
from tinygrad.tensor import Tensor, GPU
from tinygrad.utils import layer_init_uniform, fetch_mnist
from tinygrad.utils import layer_init_uniform, fetch
import tinygrad.optim as optim
from tqdm import trange
# mnist loader
def fetch_mnist():
import gzip
parse = lambda dat: np.frombuffer(gzip.decompress(dat), dtype=np.uint8).copy()
X_train = parse(fetch("http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz"))[0x10:].reshape((-1, 28, 28))
Y_train = parse(fetch("http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz"))[8:]
X_test = parse(fetch("http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz"))[0x10:].reshape((-1, 28, 28))
Y_test = parse(fetch("http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz"))[8:]
return X_train, Y_train, X_test, Y_test
# load the mnist dataset
X_train, Y_train, X_test, Y_test = fetch_mnist()

View File

@ -23,12 +23,3 @@ def fetch(url):
os.rename(fp+".tmp", fp)
return dat
def fetch_mnist():
import gzip
parse = lambda dat: np.frombuffer(gzip.decompress(dat), dtype=np.uint8).copy()
X_train = parse(fetch("http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz"))[0x10:].reshape((-1, 28, 28))
Y_train = parse(fetch("http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz"))[8:]
X_test = parse(fetch("http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz"))[0x10:].reshape((-1, 28, 28))
Y_test = parse(fetch("http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz"))[8:]
return X_train, Y_train, X_test, Y_test