mirror of https://github.com/commaai/tinygrad.git
Add a quick start guide (#900)
* feat: initial quick start guide * fix: fix link * feat: add note about jit * feat: add note about load/store ops * feat: add link to discord * feat: add note about saving and loading models * fix: correct code for saving and loading * feat: overhaul docs * fix: fix link * feat: wording * feat: add link to discord * feat: contributing guidelines * feat: make contributing section more doc focused * feat: add link to env_vars from readme * fix: wording * feat: move community to bottom * feat: showcase * feat: linebreak * feat: redesigned header * feat: tweaks * feat: tweaks * feat: badge for lines of code * feat: move installation instructions to repo readme * feat: readme overhaul number 2 * feat: move visualization to quick start guide * feat: readme 2 electric boogaloo * fix: grammar * fix: formatting * feat: no ugly line * feat: add line back * feat: new load method * feat: split adding accelerator docs out * feat: showcase whisper * feat: smaller tweaks * feat: bring back oneliner
This commit is contained in:
parent
d429553730
commit
e9c1ae3825
|
@ -0,0 +1,4 @@
|
|||
*
|
||||
!*/
|
||||
|
||||
!tinygrad/**
|
244
README.md
244
README.md
|
@ -1,91 +1,57 @@
|
|||
<p align="center">
|
||||
<img src="https://raw.githubusercontent.com/geohot/tinygrad/master/docs/logo.png">
|
||||
</p>
|
||||
<div align="center">
|
||||
|
||||
--------------------------------------------------------------------
|
||||
[![logo](https://raw.githubusercontent.com/geohot/tinygrad/master/docs/logo.png)](https://tinygrad.org)
|
||||
|
||||
![Unit Tests](https://github.com/geohot/tinygrad/workflows/Unit%20Tests/badge.svg)
|
||||
tinygrad: For something between [PyTorch](https://github.com/pytorch/pytorch) and [karpathy/micrograd](https://github.com/karpathy/micrograd). Maintained by [tiny corp](https://tinygrad.org).
|
||||
|
||||
[![tinygrad discord](https://discordapp.com/api/guilds/1068976834382925865/widget.png?style=banner2)](https://discord.gg/ZjZadyC7PK)
|
||||
<h3>
|
||||
|
||||
For something in between a [pytorch](https://github.com/pytorch/pytorch) and a [karpathy/micrograd](https://github.com/karpathy/micrograd)
|
||||
[Homepage](https://github.com/geohot/tinygrad) | [Documentation](/docs) | [Examples](/examples) | [Showcase](/docs/showcase.md) | [Discord](https://discord.gg/ZjZadyC7PK)
|
||||
|
||||
</h3>
|
||||
|
||||
[![GitHub Repo stars](https://img.shields.io/github/stars/geohot/tinygrad)](https://github.com/geohot/tinygrad/stargazers)
|
||||
[![Unit Tests](https://github.com/geohot/tinygrad/actions/workflows/test.yml/badge.svg)](https://github.com/geohot/tinygrad/actions/workflows/test.yml)
|
||||
[![Discord](https://img.shields.io/discord/1068976834382925865)](https://discord.gg/ZjZadyC7PK)
|
||||
[![Lines of code](https://img.shields.io/tokei/lines/github/geohot/tinygrad)](https://github.com/geohot/tinygrad)
|
||||
|
||||
</div>
|
||||
|
||||
---
|
||||
|
||||
This may not be the best deep learning framework, but it is a deep learning framework.
|
||||
|
||||
The sub 1000 line core of it is in `tinygrad/`
|
||||
Due to its extreme simplicity, it aims to be the easiest framework to add new accelerators to, with support for both inference and training.
|
||||
|
||||
Due to its extreme simplicity, it aims to be the easiest framework to add new accelerators to, with support for both inference and training. Support the simple basic ops, and you get SOTA [vision](https://arxiv.org/abs/1905.11946) `models/efficientnet.py` and [language](https://arxiv.org/abs/1706.03762) `models/transformer.py` models.
|
||||
Eventually, we will have a [tinygrad accelerator](https://geohot.github.io/blog/jekyll/update/2021/06/13/a-breakdown-of-ai-chip-companies.html), then tinygrad will be ***fast***. But, for now, it is slow.
|
||||
|
||||
We are working on support for the Apple Neural Engine and the Google TPU in the `accel/` folder. Eventually, [we will build custom hardware](https://geohot.github.io/blog/jekyll/update/2021/06/13/a-breakdown-of-ai-chip-companies.html) for tinygrad, and it will be blindingly fast. Now, it is slow.
|
||||
## Features
|
||||
|
||||
This project is maintained by [tiny corp](https://tinygrad.org/).
|
||||
### LLaMA and Stable Diffusion
|
||||
|
||||
### Installation
|
||||
tinygrad can run [LLaMA](/docs/showcase.md#llama) and [Stable Diffusion](/docs/showcase.md#stable-diffusion)!
|
||||
|
||||
```bash
|
||||
git clone https://github.com/geohot/tinygrad.git
|
||||
cd tinygrad
|
||||
python3 -m pip install -e .
|
||||
```
|
||||
|
||||
### Contributing
|
||||
|
||||
There's a lot of interest in tinygrad lately. Here's some guidelines for contributing:
|
||||
|
||||
* Bugfixes are the best and always welcome! Like [this one](https://github.com/geohot/tinygrad/pull/421/files).
|
||||
* If you don't understand the code you are changing, don't change it!
|
||||
* All code golf PRs will be closed, but [conceptual cleanups](https://github.com/geohot/tinygrad/pull/372/files) are great.
|
||||
* Features are welcome. Though if you are adding a feature, you need to include tests.
|
||||
* Improving test coverage is great, with reliable non brittle tests.
|
||||
|
||||
### Example
|
||||
|
||||
```python
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
x = Tensor.eye(3, requires_grad=True)
|
||||
y = Tensor([[2.0,0,-2.0]], requires_grad=True)
|
||||
z = y.matmul(x).sum()
|
||||
z.backward()
|
||||
|
||||
print(x.grad.numpy()) # dz/dx
|
||||
print(y.grad.numpy()) # dz/dy
|
||||
```
|
||||
|
||||
### Same example in torch
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
x = torch.eye(3, requires_grad=True)
|
||||
y = torch.tensor([[2.0,0,-2.0]], requires_grad=True)
|
||||
z = y.matmul(x).sum()
|
||||
z.backward()
|
||||
|
||||
print(x.grad) # dz/dx
|
||||
print(y.grad) # dz/dy
|
||||
```
|
||||
|
||||
## Is tinygrad fast?
|
||||
### Laziness
|
||||
|
||||
Try a matmul. See how, despite the style, it is fused into one kernel with the power of laziness.
|
||||
|
||||
```python
|
||||
```sh
|
||||
DEBUG=3 OPTLOCAL=1 python3 -c "from tinygrad.tensor import Tensor;
|
||||
N = 1024; a, b = Tensor.randn(N, N), Tensor.randn(N, N);
|
||||
c = (a.reshape(N, 1, N) * b.permute(1,0).reshape(1, N, N)).sum(axis=2);
|
||||
print((c.numpy() - (a.numpy() @ b.numpy())).mean())"
|
||||
```
|
||||
|
||||
Change to `DEBUG=4` to see the generated code.
|
||||
And we can change `DEBUG` to `4` to see the generated code.
|
||||
|
||||
## Neural networks?
|
||||
### Neural networks
|
||||
|
||||
It turns out, a decent autograd tensor library is 90% of what you need for neural networks. Add an optimizer (SGD, Adam, AdamW implemented) from tinygrad.nn.optim, write some boilerplate minibatching code, and you have all you need.
|
||||
As it turns out, 90% of what you need for neural networks are a decent autograd/tensor library.
|
||||
Throw in an optimizer, a data loader, and some compute, and you have all you need.
|
||||
|
||||
### Neural network example (from test/models/test_mnist.py)
|
||||
#### Neural network example (from test/models/test_mnist.py)
|
||||
|
||||
```python
|
||||
```py
|
||||
from tinygrad.tensor import Tensor
|
||||
import tinygrad.nn.optim as optim
|
||||
|
||||
|
@ -100,7 +66,7 @@ class TinyBobNet:
|
|||
model = TinyBobNet()
|
||||
optim = optim.SGD([model.l1, model.l2], lr=0.001)
|
||||
|
||||
# ... and complete like pytorch, with (x,y) data
|
||||
# ... complete data loader here
|
||||
|
||||
out = model.forward(x)
|
||||
loss = out.mul(y).mean()
|
||||
|
@ -109,114 +75,86 @@ loss.backward()
|
|||
optim.step()
|
||||
```
|
||||
|
||||
## GPU and Accelerator Support
|
||||
## Accelerators
|
||||
|
||||
tinygrad supports GPUs through PyOpenCL.
|
||||
tinygrad already supports numerous accelerators, including:
|
||||
|
||||
```python
|
||||
- [x] CPU
|
||||
- [x] GPU (OpenCL)
|
||||
- [x] C Code (Clang)
|
||||
- [x] LLVM
|
||||
- [x] METAL
|
||||
- [x] CUDA
|
||||
- [x] Triton
|
||||
- [x] PyTorch
|
||||
|
||||
And it is easy to add more! Your accelerator of choice only needs to support a total of 20 (optionally 21) low level ops.
|
||||
More information can be found in the [documentation for adding new accelerators](/docs/adding_new_accelerators.md).
|
||||
|
||||
## Installation
|
||||
|
||||
The current recommended way to install tinygrad is from source.
|
||||
|
||||
### From source
|
||||
|
||||
```sh
|
||||
git clone https://github.com/geohot/tinygrad.git
|
||||
cd tinygrad
|
||||
python3 -m pip install -e . # or `py3 -m pip install -e .` if you are on windows
|
||||
```
|
||||
Don't forget the `.` at the end!
|
||||
|
||||
## Documentation
|
||||
|
||||
Documentation along with a quick start guide can be found in the [docs/](/docs) directory.
|
||||
|
||||
### Quick example comparing to PyTorch
|
||||
|
||||
```py
|
||||
from tinygrad.tensor import Tensor
|
||||
(Tensor.ones(4,4).gpu() + Tensor.ones(4,4).gpu()).cpu()
|
||||
|
||||
x = Tensor.eye(3, requires_grad=True)
|
||||
y = Tensor([[2.0,0,-2.0]], requires_grad=True)
|
||||
z = y.matmul(x).sum()
|
||||
z.backward()
|
||||
|
||||
print(x.grad.numpy()) # dz/dx
|
||||
print(y.grad.numpy()) # dz/dy
|
||||
```
|
||||
|
||||
### hlops (in tensor.py)
|
||||
The same thing but in PyTorch:
|
||||
```py
|
||||
import torch
|
||||
|
||||
hlops are syntactic sugar around mlops. They support most things torch does.
|
||||
x = torch.eye(3, requires_grad=True)
|
||||
y = torch.tensor([[2.0,0,-2.0]], requires_grad=True)
|
||||
z = y.matmul(x).sum()
|
||||
z.backward()
|
||||
|
||||
### mlops
|
||||
|
||||
mlops are mid level ops. They understand derivatives. They are very simple.
|
||||
|
||||
```
|
||||
Relu, Log, Exp, Sin # unary ops
|
||||
Sum, Max # reduce ops (with axis argument)
|
||||
Maximum, Add, Sub, Mul, Pow, Div, Equal # binary ops (no broadcasting, use expand)
|
||||
Expand, Reshape, Permute, Pad, Shrink, Flip # movement ops
|
||||
print(x.grad.numpy()) # dz/dx
|
||||
print(y.grad.numpy()) # dz/dy
|
||||
```
|
||||
|
||||
You no longer need to write mlops for a new accelerator
|
||||
## Contributing
|
||||
|
||||
### Adding an accelerator (llops)
|
||||
There has been a lot of interest in tinygrad lately. Here are some basic guidelines for contributing:
|
||||
|
||||
The autodiff stuff is all in mlops now so you can focus on the raw operations
|
||||
- Bug fixes are the best and always welcome! Like [this one](https://github.com/geohot/tinygrad/pull/421/files).
|
||||
- If you don't understand the code you are changing, don't change it!
|
||||
- All code golf PRs will be closed, but [conceptual cleanups](https://github.com/geohot/tinygrad/pull/372/files) are great.
|
||||
- Features are welcome. Though if you are adding a feature, you need to include tests.
|
||||
- Improving test coverage is great, with reliable non-brittle tests.
|
||||
|
||||
```
|
||||
Buffer # class of memory on this device
|
||||
unary_op (NOOP, EXP, LOG, CAST, SIN) # A -> A
|
||||
reduce_op (SUM, MAX) # A -> B (smaller size, B has 1 in shape)
|
||||
binary_op (ADD, SUB, MUL, DIV, POW, CMPEQ, MAX) # A + A -> A (all the same size)
|
||||
movement_op (EXPAND, RESHAPE, PERMUTE, PAD, SHRINK, STRIDE) # A -> B (different size)
|
||||
fused_op [[optional]] (MULACC) # A * A -> B
|
||||
```
|
||||
|
||||
## ImageNet inference
|
||||
|
||||
Despite being tiny, tinygrad supports the full EfficientNet. Pass in a picture to discover what it is.
|
||||
|
||||
```bash
|
||||
python3 examples/efficientnet.py https://media.istockphoto.com/photos/hen-picture-id831791190
|
||||
```
|
||||
|
||||
Or, if you have a webcam and cv2 installed
|
||||
|
||||
```bash
|
||||
python3 examples/efficientnet.py webcam
|
||||
```
|
||||
|
||||
PROTIP: Set "DEBUG=2" environment variable if you want to see why it's slow.
|
||||
|
||||
### tinygrad supports Stable Diffusion!
|
||||
|
||||
You might need to download the [weight](https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt) of Stable Diffusion and put it into weights/
|
||||
|
||||
Run `python3 examples/stable_diffusion.py`
|
||||
|
||||
<p align="center">
|
||||
<img src="https://raw.githubusercontent.com/geohot/tinygrad/master/docs/stable_diffusion_by_tinygrad.jpg">
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
"a horse sized cat eating a bagel"
|
||||
</p>
|
||||
|
||||
### tinygrad supports LLaMA
|
||||
|
||||
After putting the weights in weights/LLaMA, you can have a chat with Stacy. She lives inside tinygrad.
|
||||
|
||||
```bash
|
||||
python3 examples/llama.py
|
||||
```
|
||||
|
||||
### tinygrad supports GANs
|
||||
|
||||
See `examples/mnist_gan.py`
|
||||
|
||||
<p align="center">
|
||||
<img src="https://raw.githubusercontent.com/geohot/tinygrad/master/docs/mnist_by_tinygrad.jpg">
|
||||
</p>
|
||||
|
||||
### tinygrad supports yolo
|
||||
|
||||
See `examples/yolov3.py`
|
||||
|
||||
<p align="center">
|
||||
<img src="https://raw.githubusercontent.com/geohot/tinygrad/master/docs/yolo_by_tinygrad.jpg">
|
||||
</p>
|
||||
|
||||
### Drawing Execution Graph
|
||||
|
||||
```bash
|
||||
GRAPH=1 python3 test/models/test_mnist.py TestMNIST.test_sgd_onestep
|
||||
# requires dot, outputs /tmp/net.svg
|
||||
```
|
||||
Additional guidelines can be found in [CONTRIBUTING.md](/CONTRIBUTING.md).
|
||||
|
||||
### Running tests
|
||||
|
||||
For more examples on how to run the full test suite please refer to the [CI workflow](.github/workflows/test.yml).
|
||||
|
||||
```bash
|
||||
Some examples:
|
||||
```sh
|
||||
python3 -m pip install -e '.[testing]'
|
||||
python3 -m pytest
|
||||
python3 -m pytest -v -k TestTrain
|
||||
python3 ./test/models/test_train.py TestTrain.test_efficientnet
|
||||
```
|
||||
|
||||
|
|
126
docs/README.md
126
docs/README.md
|
@ -1,125 +1,37 @@
|
|||
### Welcome to the tinygrad documentation
|
||||
# Welcome to the tinygrad documentation!
|
||||
|
||||
General instructions you will find in [README.md](https://github.com/geohot/tinygrad/blob/master/README.md)
|
||||
Here you will find documentation for tinygrad, as well as some examples and tutorials.
|
||||
|
||||
[abstraction.py](https://github.com/geohot/tinygrad/blob/master/docs/abstractions.py) is a well documented showcase of the abstraction stack.
|
||||
## Getting Started
|
||||
|
||||
There are plenty of [tests](https://github.com/geohot/tinygrad/tree/master/test) you can read through
|
||||
[Examples](https://github.com/geohot/tinygrad/tree/master/examples) contains tinygrad implementations of popular models (vision and language) and neural networks. LLama, Stable diffusion, GANs and Yolo to name a few
|
||||
Read the quick start guide [here](/docs/quickstart.md).
|
||||
|
||||
### Environment variables
|
||||
Here is a list of environment variables you can use with tinygrad.
|
||||
Most of these are self-explanatory, and used to enable an option at runtime.
|
||||
Example : `GPU=1 DEBUG=4 python3 -m pytest`
|
||||
Or if you want to jump right in to how tinygrad works, you can read the [abstraction stack](/docs/abstractions.py) documentation.
|
||||
|
||||
The columns are: Variable, Value and Description
|
||||
They are also grouped into either general tinygrad or specific files
|
||||
Or if you want to see some examples, you can look at the examples in the [examples](/examples) directory.
|
||||
|
||||
##### General tinygrad
|
||||
DEBUG: [1-4], enable debugging output, with 4 you get operations, timings, speed, generated code and more
|
||||
GPU: [1], enable the GPU backend
|
||||
CPU: [1], enable CPU backend
|
||||
MPS: [1], emable MPS device (for Mac M1 and after)
|
||||
METAL: [1], enable Metal backend (for Mac M1 and after)
|
||||
METAL_XCODE: [1], enable Metal using MacOS Xcode sdk
|
||||
TORCH: [1], enable Torch backend
|
||||
CLANG: [1], enable Clang backend
|
||||
LLVM: [1], enable LLVM backend
|
||||
LLVMOPT: [1], enable LLVM optimization
|
||||
LAZY: [1], enable lazy operations
|
||||
OPT: [1-4], enable optimization
|
||||
OPTLOCAL: [1], enable local optimization
|
||||
JIT: [1], enable Jit
|
||||
GRAPH: [1], Create a graph of all operations
|
||||
GRAPHPATH: [/path/to], what path to generate the graph image
|
||||
PRUNEGRAPH, [1], prune movementops and loadops from the graph
|
||||
PRINT_PRG: [1], print program
|
||||
FLOAT16: [1], use float16 instead of float32
|
||||
ENABLE_METHOD_CACHE: [1], enable method cache
|
||||
EARLY_STOPPING: [1], stop early
|
||||
DISALLOW_ASSIGN: [1], enable not assigning the realized lazydata to the lazy output buffer
|
||||
Or if you just want to see some of the things tinygrad can do, check out the [showcase](/docs/showcase.md).
|
||||
|
||||
##### tinygrad/codegen/cstyle.py
|
||||
NATIVE_EXPLOG: [1], enable using native explog
|
||||
## API
|
||||
|
||||
##### accel/ane/2_compile/hwx_parse.py
|
||||
PRINTALL: [1], print all ane registers
|
||||
This is currently a big work in progress.
|
||||
|
||||
##### extra/onnx.py
|
||||
ONNXLIMIT: [ ], set a limit for Onnx
|
||||
DEBUGONNX: [1], enable Onnx debugging
|
||||
## Resources
|
||||
|
||||
##### extra/thneed.py
|
||||
DEBUGCL: [1-4], enable Debugging for OpenCL
|
||||
PRINT_KERNEL: [1], Print OpenCL Kernels
|
||||
### Environment Variables
|
||||
|
||||
##### extra/kernel_search.py
|
||||
OP: [1-3], different operations
|
||||
NOTEST: [1], enable not testing ast
|
||||
DUMP: [1], enable dumping of intervention cache
|
||||
REDUCE: [1], enable reduce operations
|
||||
SIMPLE_REDUCE: [1], enable simpler reduce operations
|
||||
BC: [1], enable big conv operations
|
||||
CONVW: [1], enable convw operations
|
||||
FASTCONV: [1], enable faster conv operations
|
||||
GEMM: [1], enable general matrix multiply operations
|
||||
BROKEN: [1], enable a kind of operation
|
||||
BROKEN3: [1], enable a kind of operation
|
||||
[env_vars.md](/docs/env_vars.md)
|
||||
|
||||
##### examples/vit.py
|
||||
LARGE: [1], enable larger dimension model
|
||||
### Adding New Accelerators
|
||||
|
||||
##### examples/llama.py
|
||||
WEIGHTS: [1], enable using weights
|
||||
[adding_new_accelerators.md](/docs/adding_new_accelerators.md)
|
||||
|
||||
##### examples/mlperf
|
||||
MODEL: [resnet,retinanet,unet3d,rnnt,bert,maskrcnn], what models to use
|
||||
### Community
|
||||
|
||||
##### examples/benchmark_train_efficientnet.py
|
||||
CNT: [10], the amount of times to loop the benchmark
|
||||
BACKWARD: [1], enable backward call
|
||||
TRAINING: [1], set Tensor.training
|
||||
CLCACHE: [1], enable Cache for OpenCL
|
||||
[![tinygrad discord](https://discordapp.com/api/guilds/1068976834382925865/widget.png?style=banner2)](https://discord.gg/ZjZadyC7PK)
|
||||
|
||||
##### examples/hlb_cifar10.py
|
||||
TORCHWEIGHTS: [1], use torch to initialize weights
|
||||
DISABLE_BACKWARD: [1], dont use backward operations
|
||||
## Contributing
|
||||
|
||||
##### examples/benchmark_train_efficientnet.py & examples/hlb_cifar10.py
|
||||
ADAM: [1], enable Adam optimization
|
||||
|
||||
##### examples/hlb_cifar10.py & xamples/hlb_cifar10_torch.py
|
||||
STEPS: [0-10], number of steps
|
||||
FAKEDATA: [1], enable to use random data
|
||||
|
||||
##### examples/train_efficientnet.py
|
||||
STEPS: [1024 dividable], number of steps
|
||||
TINY: [1], use a tiny convolution network
|
||||
IMAGENET: [1], use imagenet for training
|
||||
|
||||
##### examples/train_efficientnet.py & examples/train_resnet.py
|
||||
TRANSFER: [1], enable to use pretrained data
|
||||
|
||||
##### examples & test/external/external_test_opt.py
|
||||
NUM: [18, 2], what ResNet[18] / EfficientNet[2] to train
|
||||
|
||||
##### test/test_ops.py
|
||||
PRINT_TENSORS: [1], print tensors
|
||||
FORWARD_ONLY: [1], use forward operations only
|
||||
|
||||
##### test/test_speed_v_torch.py
|
||||
TORCHCUDA: [1], enable the torch cuda backend
|
||||
|
||||
##### test/external/external_test_gpu_ast.py
|
||||
KOPT: [1], enable kernel optimization
|
||||
KCACHE: [1], enable kernel cache
|
||||
|
||||
##### test/external/external_test_opt.py
|
||||
ENET_NUM: [-2,-1], what EfficientNet to use
|
||||
|
||||
##### test/test_dtype.py & test/extra/test_utils.py & extra/training.py
|
||||
CI: [1], enable to avoid some tests to run in CI
|
||||
|
||||
##### examples & extra & test
|
||||
BS: [8, 16, 32, 64, 128], bytesize
|
||||
The documentation mainly follows the core contributing guidelines in the [README.md](/README.md#contributing).
|
||||
|
||||
Additionally, we always welcome documentation contributions, especially for features that are currently under documented.
|
||||
|
|
|
@ -0,0 +1,32 @@
|
|||
# Adding a new accelerator to tinygrad
|
||||
|
||||
It's pretty easy to add a new accelerator to tinygrad. All you need to do is implement a total of 20 (optionally 21) low level ops. Then tinygrad takes care of the rest, handling derivatives and syntactic sugar.
|
||||
|
||||
## llops
|
||||
|
||||
These are the ops that you must implement for your accelerator of choice.
|
||||
```
|
||||
Buffer # class of memory on this device
|
||||
unary_op (NOOP, EXP, LOG, CAST, SIN) # A -> A
|
||||
reduce_op (SUM, MAX) # A -> B (smaller size, B has 1 in shape)
|
||||
binary_op (ADD, SUB, MUL, DIV, POW, CMPEQ, MAX) # A + A -> A (all the same size)
|
||||
movement_op (EXPAND, RESHAPE, PERMUTE, PAD, SHRINK, STRIDE) # A -> B (different size)
|
||||
fused_op [[optional]] (MULACC) # A * A -> B
|
||||
```
|
||||
|
||||
## mlops
|
||||
|
||||
These are the mid level ops that handle the derivatives.
|
||||
```
|
||||
Relu, Log, Exp, Sin # unary ops
|
||||
Sum, Max # reduce ops (with axis argument)
|
||||
Maximum, Add, Sub, Mul, Pow, Div, Equal # binary ops (no broadcasting, use expand)
|
||||
Expand, Reshape, Permute, Pad, Shrink, Flip # movement ops
|
||||
```
|
||||
These are implemented in [mlops.py](/tinygrad/mlops.py).
|
||||
|
||||
## hlops
|
||||
|
||||
These are the syntax sugar. They are built on top of the mlops and support most of the things that you could expect from a tensor library.
|
||||
|
||||
These are implemented in [tensor.py](/tinygrad/tensor.py).
|
|
@ -0,0 +1,186 @@
|
|||
# List of environment variables that control tinygrad behavior.
|
||||
|
||||
This is a list of environment variable that control the runtime behavior of tinygrad and its examples.
|
||||
Most of these are self-explanatory, and are usually used to set an option at runtime.
|
||||
|
||||
Example: `GPU=1 DEBUG=4 python3 -m pytest`
|
||||
|
||||
The columns are: Variable, Possible Value(s) and Description.
|
||||
|
||||
- A `#` means that the variable can take any integer value.
|
||||
|
||||
## Global Variables
|
||||
|
||||
These control the behavior of core tinygrad even when used as a library.
|
||||
|
||||
Variable | Possible Value(s) | Description
|
||||
---|---|---
|
||||
DEBUG | [1-4] | enable debugging output, with 4 you get operations, timings, speed, generated code and more
|
||||
GPU | [1] | enable the GPU backend
|
||||
CPU | [1] | enable CPU backend
|
||||
MPS | [1] | enable MPS device (for Mac M1 and after)
|
||||
METAL | [1] | enable Metal backend (for Mac M1 and after)
|
||||
METAL_XCODE | [1] | enable Metal using macOS Xcode SDK
|
||||
TORCH | [1] | enable PyTorch backend
|
||||
CLANG | [1] | enable Clang backend
|
||||
LLVM | [1] | enable LLVM backend
|
||||
LLVMOPT | [1] | enable slightly more expensive LLVM optimizations
|
||||
LAZY | [1] | enable lazy operations (this is the default)
|
||||
OPT | [1-4] | optimization level
|
||||
OPTLOCAL | [1-2] | enable local optimization
|
||||
GRAPH | [1] | create a graph of all operations (requires graphviz)
|
||||
GRAPHPATH | [/path/to] | where to put the generated graph
|
||||
PRUNEGRAPH | [1] | prune MovementOps and LoadOps from the graph
|
||||
PRINT_PRG | [1] | print program code
|
||||
IMAGE | [1] | enable 2d specific optimizations
|
||||
FLOAT16 | [1] | use float16 for images instead of float32
|
||||
ENABLE_METHOD_CACHE | [1] | enable method cache (this is the default)
|
||||
EARLY_STOPPING | [# > 0] | stop after this many kernels
|
||||
DISALLOW_ASSIGN | [1] | disallow assignment of tensors
|
||||
NATIVE_EXPLOG | [1] | enable using native exp and log
|
||||
|
||||
## File Specific Variables
|
||||
|
||||
These are variables that control the behavior of a specific file, these usually don't affect the library itself.
|
||||
Most of the time these will never be used, but they are here for completeness.
|
||||
|
||||
### accel/ane/2_compile/hwx_parse.py
|
||||
|
||||
Variable | Possible Value(s) | Description
|
||||
---|---|---
|
||||
PRINTALL | [1] | print all ANE registers
|
||||
|
||||
### extra/onnx.py
|
||||
|
||||
Variable | Possible Value(s) | Description
|
||||
---|---|---
|
||||
ONNXLIMIT | [#] | set a limit for ONNX
|
||||
DEBUGONNX | [1] | enable ONNX debugging
|
||||
|
||||
### extra/thneed.py
|
||||
|
||||
Variable | Possible Value(s) | Description
|
||||
---|---|---
|
||||
DEBUGCL | [1-4] | enable Debugging for OpenCL
|
||||
PRINT_KERNEL | [1] | Print OpenCL Kernels
|
||||
|
||||
### extra/kernel_search.py
|
||||
|
||||
Variable | Possible Value(s) | Description
|
||||
---|---|---
|
||||
OP | [1-3] | different operations
|
||||
NOTEST | [1] | enable not testing AST
|
||||
DUMP | [1] | enable dumping of intervention cache
|
||||
REDUCE | [1] | enable reduce operations
|
||||
SIMPLE_REDUCE | [1] | enable simpler reduce operations
|
||||
BC | [1] | enable big conv operations
|
||||
CONVW | [1] | enable convw operations
|
||||
FASTCONV | [1] | enable faster conv operations
|
||||
GEMM | [1] | enable general matrix multiply operations
|
||||
BROKEN | [1] | enable a kind of operation
|
||||
BROKEN3 | [1] | enable a kind of operation
|
||||
|
||||
### examples/vit.py
|
||||
|
||||
Variable | Possible Value(s) | Description
|
||||
---|---|---
|
||||
LARGE | [1] | enable larger dimension model
|
||||
|
||||
### examples/llama.py
|
||||
|
||||
Variable | Possible Value(s) | Description
|
||||
---|---|---
|
||||
WEIGHTS | [1] | enable loading weights
|
||||
|
||||
### examples/mlperf
|
||||
|
||||
Variable | Possible Value(s) | Description
|
||||
---|---|---
|
||||
MODEL | [resnet,retinanet,unet3d,rnnt,bert,maskrcnn] | what models to use
|
||||
|
||||
### examples/benchmark_train_efficientnet.py
|
||||
|
||||
Variable | Possible Value(s) | Description
|
||||
---|---|---
|
||||
CNT | [10] | the amount of times to loop the benchmark
|
||||
BACKWARD | [1] | enable backward pass
|
||||
TRAINING | [1] | set Tensor.training
|
||||
CLCACHE | [1] | enable cache for OpenCL
|
||||
|
||||
### examples/hlb_cifar10.py
|
||||
|
||||
Variable | Possible Value(s) | Description
|
||||
---|---|---
|
||||
TORCHWEIGHTS | [1] | use torch to initialize weights
|
||||
DISABLE_BACKWARD | [1] | don't do backward pass
|
||||
|
||||
### examples/benchmark_train_efficientnet.py & examples/hlb_cifar10.py
|
||||
|
||||
Variable | Possible Value(s) | Description
|
||||
---|---|---
|
||||
ADAM | [1] | use the Adam optimizer
|
||||
|
||||
### examples/hlb_cifar10.py & xamples/hlb_cifar10_torch.py
|
||||
|
||||
Variable | Possible Value(s) | Description
|
||||
---|---|---
|
||||
STEPS | [0-10] | number of steps
|
||||
FAKEDATA | [1] | enable to use random data
|
||||
|
||||
### examples/train_efficientnet.py
|
||||
|
||||
Variable | Possible Value(s) | Description
|
||||
---|---|---
|
||||
STEPS | [# % 1024] | number of steps
|
||||
TINY | [1] | use a tiny convolution network
|
||||
IMAGENET | [1] | use imagenet for training
|
||||
|
||||
### examples/train_efficientnet.py & examples/train_resnet.py
|
||||
|
||||
Variable | Possible Value(s) | Description
|
||||
---|---|---
|
||||
TRANSFER | [1] | enable to use pretrained data
|
||||
|
||||
### examples & test/external/external_test_opt.py
|
||||
|
||||
Variable | Possible Value(s) | Description
|
||||
---|---|---
|
||||
NUM | [18, 2] | what ResNet[18] / EfficientNet[2] to train
|
||||
|
||||
### test/test_ops.py
|
||||
|
||||
Variable | Possible Value(s) | Description
|
||||
---|---|---
|
||||
PRINT_TENSORS | [1] | print tensors
|
||||
FORWARD_ONLY | [1] | use forward operations only
|
||||
|
||||
### test/test_speed_v_torch.py
|
||||
|
||||
Variable | Possible Value(s) | Description
|
||||
---|---|---
|
||||
TORCHCUDA | [1] | enable the torch cuda backend
|
||||
|
||||
### test/external/external_test_gpu_ast.py
|
||||
|
||||
Variable | Possible Value(s) | Description
|
||||
---|---|---
|
||||
KOPT | [1] | enable kernel optimization
|
||||
KCACHE | [1] | enable kernel cache
|
||||
|
||||
### test/external/external_test_opt.py
|
||||
|
||||
Variable | Possible Value(s) | Description
|
||||
---|---|---
|
||||
ENET_NUM | [-2,-1] | what EfficientNet to use
|
||||
|
||||
### test/test_dtype.py & test/extra/test_utils.py & extra/training.py
|
||||
|
||||
Variable | Possible Value(s) | Description
|
||||
---|---|---
|
||||
CI | [1] | disables some tests for CI
|
||||
|
||||
### examples & extra & test
|
||||
|
||||
Variable | Possible Value(s) | Description
|
||||
---|---|---
|
||||
BS | [8, 16, 32, 64, 128] | batch size to use
|
|
@ -0,0 +1,300 @@
|
|||
# tinygrad Quick Start Guide
|
||||
|
||||
This guide assumes no prior knowledge of pytorch or any other deep learning framework, but does assume some basic knowledge of neural networks.
|
||||
It is intended to be a very quick overview of the high level API that tinygrad provides.
|
||||
|
||||
This guide is also structured as a tutorial which at the end of it you will have a working model that can classify handwritten digits.
|
||||
|
||||
We need some imports to get started:
|
||||
```py
|
||||
import numpy as np
|
||||
import time
|
||||
```
|
||||
|
||||
## Tensors
|
||||
|
||||
Tensors are the base data structure in tinygrad. They can be thought of as a multidimensional array of a specific data type.
|
||||
All high level operations in tinygrad operate on these tensors.
|
||||
|
||||
The tensor class can be imported like so:
|
||||
```py
|
||||
from tinygrad.tensor import Tensor
|
||||
```
|
||||
|
||||
Tensors can be created from an existing data structure like a python list or numpy ndarray:
|
||||
```py
|
||||
t1 = Tensor([1, 2, 3, 4, 5])
|
||||
na = np.array([1, 2, 3, 4, 5])
|
||||
t2 = Tensor(na)
|
||||
```
|
||||
|
||||
Tensors can also be created using one of the many factory methods:
|
||||
```py
|
||||
full = Tensor.full(shape=(2, 3), fill_value=5) # create a tensor of shape (2, 3) filled with 5
|
||||
zeros = Tensor.zeros(2, 3) # create a tensor of shape (2, 3) filled with 0
|
||||
ones = Tensor.ones(2, 3) # create a tensor of shape (2, 3) filled with 1
|
||||
|
||||
full_like = Tensor.full_like(full, fill_value=2) # create a tensor of the same shape as `full` filled with 2
|
||||
zeros_like = Tensor.zeros_like(full) # create a tensor of the same shape as `full` filled with 0
|
||||
ones_like = Tensor.ones_like(full) # create a tensor of the same shape as `full` filled with 1
|
||||
|
||||
eye = Tensor.eye(3) # create a 3x3 identity matrix
|
||||
arange = Tensor.arange(start=0, stop=10, step=1) # create a tensor of shape (10,) filled with values from 0 to 9
|
||||
|
||||
rand = Tensor.rand(2, 3) # create a tensor of shape (2, 3) filled with random values from a uniform distribution
|
||||
randn = Tensor.randn(2, 3) # create a tensor of shape (2, 3) filled with random values from a normal distribution
|
||||
uniform = Tensor.uniform(2, 3, low=0, high=10) # create a tensor of shape (2, 3) filled with random values from a uniform distribution between 0 and 10
|
||||
```
|
||||
There are even more of these factory methods, you can find them in the [tensor.py](/tinygrad/tensor.py) file.
|
||||
|
||||
All the tensors creation methods can take a `dtype` argument to specify the data type of the tensor.
|
||||
```py
|
||||
from tinygrad.helpers import dtypes
|
||||
|
||||
t3 = Tensor([1, 2, 3, 4, 5], dtype=dtypes.int32)
|
||||
```
|
||||
|
||||
Tensors allow you to perform operations on them like so:
|
||||
```py
|
||||
t4 = Tensor([1, 2, 3, 4, 5])
|
||||
t5 = (t4 + 1) * 2
|
||||
t6 = (t5 * t4).relu().log_softmax()
|
||||
```
|
||||
|
||||
All of these operations are lazy and are only executed when you realize the tensor using `.realize()` or `.numpy()`.
|
||||
```py
|
||||
print(t6.numpy())
|
||||
# [-56. -48. -36. -20. 0.]
|
||||
```
|
||||
|
||||
There are a lot more operations that can be performed on tensors, you can find them in the [tensor.py](/tinygrad/tensor.py) file.
|
||||
Additionally reading through [abstractions.py](/docs/abstractions.py) will help you understand how operations on these tensors make their way down to your hardware.
|
||||
|
||||
## Models
|
||||
|
||||
Neural networks in tinygrad are really just represented by the operations performed on tensors.
|
||||
These operations are commonly grouped into the `__call__` method of a class which allows modularization and reuse of these groups of operations.
|
||||
These classes do not need to inherit from any base class, in fact if they don't need any trainable parameters they don't even need to be a class!
|
||||
|
||||
An example of this would be the `nn.Linear` class which represents a linear layer in a neural network.
|
||||
```py
|
||||
# from tinygrad.nn import Linear
|
||||
class Linear:
|
||||
def __init__(self, in_features, out_features, bias=True, initialization: str='kaiming_uniform'):
|
||||
self.weight = getattr(Tensor, initialization)(out_features, in_features)
|
||||
self.bias = Tensor.zeros(out_features) if bias else None
|
||||
|
||||
def __call__(self, x):
|
||||
return x.linear(self.weight.transpose(), self.bias)
|
||||
```
|
||||
There are more neural network modules already implemented in [nn](/tinygrad/nn/__init__.py), and you can also implement your own.
|
||||
|
||||
We will be implementing a simple neural network that can classify handwritten digits from the MNIST dataset.
|
||||
Our classifier will be a simple 2 layer neural network with a Leaky ReLU activation function.
|
||||
It will use a hidden layer size of 128 and an output layer size of 10 (one for each digit) with no bias on either Linear layer.
|
||||
```py
|
||||
from tinygrad.nn import Linear
|
||||
|
||||
class TinyNet:
|
||||
def __init__(self):
|
||||
self.l1 = Linear(784, 128, bias=False)
|
||||
self.l2 = Linear(128, 10, bias=False)
|
||||
|
||||
def __call__(self, x):
|
||||
x = self.l1(x)
|
||||
x = x.leakyrelu()
|
||||
x = self.l2(x)
|
||||
return x.log_softmax()
|
||||
|
||||
net = TinyNet()
|
||||
```
|
||||
We can see that the forward pass of our neural network is just the sequence of operations performed on the input tensor `x`.
|
||||
We can also see that functional operations like `leakyrelu` and `log_softmax` are not defined as classes and instead are just methods we can just call.
|
||||
Finally, we just initialize an instance of our neural network, and we are ready to start training it.
|
||||
|
||||
## Training
|
||||
|
||||
Now that we have our neural network defined we can start training it.
|
||||
Training neural networks in tinygrad is super simple.
|
||||
All we need to do is define our neural network, define our loss function, and then call `.backward()` on the loss function to compute the gradients.
|
||||
They can then be used to update the parameters of our neural network using one of the many optimizers in [optim.py](/tinygrad/nn/optim.py).
|
||||
|
||||
First we need to set the training flag in `Tensor`:
|
||||
```py
|
||||
Tensor.training = True
|
||||
```
|
||||
|
||||
For our loss function we will be using cross entropy loss.
|
||||
```py
|
||||
# from extra.training import sparse_categorical_crossentropy
|
||||
def cross_entropy(out, Y):
|
||||
num_classes = out.shape[-1]
|
||||
YY = Y.flatten().astype(np.int32)
|
||||
y = np.zeros((YY.shape[0], num_classes), np.float32)
|
||||
y[range(y.shape[0]),YY] = -1.0*num_classes
|
||||
y = y.reshape(list(Y.shape)+[num_classes])
|
||||
y = Tensor(y)
|
||||
return out.mul(y).mean()
|
||||
```
|
||||
As we can see in this implementation of cross entropy loss, there are certain operations that tinygrad does not support.
|
||||
Namely, operations that are load/store like indexing a tensor with another tensor or assigning a value to a tensor at a certain index.
|
||||
Load/store ops are not supported in tinygrad because they add complexity when trying to port to different backends and 90% of the models out there don't use/need them.
|
||||
|
||||
For our optimizer we will be using the traditional stochastic gradient descent optimizer with a learning rate of 3e-4.
|
||||
```py
|
||||
from tinygrad.nn.optim import SGD
|
||||
|
||||
opt = SGD([net.l1.weight, net.l2.weight], lr=3e-4)
|
||||
```
|
||||
We can see that we are passing in the parameters of our neural network to the optimizer.
|
||||
This is due to the fact that the optimizer needs to know which parameters to update.
|
||||
There is a simpler way to do this just by using `get_parameters(net)` from `tinygrad.nn.optim` which will return a list of all the parameters in the neural network.
|
||||
The parameters are just listed out explicitly here for clarity.
|
||||
|
||||
Now that we have our network, loss function, and optimizer defined all we are missing is the data to train on!
|
||||
There are a couple of dataset loaders in tinygrad located in [/datasets](/datasets).
|
||||
We will be using the MNIST dataset loader.
|
||||
```py
|
||||
from datasets import fetch_mnist
|
||||
```
|
||||
|
||||
Now we have everything we need to start training our neural network.
|
||||
We will be training for 1000 steps with a batch size of 64.
|
||||
```py
|
||||
X_train, Y_train, X_test, Y_test = fetch_mnist()
|
||||
|
||||
for step in range(1000):
|
||||
# random sample a batch
|
||||
samp = np.random.randint(0, X_train.shape[0], size=(64))
|
||||
batch = Tensor(X_train[samp], requires_grad=False)
|
||||
# get the corresponding labels
|
||||
labels = Y_train[samp]
|
||||
|
||||
# forward pass
|
||||
out = net(batch)
|
||||
|
||||
# compute loss
|
||||
loss = cross_entropy(out, labels)
|
||||
|
||||
# zero gradients
|
||||
opt.zero_grad()
|
||||
|
||||
# backward pass
|
||||
loss.backward()
|
||||
|
||||
# update parameters
|
||||
opt.step()
|
||||
|
||||
# calculate accuracy
|
||||
pred = np.argmax(out.numpy(), axis=-1)
|
||||
acc = (pred == labels).mean()
|
||||
|
||||
if step % 100 == 0:
|
||||
print(f"Step {step+1} | Loss: {loss.numpy()} | Accuracy: {acc}")
|
||||
```
|
||||
|
||||
## Evaluation
|
||||
|
||||
Now that we have trained our neural network we can evaluate it on the test set.
|
||||
We will be using the same batch size of 64 and will be evaluating for 1000 of those batches.
|
||||
```py
|
||||
# set training flag to false
|
||||
Tensor.training = False
|
||||
|
||||
st = time.perf_counter()
|
||||
avg_acc = 0
|
||||
for step in range(1000):
|
||||
# random sample a batch
|
||||
samp = np.random.randint(0, X_test.shape[0], size=(64))
|
||||
batch = Tensor(X_test[samp], requires_grad=False)
|
||||
# get the corresponding labels
|
||||
labels = Y_test[samp]
|
||||
|
||||
# forward pass
|
||||
out = net(batch)
|
||||
|
||||
# calculate accuracy
|
||||
pred = np.argmax(out.numpy(), axis=-1)
|
||||
avg_acc += (pred == labels).mean()
|
||||
print(f"Test Accuracy: {avg_acc / 1000}")
|
||||
print(f"Time: {time.perf_counter() - st}")
|
||||
```
|
||||
|
||||
## And that's it!
|
||||
|
||||
Highly recommend you check out the [examples/](/examples) folder for more examples of using tinygrad.
|
||||
Reading the source code of tinygrad is also a great way to learn how it works.
|
||||
Specifically the tests in [tests/](/tests) are a great place to see how to use and the semantics of the different operations.
|
||||
There are also a bunch of models implemented in [models/](/models) that you can use as a reference.
|
||||
|
||||
Additionally, feel free to ask questions in the `#learn-tinygrad` channel on the [discord](https://discord.gg/beYbxwxVdx). Don't ask to ask, just ask!
|
||||
|
||||
## Extras
|
||||
|
||||
### JIT
|
||||
|
||||
Additionally, it is possible to speed up the computation of certain neural networks by using the JIT.
|
||||
Currently, this does not support models with varying input sizes and non tinygrad operations.
|
||||
|
||||
To use the JIT we just need to add a function decorator to the forward pass of our neural network and ensure that the input and output are realized tensors.
|
||||
Or in this case we will create a wrapper function and decorate the wrapper function to speed up the evaluation of our neural network.
|
||||
```py
|
||||
from tinygrad.jit import TinyJit
|
||||
|
||||
@TinyJit
|
||||
def jit(x):
|
||||
return net(x).realize()
|
||||
|
||||
st = time.perf_counter()
|
||||
avg_acc = 0
|
||||
for step in range(1000):
|
||||
# random sample a batch
|
||||
samp = np.random.randint(0, X_test.shape[0], size=(64))
|
||||
batch = Tensor(X_test[samp], requires_grad=False)
|
||||
# get the corresponding labels
|
||||
labels = Y_test[samp]
|
||||
|
||||
# forward pass with jit
|
||||
out = jit(batch)
|
||||
|
||||
# calculate accuracy
|
||||
pred = np.argmax(out.numpy(), axis=-1)
|
||||
avg_acc += (pred == labels).mean()
|
||||
print(f"Test Accuracy: {avg_acc / 1000}")
|
||||
print(f"Time: {time.perf_counter() - st}")
|
||||
```
|
||||
You will find that the evaluation time is much faster than before and that your accelerator utilization is much higher.
|
||||
|
||||
### Saving and Loading Models
|
||||
|
||||
The standard weight format for tinygrad is [safetensors](https://github.com/huggingface/safetensors). This means that you can load the weights of any model also using safetensors into tinygrad.
|
||||
There are functions in [state.py](/tinygrad/state.py) to save and load models to and from this format.
|
||||
```py
|
||||
from tinygrad.state import safe_save, safe_load, get_state_dict, load_state_dict
|
||||
|
||||
# first we need the state dict of our model
|
||||
state_dict = get_state_dict(net)
|
||||
|
||||
# then we can just save it to a file
|
||||
safe_save(state_dict, "model.safetensors")
|
||||
|
||||
# and load it back in
|
||||
state_dict = safe_load("model.safetensors")
|
||||
load_state_dict(net, state_dict)
|
||||
```
|
||||
|
||||
Many of the models in the [models/](/models) folder have a `load_from_pretrained` method that will download and load the weights for you. These usually are pytorch weights meaning that you would need pytorch installed to load them.
|
||||
|
||||
### Environment Variables
|
||||
|
||||
There exist a bunch of environment variables that control the runtime behavior of tinygrad.
|
||||
Some of the commons ones are `DEBUG` and the different backend enablement variables.
|
||||
|
||||
You can find a full list and their descriptions in [env_vars.md](/docs/env_vars.md).
|
||||
|
||||
### Visualizing the Computation Graph
|
||||
|
||||
It is possible to visualize the computation graph of a neural network using [graphviz](https://graphviz.org/).
|
||||
|
||||
This is easily done by running a single pass (forward or backward!) of the neural network with the environment variable `GRAPH` set to `1`.
|
||||
The graph will be saved to `/tmp/net.svg` by default.
|
|
@ -0,0 +1,61 @@
|
|||
# tinygrad Showcase
|
||||
|
||||
Despite being a tiny library, tinygrad is capable of doing a lot of things. From state-of-the-art [vision](https://arxiv.org/abs/1905.11946) to state-of-the-art [language](https://arxiv.org/abs/1706.03762) models.
|
||||
|
||||
## Vision
|
||||
|
||||
### EfficientNet
|
||||
|
||||
You can either pass in the URL of a picture to discover what it is:
|
||||
```sh
|
||||
python3 examples/efficientnet.py https://media.istockphoto.com/photos/hen-picture-id831791190
|
||||
```
|
||||
Or, if you have a camera and OpenCV installed, you can detect what is in front of you:
|
||||
```sh
|
||||
python3 examples/efficientnet.py webcam
|
||||
```
|
||||
|
||||
### YOLOv3
|
||||
|
||||
Take a look at [yolov3.py](/examples/yolov3.py).
|
||||
|
||||
![yolo by tinygrad](/docs/showcase/yolo_by_tinygrad.jpg)
|
||||
|
||||
## Audio
|
||||
|
||||
### Whisper
|
||||
|
||||
Take a look at [whisper.py](/examples/whisper.py). You need pyaudio and torchaudio installed.
|
||||
|
||||
```sh
|
||||
SMALL=1 python3 examples/whisper.py
|
||||
```
|
||||
|
||||
## Generative
|
||||
|
||||
### Generative Adversarial Networks
|
||||
|
||||
Take a look at [mnist_gan.py](/examples/mnist_gan.py).
|
||||
|
||||
![mnist gan by tinygrad](/docs/showcase/mnist_by_tinygrad.jpg)
|
||||
|
||||
### Stable Diffusion
|
||||
|
||||
You will need to download the [weights](https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt) of Stable Diffusion and put it into the [weights/](/weights) directory.
|
||||
|
||||
```sh
|
||||
python3 examples/stable_diffusion.py
|
||||
```
|
||||
|
||||
![a horse sized cat eating a bagel](/docs/showcase/stable_diffusion_by_tinygrad.jpg)
|
||||
|
||||
*"a horse sized cat eating a bagel"*
|
||||
|
||||
### LLaMA
|
||||
|
||||
You will need to download and put the weights into the [weights/LLaMA](/weightsLLaMA) directory, which may need to be created.
|
||||
|
||||
Then you can have a chat with Stacy:
|
||||
```sh
|
||||
python3 examples/llama.py
|
||||
```
|
Before Width: | Height: | Size: 15 KiB After Width: | Height: | Size: 15 KiB |
Before Width: | Height: | Size: 87 KiB After Width: | Height: | Size: 87 KiB |
Before Width: | Height: | Size: 131 KiB After Width: | Height: | Size: 131 KiB |
Loading…
Reference in New Issue