scheduler process replay with [compare_schedule] (#5997)

This commit is contained in:
qazal 2024-08-09 21:58:22 +08:00 committed by GitHub
parent 24c7c41ce0
commit a833f1a735
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 16 additions and 13 deletions

View File

@ -120,19 +120,21 @@ def process_replay():
if COMPARE_SCHEDULE:
conn = db_connection()
cur = conn.cursor()
try: row_count = cur.execute(f"select count(*) from 'schedule_diff_{VERSION}'").fetchone()[0]
try: has_diff = cur.execute(f"select name from sqlite_master where type='table' and name='schedule_diff_{VERSION}'").fetchone()
except sqlite3.OperationalError:
logging.warning(f"schedule_diff_{VERSION} isn't accessible in master, did DB_VERSION change?")
exit(0)
conn.commit()
cur.close()
processes = []
changed = multiprocessing.Manager().Value('b', False)
for i in tqdm(range(0, row_count, PAGE_SIZE)):
processes.append(p:=multiprocessing.Process(target=print_ast_diff, args=(i,)))
p.start()
for p in processes: p.join()
if row_count != 0 and ASSERT_DIFF: raise Exception("scheduler process replay detected changes")
if has_diff:
row_count = cur.execute(f"select count(*) from 'schedule_diff_{VERSION}'").fetchone()[0]
conn.commit()
cur.close()
processes = []
changed = multiprocessing.Manager().Value('b', False)
for i in tqdm(range(0, row_count, PAGE_SIZE)):
processes.append(p:=multiprocessing.Process(target=print_ast_diff, args=(i,)))
p.start()
for p in processes: p.join()
if ASSERT_DIFF: raise Exception("scheduler process replay detected changes")
# *** kernel diff
conn = db_connection()

View File

@ -1,4 +1,4 @@
import sys, pickle, atexit
import sys, pickle, atexit, importlib
from collections import defaultdict, deque
from dataclasses import dataclass, field
from typing import Tuple, List, Dict, Optional, Set, DefaultDict, cast, get_args
@ -379,8 +379,9 @@ def create_schedule_with_vars(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffe
if seen is None: seen = set()
graph, in_degree = _graph_schedule(outs, seen)
if getenv("RUN_PROCESS_REPLAY"):
from test.external.process_replay.diff_schedule import process_replay
process_replay(outs, graph, in_degree)
try: importlib.import_module("test.external.process_replay.diff_schedule").process_replay(outs, graph, in_degree)
except ImportError: print("can't access test.external.process_replay.diff_schedule, hint: process relpay needs PYTHONPATH=.")
queue = deque(lsi for lsi,deg in in_degree.items() if deg == 0)
schedule: List[ScheduleItem] = []
var_vals: Dict[Variable, int] = {}