make [compare_schedule] the default [run_process_replay] (#6273)

* make [compare_schedule] the default

* capture ctx

* logging

* set capture to false
This commit is contained in:
qazal 2024-08-26 21:40:03 +08:00 committed by GitHub
parent 067aeaeb2f
commit d2f8eeed2e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 14 additions and 4 deletions

View File

@ -4,11 +4,13 @@ from collections import defaultdict
from typing import DefaultDict, Dict, List, Set, Tuple
from test.external.process_replay.utils import print_diff
from tinygrad.engine.schedule import LBScheduleItem, ScheduleItem
from tinygrad.helpers import CI, DEBUG, Context, colored, diskcache_put, fetch, getenv
from tinygrad.helpers import CI, DEBUG, Context, ContextVar, colored, diskcache_put, fetch, getenv
from tinygrad.lazy import LazyBuffer
from tinygrad.engine.realize import CompiledRunner, lower_schedule_item
from tinygrad.ops import UOp
CAPTURING_PROCESS_REPLAY = ContextVar("CAPTURING_PROCESS_REPLAY", getenv("RUN_PROCESS_REPLAY"))
def process_replay(outs:List[LazyBuffer], graph:DefaultDict[LBScheduleItem, List[LBScheduleItem]], in_degree:DefaultDict[LBScheduleItem, int]):
# copy the reference module
ref_schedule = getenv("REF_COMMIT_HASH", "master")
@ -35,7 +37,7 @@ def diff_schedule(s:List[Tuple[DefaultDict[LBScheduleItem, List[LBScheduleItem]]
if (cache_key:=tuple(asts)) in seen_diffs: continue
seen_diffs.add(cache_key)
changed += 1
if getenv("RUN_PROCESS_REPLAY"): diskcache_put("schedule_diff", str(uuid.uuid4()), (str(buf), list(asts.values())))
if CAPTURING_PROCESS_REPLAY: diskcache_put("schedule_diff", str(uuid.uuid4()), (str(buf), list(asts.values())))
if not CI: print_si_diff(si[0], si[1])
if DEBUG >= 1: print(f"*** process replay: {changed} unique kernel{'s' if changed>1 else ''} changed")
return changed

View File

@ -15,7 +15,7 @@ MAX_DIFF_PCT = getenv("PROCESS_REPLAY_MAX_DIFF_PCT", 20)
RUN_ID = os.getenv("GITHUB_RUN_ID", "HEAD")
TABLE_NAME = f"process_replay_{RUN_ID}_{getenv('GITHUB_RUN_ATTEMPT')}_{VERSION}"
ASSERT_DIFF = getenv("ASSERT_PROCESS_REPLAY", int((k:="[run_process_replay]") in os.getenv("COMMIT_MESSAGE", k) or k in os.getenv("PR_TITLE", k)))
COMPARE_SCHEDULE = getenv("COMPARE_SCHEDULE", int((k:="[compare_schedule]") in os.getenv("COMMIT_MESSAGE", "") or k in os.getenv("PR_TITLE", "")))
COMPARE_SCHEDULE = getenv("COMPARE_SCHEDULE", int((k:="[compare_schedule]") in os.getenv("COMMIT_MESSAGE", k) or k in os.getenv("PR_TITLE", k)))
SKIP_PROCESS_REPLAY = (k:="[skip_process_replay]") in os.getenv("COMMIT_MESSAGE", "") or k in os.getenv("PR_TITLE", "")
if REF == "master": SKIP_PROCESS_REPLAY = True
early_stop = multiprocessing.Event()
@ -113,6 +113,7 @@ def process_replay():
logging.warning(f"schedule_diff_{VERSION} isn't accessible in master, did DB_VERSION change?")
exit(0)
if has_diff:
logging.info("***** schedule diff")
row_count = cur.execute(f"select count(*) from 'schedule_diff_{VERSION}'").fetchone()[0]
conn.commit()
cur.close()
@ -125,6 +126,7 @@ def process_replay():
if ASSERT_DIFF: raise Exception("kernel process replay detected changes")
# *** kernel diff
logging.info("***** kernel diff")
conn = db_connection()
cur = conn.cursor()
try: row_count = cur.execute(f"select count(*) from '{TABLE_NAME}'").fetchone()[0]

View File

@ -1,12 +1,18 @@
from typing import cast
import unittest
from test.external.process_replay.diff_schedule import diff_schedule
from test.external.process_replay.diff_schedule import CAPTURING_PROCESS_REPLAY, diff_schedule
from tinygrad import Tensor, nn
from tinygrad.helpers import Context
from tinygrad.engine.schedule import _graph_schedule
from tinygrad.lazy import LazyBuffer
class TestDiffSchedule(unittest.TestCase):
def setUp(self):
self.old_value = CAPTURING_PROCESS_REPLAY.value
CAPTURING_PROCESS_REPLAY.value = 0
def tearDown(self):
CAPTURING_PROCESS_REPLAY.value = self.old_value
def test_diff_arange(self):
# diff a single arange kernel
X = Tensor.randn(10, 10).realize()