more scheduler process replay tooling (#5706)

* more scheduler process replay tooling

* refactor to compare_schedule
This commit is contained in:
qazal 2024-07-25 20:47:18 +08:00 committed by GitHub
parent 4e070a2c89
commit 489cda827a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 33 additions and 8 deletions

View File

@ -1,8 +1,9 @@
#!/usr/bin/env python3
# compare kernels created by HEAD against master
import difflib, pickle, multiprocessing, os, logging, sqlite3
import difflib, pickle, multiprocessing, os, logging, sqlite3, requests, io, zipfile
from typing import List
from tinygrad.codegen.kernel import Kernel
from tinygrad.device import Device
from tinygrad.helpers import Context, ContextVar, colored, db_connection, VERSION, getenv, tqdm
from tinygrad.ops import LazyOp
@ -12,7 +13,9 @@ MAX_DIFF_PCT = getenv("PROCESS_REPLAY_MAX_DIFF_PCT", 20)
TABLE_NAME = f"process_replay_{getenv('GITHUB_RUN_ID', 'HEAD')}_{VERSION}"
REF_TABLE_NAME = f"process_replay_master_{VERSION}"
ASSERT_DIFF = getenv("ASSERT_PROCESS_REPLAY", int((k:="[run_process_replay]") in os.getenv("COMMIT_MESSAGE", k) or k in os.getenv("PR_TITLE", k)))
SKIP_PROCESS_REPLAY = (k:="[skip_process_replay]") in os.getenv("COMMIT_MESSAGE", "") or k in os.getenv("PR_TITLE", "") or REF == "master"
if REF == "master": ASSERT_DIFF = False
COMPARE_SCHEDULE = getenv("COMPARE_SCHEDULE", int((k:="[compare_schedule]") in os.getenv("COMMIT_MESSAGE", k) or k in os.getenv("PR_TITLE", k)))
SKIP_PROCESS_REPLAY = (k:="[skip_process_replay]") in os.getenv("COMMIT_MESSAGE", "") or k in os.getenv("PR_TITLE", "")
early_stop = multiprocessing.Event()
logging.basicConfig(level=logging.INFO, format='%(message)s')
@ -39,7 +42,7 @@ def process_replay(offset:int, ref_schedule:List[LazyOp]):
if ASSERT_DIFF: raise e
continue
# try compare
if getenv("COMPARE_SCHEDULE") and ast not in ref_schedule:
if COMPARE_SCHEDULE and ast not in ref_schedule:
with Context(**{k:v for k,v in ctx.items() if k in ContextVar._cache and k != "DEBUG"}):
print(opts.render(name, Kernel(ast, opts=opts).linearize().uops))
continue
@ -61,25 +64,47 @@ def process_replay(offset:int, ref_schedule:List[LazyOp]):
cur.close()
def get_ref_schedule(offset:int, ref_schedule):
conn = db_connection()
conn = sqlite3.connect("/tmp/process_replay/process_replay.db")
cur = conn.cursor()
cur.execute(f"SELECT val FROM '{REF_TABLE_NAME}' LIMIT ? OFFSET ?", (PAGE_SIZE, offset))
for row in cur.fetchall(): ref_schedule.append(pickle.loads(row[0])[0])
conn.commit()
cur.close()
if __name__ == "__main__":
if SKIP_PROCESS_REPLAY:
logging.info("skipping process replay.")
exit(0)
conn = db_connection()
cur = conn.cursor()
ref_schedule = multiprocessing.Manager().list()
if getenv("COMPARE_SCHEDULE"):
row_count = cur.execute(f"select count(*) from '{REF_TABLE_NAME}'").fetchone()[0]
# *** download the reference schedule
if COMPARE_SCHEDULE:
logging.info("fetching process replay reference")
# TODO: make this run_id dynamic
run_id = "10093148840"
name = f"process_replay_{Device.DEFAULT.lower()}.db" # TODO: onnx and openpilot is matrix.task
url = f"https://api.github.com/repos/{os.getenv('GITHUB_REPOSITORY', 'tinygrad/tinygrad')}/actions/runs/{run_id}/artifacts?name={name}"
headers = {
"Authorization": f"Bearer {os.getenv('GITHUB_TOKEN')}",
"Accept": "application/vnd.github+json", "X-GitHub-Api-Version": "2022-11-28"
}
res = requests.get(url, headers=headers)
assert res.status_code == 200, f"download failed {res.status_code} {res.json()}"
download_url = res.json()["artifacts"][0]["archive_download_url"]
res = requests.get(download_url, headers=headers)
assert res.status_code == 200, f"download failed {res.status_code}"
with io.BytesIO(res.content) as zip_content:
with zipfile.ZipFile(zip_content, "r") as zip_ref: zip_ref.extractall("/tmp/process_replay/")
ref_conn = sqlite3.connect("/tmp/process_replay/process_replay.db")
row_count = ref_conn.execute(f"select count(*) from '{REF_TABLE_NAME}'").fetchone()[0]
processes = []
for i in tqdm(range(0, row_count, PAGE_SIZE)):
processes.append(p:=multiprocessing.Process(target=get_ref_schedule, args=(i, ref_schedule)))
p.start()
for p in processes: p.join()
ref_conn.close()
conn = db_connection()
cur = conn.cursor()
# *** run the comparison
try: row_count = cur.execute(f"select count(*) from '{TABLE_NAME}'").fetchone()[0]
except sqlite3.OperationalError:
logging.warning(f"{TABLE_NAME} isn't accessible in master, did DB_VERSION change?")