more process replay cleanups (#6013)

* more process replay cleanups

* comma benchmark missing
This commit is contained in:
qazal 2024-08-10 22:29:10 +08:00 committed by GitHub
parent 266afad8ed
commit 0e62076cf5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 9 additions and 20 deletions

View File

@ -1,12 +1,12 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# compare kernels created by HEAD against master # compare kernels created by HEAD against master
import difflib, pickle, multiprocessing, os, logging, sqlite3, requests, io, zipfile import difflib, pickle, multiprocessing, os, logging, sqlite3, requests
from tabulate import tabulate from tabulate import tabulate
from datetime import datetime from datetime import datetime
from typing import Dict, List, cast from typing import Dict, List, cast
from test.external.process_replay.utils import print_diff from test.external.process_replay.utils import print_diff
from tinygrad.codegen.kernel import Kernel from tinygrad.codegen.kernel import Kernel
from tinygrad.helpers import Context, ContextVar, colored, db_connection, VERSION, getenv, temp, tqdm from tinygrad.helpers import Context, ContextVar, colored, db_connection, VERSION, getenv, tqdm
# *** process replay settings # *** process replay settings
PAGE_SIZE = 100 PAGE_SIZE = 100
@ -14,12 +14,10 @@ REF = os.getenv("GITHUB_REF_NAME", "")
MAX_DIFF_PCT = getenv("PROCESS_REPLAY_MAX_DIFF_PCT", 20) MAX_DIFF_PCT = getenv("PROCESS_REPLAY_MAX_DIFF_PCT", 20)
RUN_ID = os.getenv("GITHUB_RUN_ID", "HEAD") RUN_ID = os.getenv("GITHUB_RUN_ID", "HEAD")
TABLE_NAME = f"process_replay_{RUN_ID}_{getenv('GITHUB_RUN_ATTEMPT')}_{VERSION}" TABLE_NAME = f"process_replay_{RUN_ID}_{getenv('GITHUB_RUN_ATTEMPT')}_{VERSION}"
REF_TABLE_NAME = f"process_replay_master_{VERSION}"
ASSERT_DIFF = getenv("ASSERT_PROCESS_REPLAY", int((k:="[run_process_replay]") in os.getenv("COMMIT_MESSAGE", k) or k in os.getenv("PR_TITLE", k))) ASSERT_DIFF = getenv("ASSERT_PROCESS_REPLAY", int((k:="[run_process_replay]") in os.getenv("COMMIT_MESSAGE", k) or k in os.getenv("PR_TITLE", k)))
if REF == "master": ASSERT_DIFF = False if REF == "master": ASSERT_DIFF = False
COMPARE_SCHEDULE = getenv("COMPARE_SCHEDULE", int((k:="[compare_schedule]") in os.getenv("COMMIT_MESSAGE", "") or k in os.getenv("PR_TITLE", ""))) COMPARE_SCHEDULE = getenv("COMPARE_SCHEDULE", int((k:="[compare_schedule]") in os.getenv("COMMIT_MESSAGE", "") or k in os.getenv("PR_TITLE", "")))
SKIP_PROCESS_REPLAY = (k:="[skip_process_replay]") in os.getenv("COMMIT_MESSAGE", "") or k in os.getenv("PR_TITLE", "") SKIP_PROCESS_REPLAY = (k:="[skip_process_replay]") in os.getenv("COMMIT_MESSAGE", "") or k in os.getenv("PR_TITLE", "")
TEMP_DIR = temp("process_replay")
early_stop = multiprocessing.Event() early_stop = multiprocessing.Event()
logging.basicConfig(level=logging.INFO, format='%(message)s') logging.basicConfig(level=logging.INFO, format='%(message)s')
# *** github settings # *** github settings
@ -79,16 +77,7 @@ def print_ast_diff(offset:int):
logging.info(asts[0]) logging.info(asts[0])
else: print_diff(asts[0], asts[1]) else: print_diff(asts[0], asts[1])
def download_artifact(run_id:str, name:str, dest:str): def get_step_times(data) -> Dict[str, float]:
res = requests.get(f"{BASE_URL}/actions/runs/{run_id}/artifacts?name={name}", headers=GH_HEADERS)
assert res.status_code == 200, f"download failed {res.status_code} {res.json()}"
download_url = res.json()["artifacts"][0]["archive_download_url"]
res = requests.get(download_url, headers=GH_HEADERS)
assert res.status_code == 200, f"download failed {res.status_code}"
with io.BytesIO(res.content) as zip_content:
with zipfile.ZipFile(zip_content, "r") as zip_ref: zip_ref.extractall(dest)
def _get_times(data) -> Dict[str, float]:
tms: Dict[str, float] = {} tms: Dict[str, float] = {}
for step in data["steps"][4:]: for step in data["steps"][4:]:
# last task # last task
@ -102,15 +91,16 @@ def process_replay():
# *** speed diff (for benchmarks) # *** speed diff (for benchmarks)
if REF == "update_benchmark": if REF == "update_benchmark":
name = {"testmacbenchmark": "Mac", "testnvidiabenchmark": "tinybox green", "testmorenvidiabenchmark": "tinybox green Training", name = {"testmacbenchmark": "Mac", "testnvidiabenchmark": "tinybox green", "testmorenvidiabenchmark": "tinybox green Training",
"testamdbenchmark": "tinybox red", "testmoreamdbenchmark": "tinybox red Training"}[os.environ["GITHUB_JOB"]] "testamdbenchmark": "tinybox red", "testmoreamdbenchmark": "tinybox red Training",
"testqualcommbenchmark": "comma Benchmark"}[os.environ["GITHUB_JOB"]]
compare_jobs = requests.get(f"{BASE_URL}/actions/runs/{RUN_ID}/jobs", headers=GH_HEADERS).json()["jobs"] compare_jobs = requests.get(f"{BASE_URL}/actions/runs/{RUN_ID}/jobs", headers=GH_HEADERS).json()["jobs"]
compare_job = next(j for j in compare_jobs if j["name"] == f"{name} Benchmark") compare_job = next(j for j in compare_jobs if j["name"] == f"{name} Benchmark")
ref_runs = requests.get(f"{BASE_URL}/actions/workflows/benchmark.yml/runs?per_page=1&branch=master&status=success", headers=GH_HEADERS).json() ref_runs = requests.get(f"{BASE_URL}/actions/workflows/benchmark.yml/runs?per_page=1&branch=master&status=success", headers=GH_HEADERS).json()
ref_jobs = requests.get(f"{BASE_URL}/actions/runs/{ref_runs['workflow_runs'][0]['id']}/jobs").json()["jobs"] ref_jobs = requests.get(f"{BASE_URL}/actions/runs/{ref_runs['workflow_runs'][0]['id']}/jobs").json()["jobs"]
ref_job = next(j for j in ref_jobs if j["name"] == f"{name} Benchmark") ref_job = next(j for j in ref_jobs if j["name"] == f"{name} Benchmark")
logging.info(f"comparing speed for {compare_job['id']} against {ref_job['id']}") logging.info(f"comparing speed for {compare_job['id']} against {ref_job['id']}")
compare_tms = _get_times(compare_job) compare_tms = get_step_times(compare_job)
ref_tms = _get_times(ref_job) ref_tms = get_step_times(ref_job)
diff = [[k, f"{v}s", f"{compare_tms[k]}s", f"{(((v-compare_tms[k])/v)*100):7.2f}%"] for k,v in ref_tms.items() if v>0] diff = [[k, f"{v}s", f"{compare_tms[k]}s", f"{(((v-compare_tms[k])/v)*100):7.2f}%"] for k,v in ref_tms.items() if v>0]
logging.info(tabulate(diff, headers=["job", "master", "compare", "diff"])) logging.info(tabulate(diff, headers=["job", "master", "compare", "diff"]))
@ -127,7 +117,6 @@ def process_replay():
conn.commit() conn.commit()
cur.close() cur.close()
processes = [] processes = []
changed = multiprocessing.Manager().Value('b', False)
for i in tqdm(range(0, row_count, PAGE_SIZE)): for i in tqdm(range(0, row_count, PAGE_SIZE)):
processes.append(p:=multiprocessing.Process(target=print_ast_diff, args=(i,))) processes.append(p:=multiprocessing.Process(target=print_ast_diff, args=(i,)))
p.start() p.start()

View File

@ -2,11 +2,11 @@
# should assert # should assert
sed -i 's/temp/temp1/g' ./tinygrad/codegen/lowerer.py sed -i 's/temp/temp1/g' ./tinygrad/codegen/lowerer.py
COMPRAE_SCHEDULE=0 ASSERT_PROCESS_REPLAY=1 python3 test/external/process_replay/process_replay.py &> /dev/null COMPARE_SCHEDULE=0 ASSERT_PROCESS_REPLAY=1 python3 test/external/process_replay/process_replay.py &> /dev/null
if [[ $? -eq 0 ]]; then if [[ $? -eq 0 ]]; then
echo "PROCESS REPLAY IS WRONG." echo "PROCESS REPLAY IS WRONG."
exit 1 exit 1
fi fi
# should NOT assert # should NOT assert
git stash > /dev/null git stash > /dev/null
COMPRAE_SCHEDULE=0 ASSERT_PROCESS_REPLAY=1 python3 test/external/process_replay/process_replay.py &> /dev/null COMPARE_SCHEDULE=0 ASSERT_PROCESS_REPLAY=1 python3 test/external/process_replay/process_replay.py &> /dev/null