A beautiful MNIST training example (#2272)

* beautiful mnist

* beautiful mnist example

* from tinygrad import Tensor

* more beautiful

* the jit is super core tinygrad

* globalcounters reset on jit run

* symlinks and exclude

* beautiful_cartpole

* evaluate is it's own function

* no symlinks

* more beautiful

* jit reset for double speed

* type hinting for JIT

* beautiful_mnist gets 98%

* beautiful_mnist < 4s with BEAM=2

* better cartpole

* use actor critic

* zero_grad got lost

* delete double relu

* stable cartpole with PPO

* beautiful_cartpole is more beautiful

* REPLAY_BUFFER

* beautiful stuff typechecks

* None support in shape

* hp tuning
This commit is contained in:
George Hotz 2023-11-17 19:42:43 -08:00 committed by GitHub
parent 74e6b6c9fc
commit c7b38b324b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 199 additions and 23 deletions

View File

@ -0,0 +1,114 @@
from typing import Tuple
import time
from tinygrad import Tensor, TinyJit, nn, Variable
from tinygrad.helpers import dtypes # TODO: wouldn't need this if argmax returned the right dtype
import gymnasium as gym
from tqdm import trange
import numpy as np # TODO: remove numpy import
class ActorCritic:
def __init__(self, in_features, out_features, hidden_state=32):
self.l1 = nn.Linear(in_features, hidden_state)
self.l2 = nn.Linear(hidden_state, out_features)
self.c1 = nn.Linear(in_features, hidden_state)
self.c2 = nn.Linear(hidden_state, 1)
def __call__(self, obs:Tensor) -> Tuple[Tensor, Tensor]:
x = self.l1(obs).tanh()
act = self.l2(x).log_softmax()
x = self.c1(obs).relu()
return act, self.c2(x)
def evaluate(model:ActorCritic, test_env:gym.Env) -> float:
(obs, _), terminated, truncated = test_env.reset(), False, False
total_rew = 0.0
while not terminated and not truncated:
act = model(Tensor(obs))[0].argmax().cast(dtypes.int32).item()
obs, rew, terminated, truncated, _ = test_env.step(act)
total_rew += float(rew)
return total_rew
# TODO: time should be < 5s on M1 Max
if __name__ == "__main__":
env = gym.make('CartPole-v1')
model = ActorCritic(env.observation_space.shape[0], int(env.action_space.n)) # type: ignore
opt = nn.optim.Adam(nn.state.get_parameters(model), lr=1e-2)
@TinyJit
def train_step(x:Tensor, selected_action:Tensor, reward:Tensor, old_log_dist:Tensor) -> Tuple[Tensor, Tensor, Tensor]:
with Tensor.train():
log_dist, value = model(x)
# get advantage
advantage = reward.reshape(-1, 1) - value
mask = selected_action.reshape(-1, 1) == Tensor.arange(log_dist.shape[1]).reshape(1, -1).expand(selected_action.shape[0], -1)
masked_advantage = mask * advantage.detach()
# PPO
ratios = (log_dist - old_log_dist).exp() * masked_advantage
clipped_ratios = ratios.clip(1-0.2, 1+0.2) * masked_advantage
action_loss = -ratios.minimum(clipped_ratios).sum(-1).mean()
entropy_loss = (log_dist.exp() * log_dist).sum(-1).mean() # this encourages diversity
critic_loss = advantage.square().mean()
opt.zero_grad()
(action_loss + entropy_loss*0.0005 + critic_loss).backward()
opt.step()
return action_loss.realize(), entropy_loss.realize(), critic_loss.realize()
@TinyJit
def get_action_dist(obs:Tensor) -> Tensor:
# TODO: with no_grad
Tensor.no_grad = True
ret = model(obs)[0].exp().realize()
Tensor.no_grad = False
return ret
BS = 256
MAX_REPLAY_BUFFER = 2000
st, steps = time.perf_counter(), 0
Xn, An, Rn = [], [], []
for i in (t:=trange(40)):
get_action_dist.reset() # NOTE: if you don't reset the jit here it captures the wrong model on the first run through
obs:np.ndarray = env.reset()[0]
rews, terminated, truncated = [], False, False
# NOTE: we don't want to early stop since then the rewards are wrong for the last episode
while not terminated and not truncated:
# pick actions
# TODO: move the multinomial into jitted tinygrad when JIT rand works
# TODO: what's the temperature here?
act = get_action_dist(Tensor(obs)).multinomial().item()
# save this state action pair
# TODO: don't use np.copy here on the CPU, what's the tinygrad way to do this and keep on device? need __setitem__ assignment
Xn.append(np.copy(obs))
An.append(act)
obs, rew, terminated, truncated, _ = env.step(act)
rews.append(float(rew))
steps += len(rews)
# reward to go
# TODO: move this into tinygrad
discounts = np.power(0.99, np.arange(len(rews)))
Rn += [np.sum(rews[i:] * discounts[:len(rews)-i]) for i in range(len(rews))]
Xn, An, Rn = Xn[-MAX_REPLAY_BUFFER:], An[-MAX_REPLAY_BUFFER:], Rn[-MAX_REPLAY_BUFFER:]
X, A, R = Tensor(Xn), Tensor(An), Tensor(Rn)
# TODO: make this work
#vsz = Variable("sz", 1, MAX_REPLAY_BUFFER-1).bind(len(Xn))
#X, A, R = Tensor(Xn).reshape(vsz, None), Tensor(An).reshape(vsz), Tensor(Rn).reshape(vsz)
old_log_dist = model(X)[0] # TODO: could save these instead of recomputing
for i in range(5):
samples = Tensor.randint(BS, high=X.shape[0]).realize() # TODO: remove the need for this
# TODO: is this recompiling based on the shape?
action_loss, entropy_loss, critic_loss = train_step(X[samples], A[samples], R[samples], old_log_dist[samples])
t.set_description(f"sz: {len(Xn):5d} steps/s: {steps/(time.perf_counter()-st):7.2f} action_loss: {action_loss.item():7.2f} entropy_loss: {entropy_loss.item():7.2f} critic_loss: {critic_loss.item():7.2f} reward: {sum(rews):6.2f}")
test_rew = evaluate(model, gym.make('CartPole-v1', render_mode='human'))
print(f"test reward: {test_rew}")

View File

@ -0,0 +1,47 @@
# model based off https://towardsdatascience.com/going-beyond-99-mnist-handwritten-digits-recognition-cfff96337392
from typing import List, Callable
from tinygrad import Tensor, TinyJit, nn
from tinygrad.helpers import GlobalCounters
from extra.datasets import fetch_mnist
from tqdm import trange
class Model:
def __init__(self):
self.layers: List[Callable[[Tensor], Tensor]] = [
nn.Conv2d(1, 32, 5), Tensor.relu,
nn.Conv2d(32, 32, 5, bias=False),
nn.BatchNorm2d(32), Tensor.relu, Tensor.max_pool2d,
nn.Conv2d(32, 64, 3), Tensor.relu,
nn.Conv2d(64, 64, 3, bias=False),
nn.BatchNorm2d(64), Tensor.relu, Tensor.max_pool2d,
lambda x: x.flatten(1), nn.Linear(576, 10)]
def __call__(self, x:Tensor) -> Tensor: return x.sequential(self.layers)
if __name__ == "__main__":
X_train, Y_train, X_test, Y_test = fetch_mnist(tensors=True)
model = Model()
opt = nn.optim.Adam(nn.state.get_parameters(model))
# TODO: there's a compiler error if you comment out TinyJit since randint isn't being realized and there's something weird with int
@TinyJit
def train_step(samples:Tensor) -> Tensor:
with Tensor.train():
opt.zero_grad()
# TODO: this "gather" of samples is very slow and not the desired way to do things in practice
# will be under 5s when this is fixed
loss = model(X_train[samples]).sparse_categorical_crossentropy(Y_train[samples]).backward()
opt.step()
return loss.realize()
@TinyJit
def get_test_acc() -> Tensor: return ((model(X_test).argmax(axis=1) == Y_test).mean()*100).realize()
test_acc = float('nan')
for i in (t:=trange(70)):
GlobalCounters.reset()
samples = Tensor.randint(512, high=X_train.shape[0]) # TODO: put this in the JIT when rand is fixed
loss = train_step(samples)
if i%10 == 9: test_acc = get_test_acc().item()
t.set_description(f"loss: {loss.item():6.2f} test_accuracy: {test_acc:5.2f}%")

View File

@ -5,14 +5,15 @@ from tinygrad.tensor import Tensor
from tinygrad.helpers import dtypes
from extra.utils import download_file
def fetch_mnist():
def fetch_mnist(tensors=False):
parse = lambda file: np.frombuffer(gzip.open(file).read(), dtype=np.uint8).copy()
dirname = Path(__file__).parent.resolve()
X_train = parse(dirname / "mnist/train-images-idx3-ubyte.gz")[0x10:].reshape((-1, 28*28)).astype(np.float32)
Y_train = parse(dirname / "mnist/train-labels-idx1-ubyte.gz")[8:]
X_test = parse(dirname / "mnist/t10k-images-idx3-ubyte.gz")[0x10:].reshape((-1, 28*28)).astype(np.float32)
Y_test = parse(dirname / "mnist/t10k-labels-idx1-ubyte.gz")[8:]
return X_train, Y_train, X_test, Y_test
if tensors: return Tensor(X_train).reshape(-1, 1, 28, 28), Tensor(Y_train), Tensor(X_test).reshape(-1, 1, 28, 28), Tensor(Y_test)
else: return X_train, Y_train, X_test, Y_test
cifar_mean = [0.4913997551666284, 0.48215855929893703, 0.4465309133731618]
cifar_std = [0.24703225141799082, 0.24348516474564, 0.26158783926049628]

View File

@ -23,7 +23,7 @@ class TinyBobNet:
return get_parameters(self)
def forward(self, x):
return x.dot(self.l1).relu().dot(self.l2).log_softmax()
return x.dot(self.l1).relu().dot(self.l2)
# create a model with a conv layer
class TinyConvNet:
@ -49,7 +49,7 @@ class TinyConvNet:
x = self.bn1(x.conv2d(self.c1)).relu().max_pool2d()
x = self.bn2(x.conv2d(self.c2)).relu().max_pool2d()
x = x.reshape(shape=[x.shape[0], -1])
return x.dot(self.l1).log_softmax()
return x.dot(self.l1)
class TestMNIST(unittest.TestCase):
def test_sgd_onestep(self):

View File

@ -19,8 +19,8 @@ class TestNN(unittest.TestCase):
loss = loss_fun(input, target)
input_tiny = Tensor(input.detach().numpy())
taret_tiny = Tensor(target.detach().numpy())
loss_tiny = input_tiny.sparse_categorical_crossentropy(taret_tiny)
target_tiny = Tensor(target.detach().numpy())
loss_tiny = input_tiny.sparse_categorical_crossentropy(target_tiny)
np.testing.assert_allclose(loss_tiny.numpy(), loss.detach().numpy(), atol=1e-5, rtol=1e-6)

3
tinygrad/__init__.py Normal file
View File

@ -0,0 +1,3 @@
from tinygrad.tensor import Tensor # noqa: F401
from tinygrad.jit import TinyJit # noqa: F401
from tinygrad.shape.symbolic import Variable # noqa: F401

View File

@ -22,7 +22,7 @@ actions += [
# returns time in seconds
def time_linearizer(lin:Linearizer, rawbufs:List[RawBuffer], allow_test_size=True, max_global_size=65536, cnt=3, disable_cache=False, clear_l2=False) -> float:
key = {"ast": str(lin.ast), "opts": str(lin.applied_opts), "allow_test_size": allow_test_size, "max_global_size": max_global_size}
key = {"ast": str(lin.ast), "opts": str(lin.applied_opts), "allow_test_size": allow_test_size, "max_global_size": max_global_size, "clear_l2": clear_l2}
if not disable_cache and CACHELEVEL >= 2 and (val:=diskcache_get("time_linearizer", key)) is not None: return min(val)
# Set the midpoint value value for var_vals to optimize shapes.

View File

@ -174,7 +174,7 @@ _cache_dir: str = getenv("XDG_CACHE_HOME", os.path.expanduser("~/Library/Caches"
CACHEDB: str = getenv("CACHEDB", os.path.abspath(os.path.join(_cache_dir, "tinygrad", "cache.db")))
CACHELEVEL = getenv("CACHELEVEL", 2)
VERSION = 7
VERSION = 8
_db_connection = None
def db_connection():
global _db_connection

View File

@ -1,5 +1,5 @@
from __future__ import annotations
from typing import Callable, List, Tuple, Any, Dict, cast, Union, Optional
from typing import Callable, List, Tuple, Dict, cast, Union, Optional, TypeVar, Generic
import functools, itertools
from tinygrad.helpers import DEBUG, DType, merge_dicts
from tinygrad.ops import RawBuffer, Device, ASTRunner, BatchExecutor, JitItem
@ -8,12 +8,17 @@ from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.symbolic import Variable
from weakref import ref, WeakKeyDictionary
class TinyJit:
def __init__(self, fxn:Callable):
self.fxn: Callable = fxn
ReturnType = TypeVar('ReturnType')
class TinyJit(Generic[ReturnType]):
def __init__(self, fxn:Callable[..., ReturnType]):
self.fxn = fxn
self.reset()
def reset(self):
self.jit_fxn: Optional[BatchExecutor] = None
self.cnt: int = 0
self.ret: Any = None
self.ret: Optional[ReturnType] = None
self.expected_vals: Optional[Tuple[Variable, ...]] = None
self.expected_sts_dtype: Optional[Tuple[Tuple[ShapeTracker, DType], ...]] = None
@ -25,14 +30,13 @@ class TinyJit:
# add support for instance methods
def __get__(self, obj, objtype): return functools.partial(self.__call__, obj)
def __call__(self, *args, **kwargs) -> Any:
def __call__(self, *args, **kwargs) -> ReturnType:
# all inputs are realized
input_tensors: Dict[Union[int, str], Tensor] = {cast(Union[int, str], k):v.realize() for k,v in itertools.chain(enumerate(args), kwargs.items()) if v.__class__ is Tensor}
expected_sts_dtype = tuple([(v.lazydata.st.unbind(), v.dtype) for v in input_tensors.values()])
# get rawbuffers
input_rawbuffers: Dict[Union[int, str], RawBuffer] = {k:cast(RawBuffer, v.lazydata.realized) for k,v in input_tensors.items()}
assert len(input_rawbuffers) != 0, "no inputs to JIT"
assert len(set(input_rawbuffers.values())) == len(input_rawbuffers), "duplicate inputs to JIT"
# get variables: they can either be in Tensors or passed in as arguments, and all must be bound. these are all global
@ -41,7 +45,7 @@ class TinyJit:
if self.cnt >= 2:
assert self.expected_vals == expected_vals, "mismatch of var_vals"
assert self.expected_sts_dtype == expected_sts_dtype, "mismatch of sts"
assert self.expected_sts_dtype == expected_sts_dtype, f"mismatch of sts, expected {self.expected_sts_dtype} got {expected_sts_dtype}"
assert self.jit_fxn, "didn't get jitted?"
self.jit_fxn(input_rawbuffers, var_vals, DEBUG>=2)
elif self.cnt == 1:
@ -58,7 +62,7 @@ class TinyJit:
self.ret = self.fxn(*args, **kwargs)
self.cnt += 1
return self.ret
return cast(ReturnType, self.ret)
class PlaceHolder:
def __init__(self, buf:RawBuffer): self.size, self.dtype, self._device, self.ref, self.buftype, self.bufid = buf.size, buf.dtype, getattr(buf, '_device', None), ref(buf), type(buf), id(buf._buf)

View File

@ -2,9 +2,10 @@ import math
from typing import Optional, Union, Tuple
from tinygrad.tensor import Tensor
from tinygrad.helpers import prod, all_int
from tinygrad.nn import optim, state # noqa: F401
class BatchNorm2d:
def __init__(self, sz, eps=1e-5, affine=True, track_running_stats=True, momentum=0.1):
def __init__(self, sz:int, eps=1e-5, affine=True, track_running_stats=True, momentum=0.1):
self.eps, self.track_running_stats, self.momentum = eps, track_running_stats, momentum
if affine: self.weight, self.bias = Tensor.ones(sz), Tensor.zeros(sz)

View File

@ -361,7 +361,7 @@ def get_optimized_program(linearizer_opts:LinearizerOptions, to_program, ast:Laz
if used_tensor_cores:
lins.append(("hc", Linearizer(ast, linearizer_opts)))
lins[-1][1].hand_coded_optimizations()
timed = sorted([(nm, tk, time_linearizer(tk, test_rawbuffers, allow_test_size=False, disable_cache=True, clear_l2=True)) for nm, tk in lins], key=lambda x: x[2])
timed = sorted([(nm, tk, time_linearizer(tk, test_rawbuffers, allow_test_size=False, clear_l2=True)) for nm, tk in lins], key=lambda x: x[2])
if DEBUG >= 1: print(" < ".join(f"{nm:6s} : {lin.colored_shape(30, dense=True)} : {tm*1e6:8.2f} us" for nm, lin, tm in timed))
k = timed[0][1]
else:

View File

@ -181,6 +181,10 @@ class Tensor:
src = Tensor.rand(2, *shape, **kwargs)
return src[0].mul(2*math.pi).cos().mul((1 - src[1]).log().mul(-2).sqrt()).cast(Tensor.default_type if dtype is None else dtype)
@staticmethod
def randint(*shape, low=0, high=10, **kwargs) -> Tensor:
return (Tensor.rand(*shape, **kwargs)*(high-low)+low).cast(dtypes.int32)
@staticmethod
def normal(*shape, mean=0.0, std=1.0, **kwargs) -> Tensor: return (std * Tensor.randn(*shape, **kwargs)) + mean
@ -213,9 +217,9 @@ class Tensor:
assert replacement or num_samples == 1, "supported only with replacement"
p = self.unsqueeze(0) if self.ndim == 1 else self
cdf = p.cumsum(1)
cdf /= cdf[:, -1].unsqueeze(1)
cdf_normalized = cdf / cdf[:, -1].unsqueeze(1)
unif_samples = Tensor.rand(num_samples, p.shape[0], 1)
indices = (unif_samples.expand((-1, -1, p.shape[1])) >= cdf).sum(2).permute((1, 0))
indices = (unif_samples.expand((-1, -1, p.shape[1])) >= cdf_normalized).sum(2).permute((1, 0))
if self.ndim == 1: indices = indices.squeeze(0)
return indices.cast(dtypes.int32)
@ -230,7 +234,7 @@ class Tensor:
return nodes
return _deepwalk(self, set(), [])
def backward(self):
def backward(self) -> Tensor:
assert self.shape == tuple(), f"backward can only be called for scalar tensors, but it has shape {self.shape})"
# fill in the first grad with one. don't use Tensor.ones because we don't need contiguous
@ -247,11 +251,12 @@ class Tensor:
assert g.shape == t.shape, f"grad shape must match tensor shape, {g.shape!r} != {t.shape!r}"
t.grad = g if t.grad is None else (t.grad + g)
del t0._ctx
return self
# ***** movement mlops *****
def reshape(self, shape, *args) -> Tensor:
new_shape = argfix(shape, *args)
return mlops.Reshape.apply(self, shape=tuple([-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape]))
return mlops.Reshape.apply(self, shape=tuple([-prod(self.shape) // prod(new_shape) if s == -1 else (s if s is not None else self.shape[i]) for i,s in enumerate(new_shape)]))
def expand(self, shape, *args) -> Tensor: return mlops.Expand.apply(self, shape=tuple([x if x != -1 else s for s,x in zip(self.shape, argfix(shape, *args))]))
def permute(self, order, *args) -> Tensor: return mlops.Permute.apply(self, order=argfix(order, *args))
def flip(self, axis, *args) -> Tensor: return mlops.Flip.apply(self, axis=[x if x >= 0 else x+len(self.shape) for x in argfix(axis, *args)])
@ -770,6 +775,7 @@ class Tensor:
return (self.maximum(0) - y * self + (1 + self.abs().__neg__().exp()).log()).mean()
def sparse_categorical_crossentropy(self, Y, ignore_index=-1) -> Tensor:
# NOTE: self is a logits input
loss_mask = Y != ignore_index
y_counter = Tensor.arange(self.shape[-1], dtype=dtypes.int32, requires_grad=False, device=self.device).unsqueeze(0).expand(Y.numel(), self.shape[-1])
y = ((y_counter == Y.flatten().reshape(-1, 1)).where(-1.0, 0) * loss_mask.reshape(-1, 1)).reshape(*Y.shape, self.shape[-1])