process replay rewrite (#6284)

* process replay rewrite

p2

* start some unittests + exceptions and exits

* shebang

* remove extra kernel init
This commit is contained in:
qazal 2024-08-29 20:08:27 +08:00 committed by GitHub
parent 7de4eac8f7
commit dd4e5f1c8d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 122 additions and 91 deletions

View File

@ -1,28 +1,43 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# compare kernels created by HEAD against master # compare kernels created by HEAD against master
import difflib, pickle, multiprocessing, os, logging, sqlite3, requests import os, multiprocessing, logging, pickle, sqlite3
from tabulate import tabulate from typing import Callable, List, cast
from datetime import datetime from tinygrad.helpers import VERSION, Context, ContextVar, db_connection, getenv, tqdm
from typing import Dict, List, cast
from test.external.process_replay.utils import print_diff
from tinygrad.codegen.kernel import Kernel from tinygrad.codegen.kernel import Kernel
from tinygrad.helpers import Context, ContextVar, colored, db_connection, VERSION, getenv, tqdm from test.external.process_replay.utils import print_diff
# *** process replay settings # *** process replay settings
# internal
PAGE_SIZE = 100 PAGE_SIZE = 100
REF = os.getenv("GITHUB_REF_NAME", "") REF = os.getenv("GITHUB_REF_NAME", "")
MAX_DIFF_PCT = getenv("PROCESS_REPLAY_MAX_DIFF_PCT", 20) MAX_DIFF_PCT = getenv("PROCESS_REPLAY_MAX_DIFF_PCT", 20)
RUN_ID = os.getenv("GITHUB_RUN_ID", "HEAD") RUN_ID = os.getenv("GITHUB_RUN_ID", "HEAD")
TABLE_NAME = f"process_replay_{RUN_ID}_{getenv('GITHUB_RUN_ATTEMPT')}_{VERSION}" TABLE_NAME = f"process_replay_{RUN_ID}_{getenv('GITHUB_RUN_ATTEMPT')}_{VERSION}"
ASSERT_DIFF = getenv("ASSERT_PROCESS_REPLAY", int((k:="[run_process_replay]") in os.getenv("COMMIT_MESSAGE", k) or k in os.getenv("PR_TITLE", k)))
COMPARE_SCHEDULE = getenv("COMPARE_SCHEDULE", int((k:="[compare_schedule]") in os.getenv("COMMIT_MESSAGE", k) or k in os.getenv("PR_TITLE", k)))
SKIP_PROCESS_REPLAY = (k:="[skip_process_replay]") in os.getenv("COMMIT_MESSAGE", "") or k in os.getenv("PR_TITLE", "")
if REF == "master": SKIP_PROCESS_REPLAY = True
early_stop = multiprocessing.Event() early_stop = multiprocessing.Event()
logging.basicConfig(level=logging.INFO, format='%(message)s') logging.basicConfig(level=logging.INFO, format='%(message)s')
# *** github settings os.environ["RUN_PROCESS_REPLAY"] = "0"
BASE_URL = f"https://api.github.com/repos/{os.getenv('GITHUB_REPOSITORY', 'tinygrad/tinygrad')}"
GH_HEADERS = {"Authorization": f"Bearer {os.getenv('GH_TOKEN', '')}", "Accept": "application/vnd.github+json", "X-GitHub-Api-Version": "2022-11-28"} # user config
ASSERT_DIFF = getenv("ASSERT_PROCESS_REPLAY", int((k:="[run_process_replay]") in os.getenv("COMMIT_MESSAGE", k) or k in os.getenv("PR_TITLE", k)))
SKIP_PROCESS_REPLAY = (k:="[skip_process_replay]") in os.getenv("COMMIT_MESSAGE", "") or k in os.getenv("PR_TITLE", "")
if REF == "master": SKIP_PROCESS_REPLAY = True
# *** differs
def diff_schedule(offset:int) -> bool:
conn = db_connection()
cur = conn.cursor()
cur.execute(f"SELECT val FROM 'schedule_diff_{VERSION}' LIMIT ? OFFSET ?", (PAGE_SIZE, offset))
changed = 0
for row in cur.fetchall():
changed += 1
buf, asts = pickle.loads(row[0])
if len(asts) == 1:
logging.info(f"{buf} was folded")
logging.info(asts[0])
else: print_diff(asts[0], asts[1])
return bool(changed)
def diff_kernel(offset:int) -> bool: def diff_kernel(offset:int) -> bool:
if early_stop.is_set(): return True if early_stop.is_set(): return True
@ -31,31 +46,33 @@ def diff_kernel(offset:int) -> bool:
cur.execute(f"SELECT val FROM '{TABLE_NAME}' LIMIT ? OFFSET ?", (PAGE_SIZE, offset)) cur.execute(f"SELECT val FROM '{TABLE_NAME}' LIMIT ? OFFSET ?", (PAGE_SIZE, offset))
changed = 0 changed = 0
for row in cur.fetchall(): for row in cur.fetchall():
ast, applied_opts = None, None # try unpickle
# try unpickle and linearize try: ast, opts, applied_opts, name, compare_src, ctx = pickle.loads(row[0])
except Exception as e:
logging.warning(f"FAILED TO UNPICKLE OBJECTS {e}")
if ASSERT_DIFF: return True
continue
# try linearize
try: try:
ast, opts, applied_opts, name, compare_src, ctx = pickle.loads(row[0])
with Context(**{k:v for k,v in ctx.items() if k in ContextVar._cache and k != "DEBUG"}): with Context(**{k:v for k,v in ctx.items() if k in ContextVar._cache and k != "DEBUG"}):
k = Kernel(ast, opts=opts) k = Kernel(ast, opts=opts)
for opt in applied_opts: k.apply_opt(opt) for opt in applied_opts: k.apply_opt(opt)
# NOTE: replay with the captured renderer, not the one in master # NOTE: replay with the captured renderer, not the one in master
good_src = k.opts.render(name, cast(List,k.to_program().uops)) good_src = k.opts.render(name, cast(List,k.to_program().uops))
except Exception as e: except Exception as e:
logging.warning("FAILED TO RECREATE KERNEL") logging.warning(f"FAILED TO RECREATE KERNEL {e}")
logging.info(ast) logging.info(ast)
logging.info(applied_opts) logging.info(applied_opts)
logging.info(e)
if ASSERT_DIFF: return True if ASSERT_DIFF: return True
continue continue
# diff kernels
try: assert compare_src == good_src try: assert compare_src == good_src
except AssertionError: except AssertionError:
changed += 1 changed += 1
logging.info("PROCESS REPLAY DETECTED CHANGE") logging.info("PROCESS REPLAY DETECTED CHANGE")
logging.info(ast) logging.info(ast)
logging.info(applied_opts) logging.info(applied_opts)
diff = list(difflib.unified_diff(good_src.splitlines(), compare_src.splitlines())) print_diff(good_src, compare_src)
for line in diff:
logging.info(colored(line, "red" if line.startswith("-") else "green" if line.startswith("+") else None))
if ASSERT_DIFF: return True if ASSERT_DIFF: return True
if changed > MAX_DIFF_PCT: if changed > MAX_DIFF_PCT:
logging.warning(f"detected changes in over {MAX_DIFF_PCT}% of kernels. skipping further diff generation.") logging.warning(f"detected changes in over {MAX_DIFF_PCT}% of kernels. skipping further diff generation.")
@ -65,89 +82,57 @@ def diff_kernel(offset:int) -> bool:
cur.close() cur.close()
return bool(changed) return bool(changed)
def print_ast_diff(offset:int): # *** differ runners with multiprocessing
def _run_differ(row_count:int, differ:Callable[[int], bool]) -> None:
with multiprocessing.get_context("spawn").Pool(multiprocessing.cpu_count(), maxtasksperchild=16) as pool:
inputs = list(range(0, row_count, PAGE_SIZE))
changed: List[bool] = list(tqdm(pool.imap_unordered(differ, inputs), total=len(inputs)))
pool.close()
pool.join()
pool.terminate()
if any(changed) and ASSERT_DIFF: raise AssertionError("process replay detected changes")
def process_replay_schedule() -> None:
conn = db_connection() conn = db_connection()
cur = conn.cursor() cur = conn.cursor()
cur.execute(f"SELECT val FROM 'schedule_diff_{VERSION}' LIMIT ? OFFSET ?", (PAGE_SIZE, offset)) try: has_diff = cur.execute(f"select name from sqlite_master where type='table' and name='schedule_diff_{VERSION}'").fetchone()
for row in cur.fetchall(): except sqlite3.OperationalError:
buf, asts = pickle.loads(row[0]) logging.warning(f"schedule_diff_{VERSION} isn't accessible in master, did DB_VERSION change?")
if len(asts) == 1: return
logging.info(f"{buf} was folded") if has_diff:
logging.info(asts[0]) row_count = cur.execute(f"select count(*) from 'schedule_diff_{VERSION}'").fetchone()[0]
else: print_diff(asts[0], asts[1]) if row_count != 0: logging.info("***** schedule diff")
conn.commit()
cur.close()
_run_differ(row_count, diff_schedule)
def get_step_times(data) -> Dict[str, float]: def process_replay_kernel() -> None:
tms: Dict[str, float] = {}
for step in data["steps"][4:]:
# last task
if step["name"] == "Run actions/upload-artifact@v4": break
fmt = "%Y-%m-%dT%H:%M:%SZ"
tm = datetime.strptime(step["completed_at"], fmt) - datetime.strptime(step["started_at"], fmt)
tms[step["name"]] = tm.total_seconds()
return tms
def process_replay():
# *** speed diff (for benchmarks)
# TODO: fix this for testqualcommbenchmark
if REF == "update_benchmark" and os.environ["GITHUB_JOB"] != "testqualcommbenchmark":
name = {"testmacbenchmark": "Mac", "testnvidiabenchmark": "tinybox green", "testmorenvidiabenchmark": "tinybox green Training",
"testamdbenchmark": "tinybox red", "testmoreamdbenchmark": "tinybox red Training",
"testqualcommbenchmark": "comma"}[os.environ["GITHUB_JOB"]]
compare_jobs = requests.get(f"{BASE_URL}/actions/runs/{RUN_ID}/jobs", headers=GH_HEADERS).json()["jobs"]
compare_job = next(j for j in compare_jobs if j["name"] == f"{name} Benchmark")
ref_runs = requests.get(f"{BASE_URL}/actions/workflows/benchmark.yml/runs?per_page=1&branch=master&status=success", headers=GH_HEADERS).json()
ref_jobs = requests.get(f"{BASE_URL}/actions/runs/{ref_runs['workflow_runs'][0]['id']}/jobs").json()["jobs"]
ref_job = next(j for j in ref_jobs if j["name"] == f"{name} Benchmark")
logging.info(f"comparing speed for {compare_job['id']} against {ref_job['id']}")
compare_tms = get_step_times(compare_job)
ref_tms = get_step_times(ref_job)
diff = [[k, f"{v}s", f"{compare_tms[k]}s", f"{(((v-compare_tms[k])/v)*100):7.2f}%"] for k,v in ref_tms.items() if v>0]
logging.info(tabulate(diff, headers=["job", "master", "compare", "diff"]))
# *** schedule diff
if COMPARE_SCHEDULE:
conn = db_connection()
cur = conn.cursor()
try: has_diff = cur.execute(f"select name from sqlite_master where type='table' and name='schedule_diff_{VERSION}'").fetchone()
except sqlite3.OperationalError:
logging.warning(f"schedule_diff_{VERSION} isn't accessible in master, did DB_VERSION change?")
exit(0)
if has_diff:
logging.info("***** schedule diff")
row_count = cur.execute(f"select count(*) from 'schedule_diff_{VERSION}'").fetchone()[0]
conn.commit()
cur.close()
with multiprocessing.get_context("spawn").Pool(multiprocessing.cpu_count(), maxtasksperchild=16) as pool:
inputs = list(range(0, row_count, PAGE_SIZE))
list(tqdm(pool.imap_unordered(print_ast_diff, inputs), total=len(inputs)))
pool.close()
pool.join()
pool.terminate()
if ASSERT_DIFF: raise Exception("kernel process replay detected changes")
# *** kernel diff
logging.info("***** kernel diff")
conn = db_connection() conn = db_connection()
cur = conn.cursor() cur = conn.cursor()
try: row_count = cur.execute(f"select count(*) from '{TABLE_NAME}'").fetchone()[0] try: row_count = cur.execute(f"select count(*) from '{TABLE_NAME}'").fetchone()[0]
except sqlite3.OperationalError: except sqlite3.OperationalError:
logging.warning(f"{TABLE_NAME} isn't accessible in master, did DB_VERSION change?") logging.warning(f"{TABLE_NAME} isn't accessible in master, did DB_VERSION change?")
exit(0) return None
conn.commit() conn.commit()
cur.close() cur.close()
with multiprocessing.get_context("spawn").Pool(multiprocessing.cpu_count(), maxtasksperchild=16) as pool: _run_differ(row_count, diff_kernel)
inputs = list(range(0, row_count, PAGE_SIZE))
changed = list(tqdm(pool.imap_unordered(diff_kernel, inputs), total=len(inputs))) # *** main loop
pool.close()
pool.join()
pool.terminate()
if any(changed) and ASSERT_DIFF: raise Exception("kernel process replay detected changes")
if __name__ == "__main__": if __name__ == "__main__":
if SKIP_PROCESS_REPLAY: if SKIP_PROCESS_REPLAY:
logging.info("skipping process replay.") logging.info("skipping process replay.")
exit(0) exit(0)
try: process_replay()
logging.info("***** schedule diff")
try: process_replay_schedule()
except Exception as e: except Exception as e:
# TODO: catch specific Exception
if ASSERT_DIFF: raise e if ASSERT_DIFF: raise e
logging.error(f"schedule diff err {e}")
logging.info("***** kernel diff")
try: process_replay_kernel()
except Exception as e:
if ASSERT_DIFF: raise e
logging.error(f"kernel diff err {e}")

View File

@ -0,0 +1,46 @@
import unittest
from test.external.process_replay.process_replay import TABLE_NAME, diff_kernel
from tinygrad.codegen.kernel import Kernel
from tinygrad.helpers import to_function_name, db_connection, diskcache_put, VERSION
from tinygrad.ops import UOp
from tinygrad.renderer.cstyle import ClangRenderer
from tinygrad.tensor import Tensor
def helper_append_replay(ast:UOp, name:str, src:str) -> int:
diskcache_put(TABLE_NAME.replace(f"_{VERSION}", ""), "test_1", (ast, ClangRenderer(), [], to_function_name(name), src, {}))
conn = db_connection()
row_count = conn.execute(f"select count(*) from '{TABLE_NAME}'").fetchone()[0]
return row_count
class TestProcessReplay(unittest.TestCase):
def tearDown(self):
conn = db_connection()
cur = conn.cursor()
cur.execute(f"DELETE FROM '{TABLE_NAME}' WHERE key LIKE 'test_%'")
conn.commit()
cur.close()
def test_simple_diff(self):
out = Tensor([1, 2, 3])+1
ast = out.schedule()[-1].ast
test_src = """
void test(int* restrict a, const int* restrict b) {
for (int ridx0 = 0; ridx0 < 3; ridx0++) {
int val0 = b[ridx0];
a[ridx0] = (val0+1);
}
}
"""
offset = helper_append_replay(ast, "test", test_src)
assert diff_kernel(offset-1)
def test_identical_run(self):
out = Tensor([1, 2, 3])+1
ast = out.schedule()[-1].ast
test_prg = Kernel(ast, ClangRenderer()).to_program()
offset = helper_append_replay(ast, test_prg.name, test_prg.src)
assert not diff_kernel(offset)
if __name__ == "__main__":
unittest.main()

View File

@ -754,7 +754,7 @@ class Kernel:
if getenv("RUN_PROCESS_REPLAY"): if getenv("RUN_PROCESS_REPLAY"):
table_name = f"process_replay_{getenv('GITHUB_RUN_ID', 'HEAD')}_{getenv('GITHUB_RUN_ATTEMPT')}" table_name = f"process_replay_{getenv('GITHUB_RUN_ID', 'HEAD')}_{getenv('GITHUB_RUN_ATTEMPT')}"
diskcache_put(table_name, id(self), (self.ast, self.opts, self.applied_opts, name, src, {k:v.value for k,v in ContextVar._cache.items()})) diskcache_put(table_name, str(id(self)), (self.ast, self.opts, self.applied_opts, name, src, {k:v.value for k,v in ContextVar._cache.items()}))
# group non-local bufs by the op type (LOAD or STORE) and the buffer arg. take the max access of that buffer in bytes # group non-local bufs by the op type (LOAD or STORE) and the buffer arg. take the max access of that buffer in bytes
# TODO: these max and min don't work on symbolic, and results are very wrong. # TODO: these max and min don't work on symbolic, and results are very wrong.