more process replay cleanups (#6013)

* more process replay cleanups

* comma benchmark missing
This commit is contained in:
qazal 2024-08-10 22:29:10 +08:00 committed by GitHub
parent 266afad8ed
commit 0e62076cf5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 9 additions and 20 deletions

View File

@ -1,12 +1,12 @@
#!/usr/bin/env python3
# compare kernels created by HEAD against master
import difflib, pickle, multiprocessing, os, logging, sqlite3, requests, io, zipfile
import difflib, pickle, multiprocessing, os, logging, sqlite3, requests
from tabulate import tabulate
from datetime import datetime
from typing import Dict, List, cast
from test.external.process_replay.utils import print_diff
from tinygrad.codegen.kernel import Kernel
from tinygrad.helpers import Context, ContextVar, colored, db_connection, VERSION, getenv, temp, tqdm
from tinygrad.helpers import Context, ContextVar, colored, db_connection, VERSION, getenv, tqdm
# *** process replay settings
PAGE_SIZE = 100
@ -14,12 +14,10 @@ REF = os.getenv("GITHUB_REF_NAME", "")
MAX_DIFF_PCT = getenv("PROCESS_REPLAY_MAX_DIFF_PCT", 20)
RUN_ID = os.getenv("GITHUB_RUN_ID", "HEAD")
TABLE_NAME = f"process_replay_{RUN_ID}_{getenv('GITHUB_RUN_ATTEMPT')}_{VERSION}"
REF_TABLE_NAME = f"process_replay_master_{VERSION}"
ASSERT_DIFF = getenv("ASSERT_PROCESS_REPLAY", int((k:="[run_process_replay]") in os.getenv("COMMIT_MESSAGE", k) or k in os.getenv("PR_TITLE", k)))
if REF == "master": ASSERT_DIFF = False
COMPARE_SCHEDULE = getenv("COMPARE_SCHEDULE", int((k:="[compare_schedule]") in os.getenv("COMMIT_MESSAGE", "") or k in os.getenv("PR_TITLE", "")))
SKIP_PROCESS_REPLAY = (k:="[skip_process_replay]") in os.getenv("COMMIT_MESSAGE", "") or k in os.getenv("PR_TITLE", "")
TEMP_DIR = temp("process_replay")
early_stop = multiprocessing.Event()
logging.basicConfig(level=logging.INFO, format='%(message)s')
# *** github settings
@ -79,16 +77,7 @@ def print_ast_diff(offset:int):
logging.info(asts[0])
else: print_diff(asts[0], asts[1])
def download_artifact(run_id:str, name:str, dest:str):
res = requests.get(f"{BASE_URL}/actions/runs/{run_id}/artifacts?name={name}", headers=GH_HEADERS)
assert res.status_code == 200, f"download failed {res.status_code} {res.json()}"
download_url = res.json()["artifacts"][0]["archive_download_url"]
res = requests.get(download_url, headers=GH_HEADERS)
assert res.status_code == 200, f"download failed {res.status_code}"
with io.BytesIO(res.content) as zip_content:
with zipfile.ZipFile(zip_content, "r") as zip_ref: zip_ref.extractall(dest)
def _get_times(data) -> Dict[str, float]:
def get_step_times(data) -> Dict[str, float]:
tms: Dict[str, float] = {}
for step in data["steps"][4:]:
# last task
@ -102,15 +91,16 @@ def process_replay():
# *** speed diff (for benchmarks)
if REF == "update_benchmark":
name = {"testmacbenchmark": "Mac", "testnvidiabenchmark": "tinybox green", "testmorenvidiabenchmark": "tinybox green Training",
"testamdbenchmark": "tinybox red", "testmoreamdbenchmark": "tinybox red Training"}[os.environ["GITHUB_JOB"]]
"testamdbenchmark": "tinybox red", "testmoreamdbenchmark": "tinybox red Training",
"testqualcommbenchmark": "comma Benchmark"}[os.environ["GITHUB_JOB"]]
compare_jobs = requests.get(f"{BASE_URL}/actions/runs/{RUN_ID}/jobs", headers=GH_HEADERS).json()["jobs"]
compare_job = next(j for j in compare_jobs if j["name"] == f"{name} Benchmark")
ref_runs = requests.get(f"{BASE_URL}/actions/workflows/benchmark.yml/runs?per_page=1&branch=master&status=success", headers=GH_HEADERS).json()
ref_jobs = requests.get(f"{BASE_URL}/actions/runs/{ref_runs['workflow_runs'][0]['id']}/jobs").json()["jobs"]
ref_job = next(j for j in ref_jobs if j["name"] == f"{name} Benchmark")
logging.info(f"comparing speed for {compare_job['id']} against {ref_job['id']}")
compare_tms = _get_times(compare_job)
ref_tms = _get_times(ref_job)
compare_tms = get_step_times(compare_job)
ref_tms = get_step_times(ref_job)
diff = [[k, f"{v}s", f"{compare_tms[k]}s", f"{(((v-compare_tms[k])/v)*100):7.2f}%"] for k,v in ref_tms.items() if v>0]
logging.info(tabulate(diff, headers=["job", "master", "compare", "diff"]))
@ -127,7 +117,6 @@ def process_replay():
conn.commit()
cur.close()
processes = []
changed = multiprocessing.Manager().Value('b', False)
for i in tqdm(range(0, row_count, PAGE_SIZE)):
processes.append(p:=multiprocessing.Process(target=print_ast_diff, args=(i,)))
p.start()

View File

@ -2,11 +2,11 @@
# should assert
sed -i 's/temp/temp1/g' ./tinygrad/codegen/lowerer.py
COMPRAE_SCHEDULE=0 ASSERT_PROCESS_REPLAY=1 python3 test/external/process_replay/process_replay.py &> /dev/null
COMPARE_SCHEDULE=0 ASSERT_PROCESS_REPLAY=1 python3 test/external/process_replay/process_replay.py &> /dev/null
if [[ $? -eq 0 ]]; then
echo "PROCESS REPLAY IS WRONG."
exit 1
fi
# should NOT assert
git stash > /dev/null
COMPRAE_SCHEDULE=0 ASSERT_PROCESS_REPLAY=1 python3 test/external/process_replay/process_replay.py &> /dev/null
COMPARE_SCHEDULE=0 ASSERT_PROCESS_REPLAY=1 python3 test/external/process_replay/process_replay.py &> /dev/null