mirror of https://github.com/commaai/tinygrad.git
hotfix: extract_dataset.py (#7029)
This commit is contained in:
parent
942a17109a
commit
13846930cd
|
@ -1,22 +1,18 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# extract asts from process replay artifacts
|
# extract asts from process replay artifacts
|
||||||
import os, pickle
|
import os
|
||||||
from tinygrad.helpers import db_connection, getenv, VERSION
|
from tinygrad.helpers import db_connection, VERSION
|
||||||
from test.external.process_replay.process_replay import _pmap
|
from test.external.process_replay.process_replay import _pmap
|
||||||
|
|
||||||
PAGE_SIZE = 100
|
PAGE_SIZE = 100
|
||||||
RUN_ID = os.getenv("GITHUB_RUN_ID", "HEAD")
|
TABLE_NAME = f"kernel_process_replay_{VERSION}"
|
||||||
TABLE_NAME = f"process_replay_{RUN_ID}_{getenv('GITHUB_RUN_ATTEMPT')}_{VERSION}"
|
|
||||||
LOGOPS = os.getenv("LOGOPS", "/tmp/sops")
|
LOGOPS = os.getenv("LOGOPS", "/tmp/sops")
|
||||||
|
|
||||||
def extract_ast(offset:int) -> bool:
|
def extract_ast(*args) -> bool:
|
||||||
logops = open(LOGOPS, "a")
|
open(LOGOPS, "a").write(str(args[0]).replace("\n", "").replace(" ", "")+"\n")
|
||||||
conn = db_connection()
|
return args[-1]
|
||||||
for row in conn.execute(f"SELECT val FROM '{TABLE_NAME}' LIMIT ? OFFSET ?", (PAGE_SIZE, offset)).fetchall():
|
|
||||||
logops.write(str(pickle.loads(row[0])[0]).replace("\n", "").replace(" ", "")+"\n")
|
|
||||||
return False
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
conn = db_connection()
|
conn = db_connection()
|
||||||
row_count = conn.execute(f"SELECT COUNT(*) FROM '{TABLE_NAME}'").fetchone()[0]
|
row_count = conn.execute(f"SELECT COUNT(*) FROM '{TABLE_NAME}'").fetchone()[0]
|
||||||
_pmap(row_count, extract_ast)
|
_pmap("kernel", extract_ast)
|
||||||
|
|
|
@ -29,8 +29,8 @@ if REF == "master": SKIP_PROCESS_REPLAY = True
|
||||||
|
|
||||||
# *** recreators
|
# *** recreators
|
||||||
|
|
||||||
def recreate_sched(sink:UOp, ctx) -> UOp: return full_ast_rewrite(sink, ctx)
|
def recreate_sched(sink:UOp, ctx, _) -> UOp: return full_ast_rewrite(sink, ctx)
|
||||||
def recreate_kernel(ast:UOp, opts:Renderer, applied_opts:List[Opt], name:str, ctx:ProcessReplayContext) -> str:
|
def recreate_kernel(ast:UOp, opts:Renderer, applied_opts:List[Opt], name:str, ctx:ProcessReplayContext, _) -> str:
|
||||||
with Context(**{k:v for k,v in ctx.ctx_vars.items() if k in ContextVar._cache and k != "DEBUG"}):
|
with Context(**{k:v for k,v in ctx.ctx_vars.items() if k in ContextVar._cache and k != "DEBUG"}):
|
||||||
k = Kernel(ast, opts=opts)
|
k = Kernel(ast, opts=opts)
|
||||||
for opt in applied_opts: k.apply_opt(opt)
|
for opt in applied_opts: k.apply_opt(opt)
|
||||||
|
@ -53,7 +53,7 @@ def diff(offset:int, name:str, fxn:Callable) -> Union[Tuple[int, int], bool]:
|
||||||
if ASSERT_DIFF: return True
|
if ASSERT_DIFF: return True
|
||||||
continue
|
continue
|
||||||
# try recreate
|
# try recreate
|
||||||
try: good = fxn(*args[:-1])
|
try: good = fxn(*args)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.warning(f"FAILED TO RECREATE KERNEL {e}")
|
logging.warning(f"FAILED TO RECREATE KERNEL {e}")
|
||||||
for x in args[:-1]: logging.info(x)
|
for x in args[:-1]: logging.info(x)
|
||||||
|
|
Loading…
Reference in New Issue