add world dataset (#2045)

This commit is contained in:
George Hotz 2023-10-11 15:54:30 -07:00 committed by GitHub
parent 0c3b6f13a8
commit 0ba629c7b9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 9 additions and 5 deletions

BIN
extra/datasets/sops.gz Normal file

Binary file not shown.

View File

@ -10,11 +10,14 @@ inf, nan = float('inf'), float('nan')
from tinygrad.codegen.linearizer import Linearizer
def ast_str_to_lin(ast_str): return Linearizer(eval(ast_str))
# load worlds
# load worlds, a dataset of about 12k kernels
import gzip
from pathlib import Path
import random
from tinygrad.helpers import dedup
def load_worlds(filter_reduce=True, filter_noimage=True, filter_novariable=True):
ast_strs = dedup(open("/tmp/sops").read().strip().split("\n"))
fn = Path(__file__).parent.parent / "datasets/sops.gz"
ast_strs = dedup(gzip.open(fn).read().decode('utf-8').strip().split("\n"))
if filter_reduce: ast_strs = [x for x in ast_strs if "ReduceOps" in x]
if filter_noimage: ast_strs = [x for x in ast_strs if "dtypes.image" not in x]
if filter_novariable: ast_strs = [x for x in ast_strs if "Variable" not in x]

View File

@ -1,3 +1,4 @@
import os
import numpy as np
import math, random
from tinygrad.tensor import Tensor
@ -9,7 +10,7 @@ from extra.optimization.helpers import load_worlds, ast_str_to_lin, lin_to_feats
if __name__ == "__main__":
net = PolicyNet()
load_state_dict(net, safe_load("/tmp/policynet.safetensors"))
if os.path.isfile("/tmp/policynet.safetensors"): load_state_dict(net, safe_load("/tmp/policynet.safetensors"))
optim = Adam(get_parameters(net))
ast_strs = load_worlds()
@ -40,7 +41,7 @@ if __name__ == "__main__":
rews.append(((last_tm-tm)/base_tm))
last_tm = tm
except Exception:
rews.append(-1.0)
rews.append(-0.5)
break
#print(f"{tm*1e6:10.2f}", lin.colored_shape())
@ -49,7 +50,7 @@ if __name__ == "__main__":
print(f"***** EPISODE {len(rews)} steps, {sum(rews):5.2f} reward, {base_tm*1e6:12.2f} -> {tm*1e6:12.2f} : {lin.colored_shape()}")
all_feats += feats
all_acts += acts
all_rews += rews
all_rews += np.cumsum(rews).tolist()
BS = 32
if len(all_feats) >= BS: