From d2f8eeed2e4125b0b748eed3ba451eb985230915 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Mon, 26 Aug 2024 21:40:03 +0800 Subject: [PATCH] make [compare_schedule] the default [run_process_replay] (#6273) * make [compare_schedule] the default * capture ctx * logging * set capture to false --- test/external/process_replay/diff_schedule.py | 6 ++++-- test/external/process_replay/process_replay.py | 4 +++- test/external/process_replay/test_diff_schedule.py | 8 +++++++- 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/test/external/process_replay/diff_schedule.py b/test/external/process_replay/diff_schedule.py index f7cf3b78..7a511c48 100644 --- a/test/external/process_replay/diff_schedule.py +++ b/test/external/process_replay/diff_schedule.py @@ -4,11 +4,13 @@ from collections import defaultdict from typing import DefaultDict, Dict, List, Set, Tuple from test.external.process_replay.utils import print_diff from tinygrad.engine.schedule import LBScheduleItem, ScheduleItem -from tinygrad.helpers import CI, DEBUG, Context, colored, diskcache_put, fetch, getenv +from tinygrad.helpers import CI, DEBUG, Context, ContextVar, colored, diskcache_put, fetch, getenv from tinygrad.lazy import LazyBuffer from tinygrad.engine.realize import CompiledRunner, lower_schedule_item from tinygrad.ops import UOp +CAPTURING_PROCESS_REPLAY = ContextVar("CAPTURING_PROCESS_REPLAY", getenv("RUN_PROCESS_REPLAY")) + def process_replay(outs:List[LazyBuffer], graph:DefaultDict[LBScheduleItem, List[LBScheduleItem]], in_degree:DefaultDict[LBScheduleItem, int]): # copy the reference module ref_schedule = getenv("REF_COMMIT_HASH", "master") @@ -35,7 +37,7 @@ def diff_schedule(s:List[Tuple[DefaultDict[LBScheduleItem, List[LBScheduleItem]] if (cache_key:=tuple(asts)) in seen_diffs: continue seen_diffs.add(cache_key) changed += 1 - if getenv("RUN_PROCESS_REPLAY"): diskcache_put("schedule_diff", str(uuid.uuid4()), (str(buf), list(asts.values()))) + if CAPTURING_PROCESS_REPLAY: diskcache_put("schedule_diff", str(uuid.uuid4()), (str(buf), list(asts.values()))) if not CI: print_si_diff(si[0], si[1]) if DEBUG >= 1: print(f"*** process replay: {changed} unique kernel{'s' if changed>1 else ''} changed") return changed diff --git a/test/external/process_replay/process_replay.py b/test/external/process_replay/process_replay.py index 2603f86b..f43a9c98 100755 --- a/test/external/process_replay/process_replay.py +++ b/test/external/process_replay/process_replay.py @@ -15,7 +15,7 @@ MAX_DIFF_PCT = getenv("PROCESS_REPLAY_MAX_DIFF_PCT", 20) RUN_ID = os.getenv("GITHUB_RUN_ID", "HEAD") TABLE_NAME = f"process_replay_{RUN_ID}_{getenv('GITHUB_RUN_ATTEMPT')}_{VERSION}" ASSERT_DIFF = getenv("ASSERT_PROCESS_REPLAY", int((k:="[run_process_replay]") in os.getenv("COMMIT_MESSAGE", k) or k in os.getenv("PR_TITLE", k))) -COMPARE_SCHEDULE = getenv("COMPARE_SCHEDULE", int((k:="[compare_schedule]") in os.getenv("COMMIT_MESSAGE", "") or k in os.getenv("PR_TITLE", ""))) +COMPARE_SCHEDULE = getenv("COMPARE_SCHEDULE", int((k:="[compare_schedule]") in os.getenv("COMMIT_MESSAGE", k) or k in os.getenv("PR_TITLE", k))) SKIP_PROCESS_REPLAY = (k:="[skip_process_replay]") in os.getenv("COMMIT_MESSAGE", "") or k in os.getenv("PR_TITLE", "") if REF == "master": SKIP_PROCESS_REPLAY = True early_stop = multiprocessing.Event() @@ -113,6 +113,7 @@ def process_replay(): logging.warning(f"schedule_diff_{VERSION} isn't accessible in master, did DB_VERSION change?") exit(0) if has_diff: + logging.info("***** schedule diff") row_count = cur.execute(f"select count(*) from 'schedule_diff_{VERSION}'").fetchone()[0] conn.commit() cur.close() @@ -125,6 +126,7 @@ def process_replay(): if ASSERT_DIFF: raise Exception("kernel process replay detected changes") # *** kernel diff + logging.info("***** kernel diff") conn = db_connection() cur = conn.cursor() try: row_count = cur.execute(f"select count(*) from '{TABLE_NAME}'").fetchone()[0] diff --git a/test/external/process_replay/test_diff_schedule.py b/test/external/process_replay/test_diff_schedule.py index c6332ec1..3df2e2a9 100644 --- a/test/external/process_replay/test_diff_schedule.py +++ b/test/external/process_replay/test_diff_schedule.py @@ -1,12 +1,18 @@ from typing import cast import unittest -from test.external.process_replay.diff_schedule import diff_schedule +from test.external.process_replay.diff_schedule import CAPTURING_PROCESS_REPLAY, diff_schedule from tinygrad import Tensor, nn from tinygrad.helpers import Context from tinygrad.engine.schedule import _graph_schedule from tinygrad.lazy import LazyBuffer class TestDiffSchedule(unittest.TestCase): + def setUp(self): + self.old_value = CAPTURING_PROCESS_REPLAY.value + CAPTURING_PROCESS_REPLAY.value = 0 + def tearDown(self): + CAPTURING_PROCESS_REPLAY.value = self.old_value + def test_diff_arange(self): # diff a single arange kernel X = Tensor.randn(10, 10).realize()