mirror of https://github.com/commaai/tinygrad.git
add world dataset (#2045)
This commit is contained in:
parent
0c3b6f13a8
commit
0ba629c7b9
Binary file not shown.
|
@ -10,11 +10,14 @@ inf, nan = float('inf'), float('nan')
|
|||
from tinygrad.codegen.linearizer import Linearizer
|
||||
def ast_str_to_lin(ast_str): return Linearizer(eval(ast_str))
|
||||
|
||||
# load worlds
|
||||
# load worlds, a dataset of about 12k kernels
|
||||
import gzip
|
||||
from pathlib import Path
|
||||
import random
|
||||
from tinygrad.helpers import dedup
|
||||
def load_worlds(filter_reduce=True, filter_noimage=True, filter_novariable=True):
|
||||
ast_strs = dedup(open("/tmp/sops").read().strip().split("\n"))
|
||||
fn = Path(__file__).parent.parent / "datasets/sops.gz"
|
||||
ast_strs = dedup(gzip.open(fn).read().decode('utf-8').strip().split("\n"))
|
||||
if filter_reduce: ast_strs = [x for x in ast_strs if "ReduceOps" in x]
|
||||
if filter_noimage: ast_strs = [x for x in ast_strs if "dtypes.image" not in x]
|
||||
if filter_novariable: ast_strs = [x for x in ast_strs if "Variable" not in x]
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import os
|
||||
import numpy as np
|
||||
import math, random
|
||||
from tinygrad.tensor import Tensor
|
||||
|
@ -9,7 +10,7 @@ from extra.optimization.helpers import load_worlds, ast_str_to_lin, lin_to_feats
|
|||
|
||||
if __name__ == "__main__":
|
||||
net = PolicyNet()
|
||||
load_state_dict(net, safe_load("/tmp/policynet.safetensors"))
|
||||
if os.path.isfile("/tmp/policynet.safetensors"): load_state_dict(net, safe_load("/tmp/policynet.safetensors"))
|
||||
optim = Adam(get_parameters(net))
|
||||
|
||||
ast_strs = load_worlds()
|
||||
|
@ -40,7 +41,7 @@ if __name__ == "__main__":
|
|||
rews.append(((last_tm-tm)/base_tm))
|
||||
last_tm = tm
|
||||
except Exception:
|
||||
rews.append(-1.0)
|
||||
rews.append(-0.5)
|
||||
break
|
||||
#print(f"{tm*1e6:10.2f}", lin.colored_shape())
|
||||
|
||||
|
@ -49,7 +50,7 @@ if __name__ == "__main__":
|
|||
print(f"***** EPISODE {len(rews)} steps, {sum(rews):5.2f} reward, {base_tm*1e6:12.2f} -> {tm*1e6:12.2f} : {lin.colored_shape()}")
|
||||
all_feats += feats
|
||||
all_acts += acts
|
||||
all_rews += rews
|
||||
all_rews += np.cumsum(rews).tolist()
|
||||
|
||||
BS = 32
|
||||
if len(all_feats) >= BS:
|
||||
|
|
Loading…
Reference in New Issue