retire replay_schedule (#5563)

This commit is contained in:
qazal 2024-07-19 04:07:02 +08:00 committed by GitHub
parent 50aba32ea8
commit e7a057c20f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 2 additions and 34 deletions

View File

@ -1,32 +0,0 @@
#!/usr/bin/env python3
import subprocess, pickle, shlex, sys, os
from typing import Dict, List, Tuple
from tinygrad.helpers import colored
from tinygrad.ops import LazyOp
def _run(name:str, cmd:List[str], env:Dict[str, str]) -> List[Tuple[LazyOp, ...]]:
commit = subprocess.check_output(["git", "rev-parse", name], encoding="utf-8").strip()
subprocess.run(["git", "checkout", commit], check=True)
subprocess.run(cmd, env={**env, "SAVE_SCHEDULE_PATH": f"{commit}.pkl"})
return pickle.load(open(f"./{commit}.pkl", "rb"))
def _get_cmd():
parts, env = shlex.split(sys.argv[1]), {**os.environ, "SAVE_SCHEDULE": "1", "CAPTURE_AST": "1"}
env.update({k: v for p in parts if "=" in p for k, v in [p.split("=")]})
return [p for p in parts if "=" not in p], env
if __name__ == "__main__":
cmd, env = _get_cmd()
feat = _run("HEAD", cmd, env)
master = _run("master", cmd, env)
assert len(master) == len(feat)
for m, f in zip(master, feat):
try: assert m == f
except AssertionError as e:
print(colored("FAILED FOR AST: ", "red"))
print("expected:")
for op in m: print(op)
print("got:")
for op in f: print(op)
raise e

View File

@ -336,10 +336,10 @@ def create_schedule_with_vars(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffe
print(f"saving {len(SCHEDULES)} schedule graphs to", fp:=getenv("SAVE_SCHEDULE_PATH", "schedule.pkl"))
with open(fp, "wb") as f: pickle.dump(SCHEDULES, f)
if len(SCHEDULES) == 0: atexit.register(_save)
SCHEDULES.extend((ps[1] for ps in prescheduled.values()) if getenv("CAPTURE_AST") else [(graph, prescheduled)])
SCHEDULES.append((graph, prescheduled))
if SAVE_SCHEDULE.value == len(SCHEDULES): exit(0)
# confirm everything was scheduled correctly
if not all(degree == 0 for degree in in_degree.values()) or len(prescheduled) != len(schedule):
if any(degree != 0 for degree in in_degree.values()) or len(prescheduled) != len(schedule):
raise RuntimeError(f"cycle detected in graph, prescheduled {len(prescheduled)} but only scheduled {len(schedule)}")
if DEBUG >= 1 and len(schedule) >= 10: print(f"scheduled {len(schedule)} kernels")
return schedule, var_vals