From 13846930cd43b1cfd8f7bb2967529f08c08cb6d6 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Sun, 13 Oct 2024 11:18:23 +0300 Subject: [PATCH] hotfix: extract_dataset.py (#7029) --- extra/optimization/extract_dataset.py | 18 +++++++----------- test/external/process_replay/process_replay.py | 6 +++--- 2 files changed, 10 insertions(+), 14 deletions(-) diff --git a/extra/optimization/extract_dataset.py b/extra/optimization/extract_dataset.py index 196c19d6..e1eccca0 100755 --- a/extra/optimization/extract_dataset.py +++ b/extra/optimization/extract_dataset.py @@ -1,22 +1,18 @@ #!/usr/bin/env python3 # extract asts from process replay artifacts -import os, pickle -from tinygrad.helpers import db_connection, getenv, VERSION +import os +from tinygrad.helpers import db_connection, VERSION from test.external.process_replay.process_replay import _pmap PAGE_SIZE = 100 -RUN_ID = os.getenv("GITHUB_RUN_ID", "HEAD") -TABLE_NAME = f"process_replay_{RUN_ID}_{getenv('GITHUB_RUN_ATTEMPT')}_{VERSION}" +TABLE_NAME = f"kernel_process_replay_{VERSION}" LOGOPS = os.getenv("LOGOPS", "/tmp/sops") -def extract_ast(offset:int) -> bool: - logops = open(LOGOPS, "a") - conn = db_connection() - for row in conn.execute(f"SELECT val FROM '{TABLE_NAME}' LIMIT ? OFFSET ?", (PAGE_SIZE, offset)).fetchall(): - logops.write(str(pickle.loads(row[0])[0]).replace("\n", "").replace(" ", "")+"\n") - return False +def extract_ast(*args) -> bool: + open(LOGOPS, "a").write(str(args[0]).replace("\n", "").replace(" ", "")+"\n") + return args[-1] if __name__ == "__main__": conn = db_connection() row_count = conn.execute(f"SELECT COUNT(*) FROM '{TABLE_NAME}'").fetchone()[0] - _pmap(row_count, extract_ast) + _pmap("kernel", extract_ast) diff --git a/test/external/process_replay/process_replay.py b/test/external/process_replay/process_replay.py index b23ce4a7..85c753db 100755 --- a/test/external/process_replay/process_replay.py +++ b/test/external/process_replay/process_replay.py @@ -29,8 +29,8 @@ if REF == "master": SKIP_PROCESS_REPLAY = True # *** recreators -def recreate_sched(sink:UOp, ctx) -> UOp: return full_ast_rewrite(sink, ctx) -def recreate_kernel(ast:UOp, opts:Renderer, applied_opts:List[Opt], name:str, ctx:ProcessReplayContext) -> str: +def recreate_sched(sink:UOp, ctx, _) -> UOp: return full_ast_rewrite(sink, ctx) +def recreate_kernel(ast:UOp, opts:Renderer, applied_opts:List[Opt], name:str, ctx:ProcessReplayContext, _) -> str: with Context(**{k:v for k,v in ctx.ctx_vars.items() if k in ContextVar._cache and k != "DEBUG"}): k = Kernel(ast, opts=opts) for opt in applied_opts: k.apply_opt(opt) @@ -53,7 +53,7 @@ def diff(offset:int, name:str, fxn:Callable) -> Union[Tuple[int, int], bool]: if ASSERT_DIFF: return True continue # try recreate - try: good = fxn(*args[:-1]) + try: good = fxn(*args) except Exception as e: logging.warning(f"FAILED TO RECREATE KERNEL {e}") for x in args[:-1]: logging.info(x)