mirror of https://github.com/commaai/tinygrad.git
move the mnist loader out of tinygrad proper
This commit is contained in:
parent
498b4d2f27
commit
52ee913c98
|
@ -3,10 +3,20 @@ import os
|
|||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad.tensor import Tensor, GPU
|
||||
from tinygrad.utils import layer_init_uniform, fetch_mnist
|
||||
from tinygrad.utils import layer_init_uniform, fetch
|
||||
import tinygrad.optim as optim
|
||||
from tqdm import trange
|
||||
|
||||
# mnist loader
|
||||
def fetch_mnist():
|
||||
import gzip
|
||||
parse = lambda dat: np.frombuffer(gzip.decompress(dat), dtype=np.uint8).copy()
|
||||
X_train = parse(fetch("http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz"))[0x10:].reshape((-1, 28, 28))
|
||||
Y_train = parse(fetch("http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz"))[8:]
|
||||
X_test = parse(fetch("http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz"))[0x10:].reshape((-1, 28, 28))
|
||||
Y_test = parse(fetch("http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz"))[8:]
|
||||
return X_train, Y_train, X_test, Y_test
|
||||
|
||||
# load the mnist dataset
|
||||
X_train, Y_train, X_test, Y_test = fetch_mnist()
|
||||
|
||||
|
|
|
@ -23,12 +23,3 @@ def fetch(url):
|
|||
os.rename(fp+".tmp", fp)
|
||||
return dat
|
||||
|
||||
def fetch_mnist():
|
||||
import gzip
|
||||
parse = lambda dat: np.frombuffer(gzip.decompress(dat), dtype=np.uint8).copy()
|
||||
X_train = parse(fetch("http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz"))[0x10:].reshape((-1, 28, 28))
|
||||
Y_train = parse(fetch("http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz"))[8:]
|
||||
X_test = parse(fetch("http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz"))[0x10:].reshape((-1, 28, 28))
|
||||
Y_test = parse(fetch("http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz"))[8:]
|
||||
return X_train, Y_train, X_test, Y_test
|
||||
|
||||
|
|
Loading…
Reference in New Issue