You like pytorch? You like micrograd? You love tinygrad! ❤️
Go to file
Marcel Bischoff d24363f421
Update setup.py (#49)
I think `:=` in tinygrad/test/test_mnist.py actually needs 3.8
2020-11-02 18:09:31 -08:00
.github/workflows the power of cheating 2020-11-02 07:42:11 -08:00
docs adds beautiful and meaningful logo 2020-10-26 18:12:49 +01:00
examples alternative to einsum which is slow in extreme cases (#44) 2020-11-02 06:36:07 -08:00
test gpu relu is good 2020-11-02 08:25:32 -08:00
tinygrad gpu relu is good 2020-11-02 08:25:32 -08:00
.gitignore add setup.py and change imports to relative 2020-10-26 18:19:50 +03:00
LICENSE readme 2020-10-18 11:27:37 -07:00
README.md more readme 2020-11-02 08:33:48 -08:00
push_pypi.sh push pypi 2020-10-27 08:13:15 -07:00
requirements.txt fix for invalid GPU error caused by (test/test_net_speed.py::TestConvSpeed::test_mnist) when testing in CI (#31) 2020-10-29 17:45:16 -07:00
setup.py Update setup.py (#49) 2020-11-02 18:09:31 -08:00

README.md


Unit Tests

For something in between a pytorch and a karpathy/micrograd

This may not be the best deep learning framework, but it is a deep learning framework.

The Tensor class is a wrapper around a numpy array, except it does Tensor things.

Installation

pip3 install tinygrad

Example

from tinygrad.tensor import Tensor

x = Tensor.eye(3)
y = Tensor([[2.0,0,-2.0]])
z = y.matmul(x).sum()
z.backward()

print(x.grad)  # dz/dx
print(y.grad)  # dz/dy

Same example in torch

import torch

x = torch.eye(3, requires_grad=True)
y = torch.tensor([[2.0,0,-2.0]], requires_grad=True)
z = y.matmul(x).sum()
z.backward()

print(x.grad)  # dz/dx
print(y.grad)  # dz/dy

Neural networks?

It turns out, a decent autograd tensor library is 90% of what you need for neural networks. Add an optimizer (SGD, RMSprop, and Adam implemented) from tinygrad.optim, write some boilerplate minibatching code, and you have all you need.

Neural network example (from test/test_mnist.py)

from tinygrad.tensor import Tensor
import tinygrad.optim as optim
from tinygrad.utils import layer_init_uniform

class TinyBobNet:
  def __init__(self):
    self.l1 = Tensor(layer_init_uniform(784, 128))
    self.l2 = Tensor(layer_init_uniform(128, 10))

  def forward(self, x):
    return x.dot(self.l1).relu().dot(self.l2).logsoftmax()

model = TinyBobNet()
optim = optim.SGD([model.l1, model.l2], lr=0.001)

# ... and complete like pytorch, with (x,y) data

out = model.forward(x)
loss = out.mul(y).mean()
loss.backward()
optim.step()

GPU Support?!

tinygrad supports GPUs through PyOpenCL. Not all ops are supported yet.

from tinygrad.tensor import Tensor
(Tensor.ones(4,4).cuda() + Tensor.ones(4,4).cuda()).cpu()

ImageNet inference (on the micrograd puppy)

python3 examples/efficientnet.py

The promise of small

tinygrad will always be below 1000 lines. If it isn't, we will revert commits until tinygrad becomes smaller.

Running tests

python -m pytest

TODO

  • Train an EfficientNet on ImageNet
    • Make broadcasting work on the backward pass (simple please)
    • EfficientNet backward pass
    • Tensors on GPU (GPU support, must support Mac)
  • Reduce code
  • Increase speed
  • Add features