mirror of https://github.com/commaai/tinygrad.git
scheduler process replay with [compare_schedule] (#5997)
This commit is contained in:
parent
24c7c41ce0
commit
a833f1a735
|
@ -120,19 +120,21 @@ def process_replay():
|
|||
if COMPARE_SCHEDULE:
|
||||
conn = db_connection()
|
||||
cur = conn.cursor()
|
||||
try: row_count = cur.execute(f"select count(*) from 'schedule_diff_{VERSION}'").fetchone()[0]
|
||||
try: has_diff = cur.execute(f"select name from sqlite_master where type='table' and name='schedule_diff_{VERSION}'").fetchone()
|
||||
except sqlite3.OperationalError:
|
||||
logging.warning(f"schedule_diff_{VERSION} isn't accessible in master, did DB_VERSION change?")
|
||||
exit(0)
|
||||
conn.commit()
|
||||
cur.close()
|
||||
processes = []
|
||||
changed = multiprocessing.Manager().Value('b', False)
|
||||
for i in tqdm(range(0, row_count, PAGE_SIZE)):
|
||||
processes.append(p:=multiprocessing.Process(target=print_ast_diff, args=(i,)))
|
||||
p.start()
|
||||
for p in processes: p.join()
|
||||
if row_count != 0 and ASSERT_DIFF: raise Exception("scheduler process replay detected changes")
|
||||
if has_diff:
|
||||
row_count = cur.execute(f"select count(*) from 'schedule_diff_{VERSION}'").fetchone()[0]
|
||||
conn.commit()
|
||||
cur.close()
|
||||
processes = []
|
||||
changed = multiprocessing.Manager().Value('b', False)
|
||||
for i in tqdm(range(0, row_count, PAGE_SIZE)):
|
||||
processes.append(p:=multiprocessing.Process(target=print_ast_diff, args=(i,)))
|
||||
p.start()
|
||||
for p in processes: p.join()
|
||||
if ASSERT_DIFF: raise Exception("scheduler process replay detected changes")
|
||||
|
||||
# *** kernel diff
|
||||
conn = db_connection()
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
import sys, pickle, atexit
|
||||
import sys, pickle, atexit, importlib
|
||||
from collections import defaultdict, deque
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Tuple, List, Dict, Optional, Set, DefaultDict, cast, get_args
|
||||
|
@ -379,8 +379,9 @@ def create_schedule_with_vars(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffe
|
|||
if seen is None: seen = set()
|
||||
graph, in_degree = _graph_schedule(outs, seen)
|
||||
if getenv("RUN_PROCESS_REPLAY"):
|
||||
from test.external.process_replay.diff_schedule import process_replay
|
||||
process_replay(outs, graph, in_degree)
|
||||
try: importlib.import_module("test.external.process_replay.diff_schedule").process_replay(outs, graph, in_degree)
|
||||
except ImportError: print("can't access test.external.process_replay.diff_schedule, hint: process relpay needs PYTHONPATH=.")
|
||||
|
||||
queue = deque(lsi for lsi,deg in in_degree.items() if deg == 0)
|
||||
schedule: List[ScheduleItem] = []
|
||||
var_vals: Dict[Variable, int] = {}
|
||||
|
|
Loading…
Reference in New Issue