diff --git a/test/external/process_replay/process_replay.py b/test/external/process_replay/process_replay.py index b0f35738..61908bd8 100755 --- a/test/external/process_replay/process_replay.py +++ b/test/external/process_replay/process_replay.py @@ -120,19 +120,21 @@ def process_replay(): if COMPARE_SCHEDULE: conn = db_connection() cur = conn.cursor() - try: row_count = cur.execute(f"select count(*) from 'schedule_diff_{VERSION}'").fetchone()[0] + try: has_diff = cur.execute(f"select name from sqlite_master where type='table' and name='schedule_diff_{VERSION}'").fetchone() except sqlite3.OperationalError: logging.warning(f"schedule_diff_{VERSION} isn't accessible in master, did DB_VERSION change?") exit(0) - conn.commit() - cur.close() - processes = [] - changed = multiprocessing.Manager().Value('b', False) - for i in tqdm(range(0, row_count, PAGE_SIZE)): - processes.append(p:=multiprocessing.Process(target=print_ast_diff, args=(i,))) - p.start() - for p in processes: p.join() - if row_count != 0 and ASSERT_DIFF: raise Exception("scheduler process replay detected changes") + if has_diff: + row_count = cur.execute(f"select count(*) from 'schedule_diff_{VERSION}'").fetchone()[0] + conn.commit() + cur.close() + processes = [] + changed = multiprocessing.Manager().Value('b', False) + for i in tqdm(range(0, row_count, PAGE_SIZE)): + processes.append(p:=multiprocessing.Process(target=print_ast_diff, args=(i,))) + p.start() + for p in processes: p.join() + if ASSERT_DIFF: raise Exception("scheduler process replay detected changes") # *** kernel diff conn = db_connection() diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index fc098371..7e6723f3 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -1,4 +1,4 @@ -import sys, pickle, atexit +import sys, pickle, atexit, importlib from collections import defaultdict, deque from dataclasses import dataclass, field from typing import Tuple, List, Dict, Optional, Set, DefaultDict, cast, get_args @@ -379,8 +379,9 @@ def create_schedule_with_vars(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffe if seen is None: seen = set() graph, in_degree = _graph_schedule(outs, seen) if getenv("RUN_PROCESS_REPLAY"): - from test.external.process_replay.diff_schedule import process_replay - process_replay(outs, graph, in_degree) + try: importlib.import_module("test.external.process_replay.diff_schedule").process_replay(outs, graph, in_degree) + except ImportError: print("can't access test.external.process_replay.diff_schedule, hint: process relpay needs PYTHONPATH=.") + queue = deque(lsi for lsi,deg in in_degree.items() if deg == 0) schedule: List[ScheduleItem] = [] var_vals: Dict[Variable, int] = {}