Files
carrotpilot/tinygrad_repo/examples/beautiful_mnist.py
carrot 77a8919349 TR16 Model, fix radar routine (#211)
* UV+DTR model

* DTR model.. again.

* fix naviGPS

* fix radar...

* fix..

* test

* fix..

* carrot serv

* fix..

* fix.. fleet

* fix.. radar

* fix atc

* Steam Powered model..

* fix.. radarLatFactor range.. 200->500

* fix.. dbc..

* side

* SP v2

* brake light

* fix brakelight

* fix..

* add datetime...

* fix..

* fix..

* fix..

* fix..

* blind spot

* fix tz

* fix..

* ff

* radarLatFactor

* fix.. bsd

* Revert "fix.. bsd"

This reverts commit 1d0d143447.

* fix.. bsd side..

* test

* fix.. e2e conditions

* Revert "test"

This reverts commit 0ce791dbd6.

* TR16

* fix cut-in detect threshold  3.4 -> 2.6

* fix.. jerk_l limit 5->10

* fix..

* fix.. gm

* fix.. OPTIMA_H mass

* fix.. radar..

* fix radar..

* fix..

* Radar...

* fix..

* fix..

* fix..

* fix.. radartrack 3

* fix..

* fix..

* fix..

* merge..

* fix.. canfd

* fix..

* fix..

* fix..

* fix.. radard

* new cut_in

* Revert "new cut_in"

This reverts commit b9b6e9b333.

* fix..

* new cut_in detect...

* fix.. disp..

* fix..

* fix..

* fix.. center radar..

* fix.. radar y_sane..

* fix..

* fix..

* hkg jerk 10 -> 5

* fix..

* fix..

* fix.. radar dbc..

* fix..

* fix.. jLead filter..

* test new radar interface..

* fix..

* fix..

* test time...

* Revert "test time..."

This reverts commit 63e9187736.

* fix radar..

* fix..

* FireHose model..

* tinygrad

* Update interface.py

* fix..

* fix.. nff toyota corolla_tss2

* fix..

* fix..

* fix.. radar

* fix..

* fix.. radar, y_gate

* fix.. radar..

* fix.. for clone..

* scc radar enable at low speed..

* fix.. settings..

* fix.

* fix..

* fix.. radarTimeStep.

* TR16 model again..

* RELEASE.md

* fix cut-in detection...

* fix.. registeration timeout 15sec..

* fix..

* fix.. radar processing.

* fix..

* fix..

* fix..

* fix..

* fix..

* fix..
2025-09-05 15:43:10 +09:00

48 lines
1.9 KiB
Python

# model based off https://medium.com/data-science/going-beyond-99-mnist-handwritten-digits-recognition-cfff96337392
from typing import Callable
from tinygrad import Tensor, TinyJit, nn, GlobalCounters
from tinygrad.helpers import getenv, colored, trange
from tinygrad.nn.datasets import mnist
class Model:
def __init__(self):
self.layers: list[Callable[[Tensor], Tensor]] = [
nn.Conv2d(1, 32, 5), Tensor.relu,
nn.Conv2d(32, 32, 5), Tensor.relu,
nn.BatchNorm(32), Tensor.max_pool2d,
nn.Conv2d(32, 64, 3), Tensor.relu,
nn.Conv2d(64, 64, 3), Tensor.relu,
nn.BatchNorm(64), Tensor.max_pool2d,
lambda x: x.flatten(1), nn.Linear(576, 10)]
def __call__(self, x:Tensor) -> Tensor: return x.sequential(self.layers)
if __name__ == "__main__":
X_train, Y_train, X_test, Y_test = mnist(fashion=getenv("FASHION"))
model = Model()
opt = (nn.optim.Adam if not getenv("MUON") else nn.optim.Muon)(nn.state.get_parameters(model))
@TinyJit
@Tensor.train()
def train_step() -> Tensor:
opt.zero_grad()
samples = Tensor.randint(getenv("BS", 512), high=X_train.shape[0])
loss = model(X_train[samples]).sparse_categorical_crossentropy(Y_train[samples]).backward()
return loss.realize(*opt.schedule_step())
@TinyJit
def get_test_acc() -> Tensor: return (model(X_test).argmax(axis=1) == Y_test).mean()*100
test_acc = float('nan')
for i in (t:=trange(getenv("STEPS", 70))):
GlobalCounters.reset() # NOTE: this makes it nice for DEBUG=2 timing
loss = train_step()
if i%10 == 9: test_acc = get_test_acc().item()
t.set_description(f"loss: {loss.item():6.2f} test_accuracy: {test_acc:5.2f}%")
# verify eval acc
if target := getenv("TARGET_EVAL_ACC_PCT", 0.0):
if test_acc >= target and test_acc != 100.0: print(colored(f"{test_acc=} >= {target}", "green"))
else: raise ValueError(colored(f"{test_acc=} < {target}", "red"))