process replay rewrite (#6284)

* process replay rewrite

p2

* start some unittests + exceptions and exits

* shebang

* remove extra kernel init
This commit is contained in:
qazal 2024-08-29 20:08:27 +08:00 committed by GitHub
parent 7de4eac8f7
commit dd4e5f1c8d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 122 additions and 91 deletions

View File

@ -1,28 +1,43 @@
#!/usr/bin/env python3
# compare kernels created by HEAD against master
import difflib, pickle, multiprocessing, os, logging, sqlite3, requests
from tabulate import tabulate
from datetime import datetime
from typing import Dict, List, cast
from test.external.process_replay.utils import print_diff
import os, multiprocessing, logging, pickle, sqlite3
from typing import Callable, List, cast
from tinygrad.helpers import VERSION, Context, ContextVar, db_connection, getenv, tqdm
from tinygrad.codegen.kernel import Kernel
from tinygrad.helpers import Context, ContextVar, colored, db_connection, VERSION, getenv, tqdm
from test.external.process_replay.utils import print_diff
# *** process replay settings
# internal
PAGE_SIZE = 100
REF = os.getenv("GITHUB_REF_NAME", "")
MAX_DIFF_PCT = getenv("PROCESS_REPLAY_MAX_DIFF_PCT", 20)
RUN_ID = os.getenv("GITHUB_RUN_ID", "HEAD")
TABLE_NAME = f"process_replay_{RUN_ID}_{getenv('GITHUB_RUN_ATTEMPT')}_{VERSION}"
ASSERT_DIFF = getenv("ASSERT_PROCESS_REPLAY", int((k:="[run_process_replay]") in os.getenv("COMMIT_MESSAGE", k) or k in os.getenv("PR_TITLE", k)))
COMPARE_SCHEDULE = getenv("COMPARE_SCHEDULE", int((k:="[compare_schedule]") in os.getenv("COMMIT_MESSAGE", k) or k in os.getenv("PR_TITLE", k)))
SKIP_PROCESS_REPLAY = (k:="[skip_process_replay]") in os.getenv("COMMIT_MESSAGE", "") or k in os.getenv("PR_TITLE", "")
if REF == "master": SKIP_PROCESS_REPLAY = True
early_stop = multiprocessing.Event()
logging.basicConfig(level=logging.INFO, format='%(message)s')
# *** github settings
BASE_URL = f"https://api.github.com/repos/{os.getenv('GITHUB_REPOSITORY', 'tinygrad/tinygrad')}"
GH_HEADERS = {"Authorization": f"Bearer {os.getenv('GH_TOKEN', '')}", "Accept": "application/vnd.github+json", "X-GitHub-Api-Version": "2022-11-28"}
os.environ["RUN_PROCESS_REPLAY"] = "0"
# user config
ASSERT_DIFF = getenv("ASSERT_PROCESS_REPLAY", int((k:="[run_process_replay]") in os.getenv("COMMIT_MESSAGE", k) or k in os.getenv("PR_TITLE", k)))
SKIP_PROCESS_REPLAY = (k:="[skip_process_replay]") in os.getenv("COMMIT_MESSAGE", "") or k in os.getenv("PR_TITLE", "")
if REF == "master": SKIP_PROCESS_REPLAY = True
# *** differs
def diff_schedule(offset:int) -> bool:
conn = db_connection()
cur = conn.cursor()
cur.execute(f"SELECT val FROM 'schedule_diff_{VERSION}' LIMIT ? OFFSET ?", (PAGE_SIZE, offset))
changed = 0
for row in cur.fetchall():
changed += 1
buf, asts = pickle.loads(row[0])
if len(asts) == 1:
logging.info(f"{buf} was folded")
logging.info(asts[0])
else: print_diff(asts[0], asts[1])
return bool(changed)
def diff_kernel(offset:int) -> bool:
if early_stop.is_set(): return True
@ -31,31 +46,33 @@ def diff_kernel(offset:int) -> bool:
cur.execute(f"SELECT val FROM '{TABLE_NAME}' LIMIT ? OFFSET ?", (PAGE_SIZE, offset))
changed = 0
for row in cur.fetchall():
ast, applied_opts = None, None
# try unpickle and linearize
# try unpickle
try: ast, opts, applied_opts, name, compare_src, ctx = pickle.loads(row[0])
except Exception as e:
logging.warning(f"FAILED TO UNPICKLE OBJECTS {e}")
if ASSERT_DIFF: return True
continue
# try linearize
try:
ast, opts, applied_opts, name, compare_src, ctx = pickle.loads(row[0])
with Context(**{k:v for k,v in ctx.items() if k in ContextVar._cache and k != "DEBUG"}):
k = Kernel(ast, opts=opts)
for opt in applied_opts: k.apply_opt(opt)
# NOTE: replay with the captured renderer, not the one in master
good_src = k.opts.render(name, cast(List,k.to_program().uops))
except Exception as e:
logging.warning("FAILED TO RECREATE KERNEL")
logging.warning(f"FAILED TO RECREATE KERNEL {e}")
logging.info(ast)
logging.info(applied_opts)
logging.info(e)
if ASSERT_DIFF: return True
continue
# diff kernels
try: assert compare_src == good_src
except AssertionError:
changed += 1
logging.info("PROCESS REPLAY DETECTED CHANGE")
logging.info(ast)
logging.info(applied_opts)
diff = list(difflib.unified_diff(good_src.splitlines(), compare_src.splitlines()))
for line in diff:
logging.info(colored(line, "red" if line.startswith("-") else "green" if line.startswith("+") else None))
print_diff(good_src, compare_src)
if ASSERT_DIFF: return True
if changed > MAX_DIFF_PCT:
logging.warning(f"detected changes in over {MAX_DIFF_PCT}% of kernels. skipping further diff generation.")
@ -65,89 +82,57 @@ def diff_kernel(offset:int) -> bool:
cur.close()
return bool(changed)
def print_ast_diff(offset:int):
# *** differ runners with multiprocessing
def _run_differ(row_count:int, differ:Callable[[int], bool]) -> None:
with multiprocessing.get_context("spawn").Pool(multiprocessing.cpu_count(), maxtasksperchild=16) as pool:
inputs = list(range(0, row_count, PAGE_SIZE))
changed: List[bool] = list(tqdm(pool.imap_unordered(differ, inputs), total=len(inputs)))
pool.close()
pool.join()
pool.terminate()
if any(changed) and ASSERT_DIFF: raise AssertionError("process replay detected changes")
def process_replay_schedule() -> None:
conn = db_connection()
cur = conn.cursor()
cur.execute(f"SELECT val FROM 'schedule_diff_{VERSION}' LIMIT ? OFFSET ?", (PAGE_SIZE, offset))
for row in cur.fetchall():
buf, asts = pickle.loads(row[0])
if len(asts) == 1:
logging.info(f"{buf} was folded")
logging.info(asts[0])
else: print_diff(asts[0], asts[1])
try: has_diff = cur.execute(f"select name from sqlite_master where type='table' and name='schedule_diff_{VERSION}'").fetchone()
except sqlite3.OperationalError:
logging.warning(f"schedule_diff_{VERSION} isn't accessible in master, did DB_VERSION change?")
return
if has_diff:
row_count = cur.execute(f"select count(*) from 'schedule_diff_{VERSION}'").fetchone()[0]
if row_count != 0: logging.info("***** schedule diff")
conn.commit()
cur.close()
_run_differ(row_count, diff_schedule)
def get_step_times(data) -> Dict[str, float]:
tms: Dict[str, float] = {}
for step in data["steps"][4:]:
# last task
if step["name"] == "Run actions/upload-artifact@v4": break
fmt = "%Y-%m-%dT%H:%M:%SZ"
tm = datetime.strptime(step["completed_at"], fmt) - datetime.strptime(step["started_at"], fmt)
tms[step["name"]] = tm.total_seconds()
return tms
def process_replay():
# *** speed diff (for benchmarks)
# TODO: fix this for testqualcommbenchmark
if REF == "update_benchmark" and os.environ["GITHUB_JOB"] != "testqualcommbenchmark":
name = {"testmacbenchmark": "Mac", "testnvidiabenchmark": "tinybox green", "testmorenvidiabenchmark": "tinybox green Training",
"testamdbenchmark": "tinybox red", "testmoreamdbenchmark": "tinybox red Training",
"testqualcommbenchmark": "comma"}[os.environ["GITHUB_JOB"]]
compare_jobs = requests.get(f"{BASE_URL}/actions/runs/{RUN_ID}/jobs", headers=GH_HEADERS).json()["jobs"]
compare_job = next(j for j in compare_jobs if j["name"] == f"{name} Benchmark")
ref_runs = requests.get(f"{BASE_URL}/actions/workflows/benchmark.yml/runs?per_page=1&branch=master&status=success", headers=GH_HEADERS).json()
ref_jobs = requests.get(f"{BASE_URL}/actions/runs/{ref_runs['workflow_runs'][0]['id']}/jobs").json()["jobs"]
ref_job = next(j for j in ref_jobs if j["name"] == f"{name} Benchmark")
logging.info(f"comparing speed for {compare_job['id']} against {ref_job['id']}")
compare_tms = get_step_times(compare_job)
ref_tms = get_step_times(ref_job)
diff = [[k, f"{v}s", f"{compare_tms[k]}s", f"{(((v-compare_tms[k])/v)*100):7.2f}%"] for k,v in ref_tms.items() if v>0]
logging.info(tabulate(diff, headers=["job", "master", "compare", "diff"]))
# *** schedule diff
if COMPARE_SCHEDULE:
conn = db_connection()
cur = conn.cursor()
try: has_diff = cur.execute(f"select name from sqlite_master where type='table' and name='schedule_diff_{VERSION}'").fetchone()
except sqlite3.OperationalError:
logging.warning(f"schedule_diff_{VERSION} isn't accessible in master, did DB_VERSION change?")
exit(0)
if has_diff:
logging.info("***** schedule diff")
row_count = cur.execute(f"select count(*) from 'schedule_diff_{VERSION}'").fetchone()[0]
conn.commit()
cur.close()
with multiprocessing.get_context("spawn").Pool(multiprocessing.cpu_count(), maxtasksperchild=16) as pool:
inputs = list(range(0, row_count, PAGE_SIZE))
list(tqdm(pool.imap_unordered(print_ast_diff, inputs), total=len(inputs)))
pool.close()
pool.join()
pool.terminate()
if ASSERT_DIFF: raise Exception("kernel process replay detected changes")
# *** kernel diff
logging.info("***** kernel diff")
def process_replay_kernel() -> None:
conn = db_connection()
cur = conn.cursor()
try: row_count = cur.execute(f"select count(*) from '{TABLE_NAME}'").fetchone()[0]
except sqlite3.OperationalError:
logging.warning(f"{TABLE_NAME} isn't accessible in master, did DB_VERSION change?")
exit(0)
return None
conn.commit()
cur.close()
with multiprocessing.get_context("spawn").Pool(multiprocessing.cpu_count(), maxtasksperchild=16) as pool:
inputs = list(range(0, row_count, PAGE_SIZE))
changed = list(tqdm(pool.imap_unordered(diff_kernel, inputs), total=len(inputs)))
pool.close()
pool.join()
pool.terminate()
if any(changed) and ASSERT_DIFF: raise Exception("kernel process replay detected changes")
_run_differ(row_count, diff_kernel)
# *** main loop
if __name__ == "__main__":
if SKIP_PROCESS_REPLAY:
logging.info("skipping process replay.")
exit(0)
try: process_replay()
logging.info("***** schedule diff")
try: process_replay_schedule()
except Exception as e:
# TODO: catch specific Exception
if ASSERT_DIFF: raise e
logging.error(f"schedule diff err {e}")
logging.info("***** kernel diff")
try: process_replay_kernel()
except Exception as e:
if ASSERT_DIFF: raise e
logging.error(f"kernel diff err {e}")

View File

@ -0,0 +1,46 @@
import unittest
from test.external.process_replay.process_replay import TABLE_NAME, diff_kernel
from tinygrad.codegen.kernel import Kernel
from tinygrad.helpers import to_function_name, db_connection, diskcache_put, VERSION
from tinygrad.ops import UOp
from tinygrad.renderer.cstyle import ClangRenderer
from tinygrad.tensor import Tensor
def helper_append_replay(ast:UOp, name:str, src:str) -> int:
diskcache_put(TABLE_NAME.replace(f"_{VERSION}", ""), "test_1", (ast, ClangRenderer(), [], to_function_name(name), src, {}))
conn = db_connection()
row_count = conn.execute(f"select count(*) from '{TABLE_NAME}'").fetchone()[0]
return row_count
class TestProcessReplay(unittest.TestCase):
def tearDown(self):
conn = db_connection()
cur = conn.cursor()
cur.execute(f"DELETE FROM '{TABLE_NAME}' WHERE key LIKE 'test_%'")
conn.commit()
cur.close()
def test_simple_diff(self):
out = Tensor([1, 2, 3])+1
ast = out.schedule()[-1].ast
test_src = """
void test(int* restrict a, const int* restrict b) {
for (int ridx0 = 0; ridx0 < 3; ridx0++) {
int val0 = b[ridx0];
a[ridx0] = (val0+1);
}
}
"""
offset = helper_append_replay(ast, "test", test_src)
assert diff_kernel(offset-1)
def test_identical_run(self):
out = Tensor([1, 2, 3])+1
ast = out.schedule()[-1].ast
test_prg = Kernel(ast, ClangRenderer()).to_program()
offset = helper_append_replay(ast, test_prg.name, test_prg.src)
assert not diff_kernel(offset)
if __name__ == "__main__":
unittest.main()

View File

@ -754,7 +754,7 @@ class Kernel:
if getenv("RUN_PROCESS_REPLAY"):
table_name = f"process_replay_{getenv('GITHUB_RUN_ID', 'HEAD')}_{getenv('GITHUB_RUN_ATTEMPT')}"
diskcache_put(table_name, id(self), (self.ast, self.opts, self.applied_opts, name, src, {k:v.value for k,v in ContextVar._cache.items()}))
diskcache_put(table_name, str(id(self)), (self.ast, self.opts, self.applied_opts, name, src, {k:v.value for k,v in ContextVar._cache.items()}))
# group non-local bufs by the op type (LOAD or STORE) and the buffer arg. take the max access of that buffer in bytes
# TODO: these max and min don't work on symbolic, and results are very wrong.