mirror of https://github.com/commaai/tinygrad.git
parent
8d6cecb25c
commit
0c3b6f13a8
|
@ -13,9 +13,11 @@ def ast_str_to_lin(ast_str): return Linearizer(eval(ast_str))
|
|||
# load worlds
|
||||
import random
|
||||
from tinygrad.helpers import dedup
|
||||
def load_worlds():
|
||||
def load_worlds(filter_reduce=True, filter_noimage=True, filter_novariable=True):
|
||||
ast_strs = dedup(open("/tmp/sops").read().strip().split("\n"))
|
||||
ast_strs = [x for x in ast_strs if "ReduceOps" in x and "dtypes.image" not in x and "Variable" not in x]
|
||||
if filter_reduce: ast_strs = [x for x in ast_strs if "ReduceOps" in x]
|
||||
if filter_noimage: ast_strs = [x for x in ast_strs if "dtypes.image" not in x]
|
||||
if filter_novariable: ast_strs = [x for x in ast_strs if "Variable" not in x]
|
||||
random.seed(1337)
|
||||
random.shuffle(ast_strs)
|
||||
return ast_strs
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import os
|
||||
from copy import deepcopy
|
||||
from tinygrad.nn import Linear
|
||||
from tinygrad.tensor import Tensor
|
||||
|
@ -18,10 +19,10 @@ class PolicyNet:
|
|||
return self.l3(x).log_softmax()
|
||||
|
||||
if __name__ == "__main__":
|
||||
ast_strs = load_worlds()
|
||||
ast_strs = load_worlds(False, False, filter_novariable=True)
|
||||
|
||||
net = PolicyNet()
|
||||
load_state_dict(net, safe_load("/tmp/policynet.safetensors"))
|
||||
if os.path.isfile("/tmp/policynet.safetensors"): load_state_dict(net, safe_load("/tmp/policynet.safetensors"))
|
||||
optim = Adam(get_parameters(net))
|
||||
|
||||
X,Y = [], []
|
||||
|
|
|
@ -0,0 +1,69 @@
|
|||
import numpy as np
|
||||
import math, random
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn.state import get_parameters, get_state_dict, safe_save, safe_load, load_state_dict
|
||||
from tinygrad.codegen.search import actions, bufs_from_lin, time_linearizer
|
||||
from tinygrad.nn.optim import Adam
|
||||
from extra.optimization.pretrain import PolicyNet
|
||||
from extra.optimization.helpers import load_worlds, ast_str_to_lin, lin_to_feats
|
||||
|
||||
if __name__ == "__main__":
|
||||
net = PolicyNet()
|
||||
load_state_dict(net, safe_load("/tmp/policynet.safetensors"))
|
||||
optim = Adam(get_parameters(net))
|
||||
|
||||
ast_strs = load_worlds()
|
||||
|
||||
# select a world
|
||||
all_feats, all_acts, all_rews = [], [], []
|
||||
while 1:
|
||||
Tensor.no_grad, Tensor.training = True, False
|
||||
lin = ast_str_to_lin(random.choice(ast_strs))
|
||||
rawbufs = bufs_from_lin(lin)
|
||||
tm = last_tm = base_tm = time_linearizer(lin, rawbufs)[0]
|
||||
|
||||
# take actions
|
||||
feats, acts, rews = [], [], []
|
||||
while 1:
|
||||
feat = lin_to_feats(lin)
|
||||
feats.append(feat)
|
||||
probs = net(Tensor([feat])).exp()[0].numpy()
|
||||
act = np.random.choice(len(probs), p=probs)
|
||||
acts.append(act)
|
||||
if act == 0:
|
||||
rews.append(0)
|
||||
break
|
||||
try:
|
||||
lin.apply_opt(actions[act-1])
|
||||
tm = time_linearizer(lin, rawbufs)[0]
|
||||
if math.isinf(tm): raise Exception("failed")
|
||||
rews.append(((last_tm-tm)/base_tm))
|
||||
last_tm = tm
|
||||
except Exception:
|
||||
rews.append(-1.0)
|
||||
break
|
||||
#print(f"{tm*1e6:10.2f}", lin.colored_shape())
|
||||
|
||||
assert len(feats) == len(acts) and len(acts) == len(rews)
|
||||
#print(rews)
|
||||
print(f"***** EPISODE {len(rews)} steps, {sum(rews):5.2f} reward, {base_tm*1e6:12.2f} -> {tm*1e6:12.2f} : {lin.colored_shape()}")
|
||||
all_feats += feats
|
||||
all_acts += acts
|
||||
all_rews += rews
|
||||
|
||||
BS = 32
|
||||
if len(all_feats) >= BS:
|
||||
Tensor.no_grad, Tensor.training = False, True
|
||||
x = Tensor(all_feats[:BS])
|
||||
mask = np.zeros((BS, len(actions)+1), dtype=np.float32)
|
||||
mask[range(BS), all_acts[:BS]] = all_rews[:BS]
|
||||
loss = -(net(x) * Tensor(mask)).mean()
|
||||
optim.zero_grad()
|
||||
loss.backward()
|
||||
optim.step()
|
||||
all_feats = all_feats[BS:]
|
||||
all_acts = all_acts[BS:]
|
||||
all_rews = all_rews[BS:]
|
||||
|
||||
#print(rews)
|
||||
|
|
@ -1,3 +1,4 @@
|
|||
from __future__ import annotations
|
||||
from typing import Tuple, List, cast, Optional
|
||||
from dataclasses import dataclass
|
||||
import itertools, math, os
|
||||
|
@ -8,9 +9,11 @@ from tinygrad.shape.shapetracker import ShapeTracker, get_contraction
|
|||
from tinygrad.shape.view import View, strides_for_shape
|
||||
from enum import Enum, auto
|
||||
|
||||
class OptOps(Enum): UPCAST = auto(); LOCAL = auto(); GROUP = auto(); GROUPTOP = auto() # noqa: E702
|
||||
class OptOps(Enum):
|
||||
UPCAST = auto(); LOCAL = auto(); GROUP = auto(); GROUPTOP = auto() # noqa: E702
|
||||
def __lt__(self, x:OptOps): return self.value < x.value
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@dataclass(frozen=True, order=True)
|
||||
class Opt:
|
||||
op: OptOps
|
||||
axis: int
|
||||
|
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue