mirror of https://github.com/commaai/tinygrad.git
Merge remote-tracking branch 'upstream/master'
This commit is contained in:
commit
d33885f496
|
@ -38,6 +38,7 @@ print(y.grad) # dz/dy
|
|||
|
||||
### TODO (to make real neural network library)
|
||||
|
||||
* Implement gradcheck (numeric)
|
||||
* Implement convolutions
|
||||
* Implement Adam optimizer
|
||||
|
||||
|
|
|
@ -1,32 +1,17 @@
|
|||
#!/usr/bin/env python
|
||||
import numpy as np
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn import layer_init, SGD
|
||||
from tinygrad.utils import fetch_mnist
|
||||
|
||||
from tqdm import trange
|
||||
|
||||
# load the mnist dataset
|
||||
|
||||
def fetch(url):
|
||||
import requests, gzip, os, hashlib, numpy
|
||||
fp = os.path.join("/tmp", hashlib.md5(url.encode('utf-8')).hexdigest())
|
||||
if os.path.isfile(fp):
|
||||
with open(fp, "rb") as f:
|
||||
dat = f.read()
|
||||
else:
|
||||
with open(fp, "wb") as f:
|
||||
dat = requests.get(url).content
|
||||
f.write(dat)
|
||||
return numpy.frombuffer(gzip.decompress(dat), dtype=np.uint8).copy()
|
||||
X_train = fetch("http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz")[0x10:].reshape((-1, 28, 28))
|
||||
Y_train = fetch("http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz")[8:]
|
||||
X_test = fetch("http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz")[0x10:].reshape((-1, 28, 28))
|
||||
Y_test = fetch("http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz")[8:]
|
||||
X_train, Y_train, X_test, Y_test = fetch_mnist()
|
||||
|
||||
# train a model
|
||||
|
||||
def layer_init(m, h):
|
||||
ret = np.random.uniform(-1., 1., size=(m,h))/np.sqrt(m*h)
|
||||
return ret.astype(np.float32)
|
||||
|
||||
class TinyBobNet:
|
||||
def __init__(self):
|
||||
self.l1 = Tensor(layer_init(784, 128))
|
||||
|
@ -35,9 +20,12 @@ class TinyBobNet:
|
|||
def forward(self, x):
|
||||
return x.dot(self.l1).relu().dot(self.l2).logsoftmax()
|
||||
|
||||
model = TinyBobNet()
|
||||
# optimizer
|
||||
|
||||
|
||||
model = TinyBobNet()
|
||||
optim = SGD([model.l1, model.l2], lr=0.01)
|
||||
|
||||
lr = 0.01
|
||||
BS = 128
|
||||
losses, accuracies = [], []
|
||||
for i in (t := trange(1000)):
|
||||
|
@ -55,13 +43,11 @@ for i in (t := trange(1000)):
|
|||
# NLL loss function
|
||||
loss = outs.mul(y).mean()
|
||||
loss.backward()
|
||||
optim.step()
|
||||
|
||||
cat = np.argmax(outs.data, axis=1)
|
||||
accuracy = (cat == Y).mean()
|
||||
|
||||
# SGD
|
||||
model.l1.data = model.l1.data - lr*model.l1.grad
|
||||
model.l2.data = model.l2.data - lr*model.l2.grad
|
||||
|
||||
# printing
|
||||
loss = loss.data
|
||||
|
|
|
@ -0,0 +1,15 @@
|
|||
import numpy as np
|
||||
|
||||
def layer_init(m, h):
|
||||
ret = np.random.uniform(-1., 1., size=(m,h))/np.sqrt(m*h)
|
||||
return ret.astype(np.float32)
|
||||
|
||||
class SGD:
|
||||
def __init__(self, tensors, lr):
|
||||
self.tensors = tensors
|
||||
self.lr = lr
|
||||
|
||||
def step(self):
|
||||
for t in self.tensors:
|
||||
t.data -= self.lr * t.grad
|
||||
|
|
@ -2,16 +2,7 @@
|
|||
from functools import partialmethod
|
||||
import numpy as np
|
||||
|
||||
# **** start with three base classes ****
|
||||
|
||||
class Context:
|
||||
def __init__(self, arg, *tensors):
|
||||
self.arg = arg
|
||||
self.parents = tensors
|
||||
self.saved_tensors = []
|
||||
|
||||
def save_for_backward(self, *x):
|
||||
self.saved_tensors.extend(x)
|
||||
# **** start with two base classes ****
|
||||
|
||||
class Tensor:
|
||||
def __init__(self, data):
|
||||
|
@ -35,17 +26,19 @@ class Tensor:
|
|||
|
||||
if self.grad is None and allow_fill:
|
||||
# fill in the first grad with one
|
||||
# this is "implicit gradient creation"
|
||||
assert self.data.size == 1
|
||||
self.grad = np.ones_like(self.data)
|
||||
|
||||
assert(self.grad is not None)
|
||||
|
||||
grads = self._ctx.arg.backward(self._ctx, self.grad)
|
||||
grads = self._ctx.backward(self._ctx, self.grad)
|
||||
if len(self._ctx.parents) == 1:
|
||||
grads = [grads]
|
||||
for t,g in zip(self._ctx.parents, grads):
|
||||
if g.shape != t.data.shape:
|
||||
print("grad shape must match tensor shape in %r, %r != %r" % (self._ctx.arg, g.shape, t.data.shape))
|
||||
print("grad shape must match tensor shape in %r, %r != %r" %
|
||||
(self._ctx, g.shape, t.data.shape))
|
||||
assert(False)
|
||||
t.grad = g
|
||||
t.backward(False)
|
||||
|
@ -54,9 +47,18 @@ class Tensor:
|
|||
div = Tensor(np.array([1/self.data.size]))
|
||||
return self.sum().mul(div)
|
||||
|
||||
# An instantiation of the Function is the Context
|
||||
class Function:
|
||||
def __init__(self, *tensors):
|
||||
self.parents = tensors
|
||||
self.saved_tensors = []
|
||||
|
||||
def save_for_backward(self, *x):
|
||||
self.saved_tensors.extend(x)
|
||||
|
||||
# note that due to how partialmethod works, self and arg are switched
|
||||
def apply(self, arg, *x):
|
||||
ctx = Context(arg, self, *x)
|
||||
ctx = arg(self, *x)
|
||||
ret = Tensor(arg.forward(ctx, self.data, *[t.data for t in x]))
|
||||
ret._ctx = ctx
|
||||
return ret
|
||||
|
|
|
@ -0,0 +1,17 @@
|
|||
def fetch_mnist():
|
||||
def fetch(url):
|
||||
import requests, gzip, os, hashlib, numpy
|
||||
fp = os.path.join("/tmp", hashlib.md5(url.encode('utf-8')).hexdigest())
|
||||
if not os.path.isfile(fp):
|
||||
with open(fp, "rb") as f:
|
||||
dat = f.read()
|
||||
else:
|
||||
with open(fp, "wb") as f:
|
||||
dat = requests.get(url).content
|
||||
f.write(dat)
|
||||
return numpy.frombuffer(gzip.decompress(dat), dtype=numpy.uint8).copy()
|
||||
X_train = fetch("http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz")[0x10:].reshape((-1, 28, 28))
|
||||
Y_train = fetch("http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz")[8:]
|
||||
X_test = fetch("http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz")[0x10:].reshape((-1, 28, 28))
|
||||
Y_test = fetch("http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz")[8:]
|
||||
return X_train, Y_train, X_test, Y_test
|
Loading…
Reference in New Issue