Files
carrotpilot/tinygrad_repo/extra/optimization/test_net.py
carrot 77a8919349 TR16 Model, fix radar routine (#211)
* UV+DTR model

* DTR model.. again.

* fix naviGPS

* fix radar...

* fix..

* test

* fix..

* carrot serv

* fix..

* fix.. fleet

* fix.. radar

* fix atc

* Steam Powered model..

* fix.. radarLatFactor range.. 200->500

* fix.. dbc..

* side

* SP v2

* brake light

* fix brakelight

* fix..

* add datetime...

* fix..

* fix..

* fix..

* fix..

* blind spot

* fix tz

* fix..

* ff

* radarLatFactor

* fix.. bsd

* Revert "fix.. bsd"

This reverts commit 1d0d143447.

* fix.. bsd side..

* test

* fix.. e2e conditions

* Revert "test"

This reverts commit 0ce791dbd6.

* TR16

* fix cut-in detect threshold  3.4 -> 2.6

* fix.. jerk_l limit 5->10

* fix..

* fix.. gm

* fix.. OPTIMA_H mass

* fix.. radar..

* fix radar..

* fix..

* Radar...

* fix..

* fix..

* fix..

* fix.. radartrack 3

* fix..

* fix..

* fix..

* merge..

* fix.. canfd

* fix..

* fix..

* fix..

* fix.. radard

* new cut_in

* Revert "new cut_in"

This reverts commit b9b6e9b333.

* fix..

* new cut_in detect...

* fix.. disp..

* fix..

* fix..

* fix.. center radar..

* fix.. radar y_sane..

* fix..

* fix..

* hkg jerk 10 -> 5

* fix..

* fix..

* fix.. radar dbc..

* fix..

* fix.. jLead filter..

* test new radar interface..

* fix..

* fix..

* test time...

* Revert "test time..."

This reverts commit 63e9187736.

* fix radar..

* fix..

* FireHose model..

* tinygrad

* Update interface.py

* fix..

* fix.. nff toyota corolla_tss2

* fix..

* fix..

* fix.. radar

* fix..

* fix.. radar, y_gate

* fix.. radar..

* fix.. for clone..

* scc radar enable at low speed..

* fix.. settings..

* fix.

* fix..

* fix.. radarTimeStep.

* TR16 model again..

* RELEASE.md

* fix cut-in detection...

* fix.. registeration timeout 15sec..

* fix..

* fix.. radar processing.

* fix..

* fix..

* fix..

* fix..

* fix..

* fix..
2025-09-05 15:43:10 +09:00

67 lines
2.2 KiB
Python

import numpy as np
import math
import random
np.set_printoptions(suppress=True)
from copy import deepcopy
from tinygrad.helpers import getenv, colored
from tinygrad.tensor import Tensor
from tinygrad.nn.state import get_parameters, get_state_dict, safe_save, safe_load, load_state_dict
from tinygrad.codegen.opt.search import bufs_from_lin, actions, get_kernel_actions
from tinygrad.codegen.opt.heuristic import hand_coded_optimizations
from extra.optimization.helpers import load_worlds, ast_str_to_lin, lin_to_feats, time_linearizer
from extra.optimization.extract_policynet import PolicyNet
from extra.optimization.pretrain_valuenet import ValueNet
VALUE = getenv("VALUE")
if __name__ == "__main__":
if VALUE:
net = ValueNet()
load_state_dict(net, safe_load("/tmp/valuenet.safetensors"))
else:
net = PolicyNet()
load_state_dict(net, safe_load("/tmp/policynet.safetensors"))
ast_strs = load_worlds()
# real randomness
random.seed()
random.shuffle(ast_strs)
wins = 0
for ep_num,ast_str in enumerate(ast_strs):
print("\nEPISODE", ep_num, f"win {wins*100/max(1,ep_num):.2f}%")
lin = ast_str_to_lin(ast_str)
rawbufs = bufs_from_lin(lin)
linhc = deepcopy(lin)
linhc.applied_opts(hand_coded_optimizations(linhc))
tmhc = time_linearizer(linhc, rawbufs)
print(f"{tmhc*1e6:10.2f} HC ", linhc.colored_shape())
pred_time = float('nan')
tm = float('inf')
while 1:
if VALUE:
acts,feats = [], []
for k,v in get_kernel_actions(lin).items():
acts.append(k)
feats.append(lin_to_feats(v))
preds = net(Tensor(feats))
pred_time = math.exp(preds.numpy().min())
act = acts[preds.numpy().argmin()]
else:
probs = net(Tensor([lin_to_feats(lin)]))
dist = probs.exp().numpy()
act = dist.argmax()
if act == 0: break
try:
lin.apply_opt(actions[act-1])
except Exception:
print("FAILED")
break
tm = time_linearizer(lin, rawbufs)
print(f"{tm*1e6:10.2f} {pred_time*1e6:10.2f}", lin.colored_shape())
print(f"{colored('BEAT', 'green') if tm < tmhc else colored('lost', 'red')} hand coded {tmhc/tm:5.2f}x")
wins += int(tm < tmhc)