From f4e0cb5945f9327e0b2d3aac54d94f84081d8676 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Sun, 18 Oct 2020 13:19:19 -0700 Subject: [PATCH 1/3] refactor tinygrad to be more tiny --- README.md | 1 + tinygrad/tensor.py | 27 ++++++++++++++------------- 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index f7579bb4..a06e2a7d 100644 --- a/README.md +++ b/README.md @@ -38,6 +38,7 @@ print(y.grad) # dz/dy ### TODO (to make real neural network library) +* Implement gradcheck (numeric) * Implement convolutions * Implement Adam optimizer diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 959e4737..170cd9c9 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -2,16 +2,7 @@ from functools import partialmethod import numpy as np -# **** start with three base classes **** - -class Context: - def __init__(self, arg, *tensors): - self.arg = arg - self.parents = tensors - self.saved_tensors = [] - - def save_for_backward(self, *x): - self.saved_tensors.extend(x) +# **** start with two base classes **** class Tensor: def __init__(self, data): @@ -35,17 +26,18 @@ class Tensor: if self.grad is None and allow_fill: # fill in the first grad with one + # this is "implicit gradient creation" assert self.data.size == 1 self.grad = np.ones_like(self.data) assert(self.grad is not None) - grads = self._ctx.arg.backward(self._ctx, self.grad) + grads = self._ctx.backward(self._ctx, self.grad) if len(self._ctx.parents) == 1: grads = [grads] for t,g in zip(self._ctx.parents, grads): if g.shape != t.data.shape: - print("grad shape must match tensor shape in %r, %r != %r" % (self._ctx.arg, g.shape, t.data.shape)) + print("grad shape must match tensor shape in %r, %r != %r" % (self._ctx, g.shape, t.data.shape)) assert(False) t.grad = g t.backward(False) @@ -54,9 +46,18 @@ class Tensor: div = Tensor(np.array([1/self.data.size])) return self.sum().mul(div) +# The Function is the Context class Function: + def __init__(self, *tensors): + self.parents = tensors + self.saved_tensors = [] + + def save_for_backward(self, *x): + self.saved_tensors.extend(x) + + # note that due to how partialmethod works, self and arg are switched def apply(self, arg, *x): - ctx = Context(arg, self, *x) + ctx = arg(self, *x) ret = Tensor(arg.forward(ctx, self.data, *[t.data for t in x])) ret._ctx = ctx return ret From 118c2eebe3a81bad4ef76e178aab14978e071ac1 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Sun, 18 Oct 2020 13:27:59 -0700 Subject: [PATCH 2/3] write sgd class --- test/mnist.py | 19 ++++++++++++++----- tinygrad/tensor.py | 5 +++-- 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/test/mnist.py b/test/mnist.py index 3bc1aa95..ffcead0a 100644 --- a/test/mnist.py +++ b/test/mnist.py @@ -35,9 +35,20 @@ class TinyBobNet: def forward(self, x): return x.dot(self.l1).relu().dot(self.l2).logsoftmax() -model = TinyBobNet() +# optimizer + +class SGD: + def __init__(self, tensors, lr): + self.tensors = tensors + self.lr = lr + + def step(self): + for t in self.tensors: + t.data -= self.lr * t.grad + +model = TinyBobNet() +optim = SGD([model.l1, model.l2], lr=0.01) -lr = 0.01 BS = 128 losses, accuracies = [], [] for i in (t := trange(1000)): @@ -55,13 +66,11 @@ for i in (t := trange(1000)): # NLL loss function loss = outs.mul(y).mean() loss.backward() + optim.step() cat = np.argmax(outs.data, axis=1) accuracy = (cat == Y).mean() - # SGD - model.l1.data = model.l1.data - lr*model.l1.grad - model.l2.data = model.l2.data - lr*model.l2.grad # printing loss = loss.data diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 170cd9c9..13ef238b 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -37,7 +37,8 @@ class Tensor: grads = [grads] for t,g in zip(self._ctx.parents, grads): if g.shape != t.data.shape: - print("grad shape must match tensor shape in %r, %r != %r" % (self._ctx, g.shape, t.data.shape)) + print("grad shape must match tensor shape in %r, %r != %r" % + (self._ctx, g.shape, t.data.shape)) assert(False) t.grad = g t.backward(False) @@ -46,7 +47,7 @@ class Tensor: div = Tensor(np.array([1/self.data.size])) return self.sum().mul(div) -# The Function is the Context +# An instantiation of the Function is the Context class Function: def __init__(self, *tensors): self.parents = tensors From 92fd23df66ff3d9618d908a87ec349871c618720 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Sun, 18 Oct 2020 13:30:25 -0700 Subject: [PATCH 3/3] refactor into a few files --- test/mnist.py | 31 ++++--------------------------- tinygrad/nn.py | 15 +++++++++++++++ tinygrad/utils.py | 17 +++++++++++++++++ 3 files changed, 36 insertions(+), 27 deletions(-) create mode 100644 tinygrad/nn.py create mode 100644 tinygrad/utils.py diff --git a/test/mnist.py b/test/mnist.py index ffcead0a..686b6fcf 100644 --- a/test/mnist.py +++ b/test/mnist.py @@ -1,32 +1,17 @@ #!/usr/bin/env python import numpy as np from tinygrad.tensor import Tensor +from tinygrad.nn import layer_init, SGD +from tinygrad.utils import fetch_mnist + from tqdm import trange # load the mnist dataset -def fetch(url): - import requests, gzip, os, hashlib, numpy - fp = os.path.join("/tmp", hashlib.md5(url.encode('utf-8')).hexdigest()) - if not os.path.isfile(fp): - with open(fp, "rb") as f: - dat = f.read() - else: - with open(fp, "wb") as f: - dat = requests.get(url).content - f.write(dat) - return numpy.frombuffer(gzip.decompress(dat), dtype=np.uint8).copy() -X_train = fetch("http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz")[0x10:].reshape((-1, 28, 28)) -Y_train = fetch("http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz")[8:] -X_test = fetch("http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz")[0x10:].reshape((-1, 28, 28)) -Y_test = fetch("http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz")[8:] +X_train, Y_train, X_test, Y_test = fetch_mnist() # train a model -def layer_init(m, h): - ret = np.random.uniform(-1., 1., size=(m,h))/np.sqrt(m*h) - return ret.astype(np.float32) - class TinyBobNet: def __init__(self): self.l1 = Tensor(layer_init(784, 128)) @@ -37,14 +22,6 @@ class TinyBobNet: # optimizer -class SGD: - def __init__(self, tensors, lr): - self.tensors = tensors - self.lr = lr - - def step(self): - for t in self.tensors: - t.data -= self.lr * t.grad model = TinyBobNet() optim = SGD([model.l1, model.l2], lr=0.01) diff --git a/tinygrad/nn.py b/tinygrad/nn.py new file mode 100644 index 00000000..e3f6e95f --- /dev/null +++ b/tinygrad/nn.py @@ -0,0 +1,15 @@ +import numpy as np + +def layer_init(m, h): + ret = np.random.uniform(-1., 1., size=(m,h))/np.sqrt(m*h) + return ret.astype(np.float32) + +class SGD: + def __init__(self, tensors, lr): + self.tensors = tensors + self.lr = lr + + def step(self): + for t in self.tensors: + t.data -= self.lr * t.grad + diff --git a/tinygrad/utils.py b/tinygrad/utils.py new file mode 100644 index 00000000..bbef87f6 --- /dev/null +++ b/tinygrad/utils.py @@ -0,0 +1,17 @@ +def fetch_mnist(): + def fetch(url): + import requests, gzip, os, hashlib, numpy + fp = os.path.join("/tmp", hashlib.md5(url.encode('utf-8')).hexdigest()) + if not os.path.isfile(fp): + with open(fp, "rb") as f: + dat = f.read() + else: + with open(fp, "wb") as f: + dat = requests.get(url).content + f.write(dat) + return numpy.frombuffer(gzip.decompress(dat), dtype=numpy.uint8).copy() + X_train = fetch("http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz")[0x10:].reshape((-1, 28, 28)) + Y_train = fetch("http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz")[8:] + X_test = fetch("http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz")[0x10:].reshape((-1, 28, 28)) + Y_test = fetch("http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz")[8:] + return X_train, Y_train, X_test, Y_test