mirror of https://github.com/commaai/tinygrad.git
make [compare_schedule] the default [run_process_replay] (#6273)
* make [compare_schedule] the default * capture ctx * logging * set capture to false
This commit is contained in:
parent
067aeaeb2f
commit
d2f8eeed2e
|
@ -4,11 +4,13 @@ from collections import defaultdict
|
||||||
from typing import DefaultDict, Dict, List, Set, Tuple
|
from typing import DefaultDict, Dict, List, Set, Tuple
|
||||||
from test.external.process_replay.utils import print_diff
|
from test.external.process_replay.utils import print_diff
|
||||||
from tinygrad.engine.schedule import LBScheduleItem, ScheduleItem
|
from tinygrad.engine.schedule import LBScheduleItem, ScheduleItem
|
||||||
from tinygrad.helpers import CI, DEBUG, Context, colored, diskcache_put, fetch, getenv
|
from tinygrad.helpers import CI, DEBUG, Context, ContextVar, colored, diskcache_put, fetch, getenv
|
||||||
from tinygrad.lazy import LazyBuffer
|
from tinygrad.lazy import LazyBuffer
|
||||||
from tinygrad.engine.realize import CompiledRunner, lower_schedule_item
|
from tinygrad.engine.realize import CompiledRunner, lower_schedule_item
|
||||||
from tinygrad.ops import UOp
|
from tinygrad.ops import UOp
|
||||||
|
|
||||||
|
CAPTURING_PROCESS_REPLAY = ContextVar("CAPTURING_PROCESS_REPLAY", getenv("RUN_PROCESS_REPLAY"))
|
||||||
|
|
||||||
def process_replay(outs:List[LazyBuffer], graph:DefaultDict[LBScheduleItem, List[LBScheduleItem]], in_degree:DefaultDict[LBScheduleItem, int]):
|
def process_replay(outs:List[LazyBuffer], graph:DefaultDict[LBScheduleItem, List[LBScheduleItem]], in_degree:DefaultDict[LBScheduleItem, int]):
|
||||||
# copy the reference module
|
# copy the reference module
|
||||||
ref_schedule = getenv("REF_COMMIT_HASH", "master")
|
ref_schedule = getenv("REF_COMMIT_HASH", "master")
|
||||||
|
@ -35,7 +37,7 @@ def diff_schedule(s:List[Tuple[DefaultDict[LBScheduleItem, List[LBScheduleItem]]
|
||||||
if (cache_key:=tuple(asts)) in seen_diffs: continue
|
if (cache_key:=tuple(asts)) in seen_diffs: continue
|
||||||
seen_diffs.add(cache_key)
|
seen_diffs.add(cache_key)
|
||||||
changed += 1
|
changed += 1
|
||||||
if getenv("RUN_PROCESS_REPLAY"): diskcache_put("schedule_diff", str(uuid.uuid4()), (str(buf), list(asts.values())))
|
if CAPTURING_PROCESS_REPLAY: diskcache_put("schedule_diff", str(uuid.uuid4()), (str(buf), list(asts.values())))
|
||||||
if not CI: print_si_diff(si[0], si[1])
|
if not CI: print_si_diff(si[0], si[1])
|
||||||
if DEBUG >= 1: print(f"*** process replay: {changed} unique kernel{'s' if changed>1 else ''} changed")
|
if DEBUG >= 1: print(f"*** process replay: {changed} unique kernel{'s' if changed>1 else ''} changed")
|
||||||
return changed
|
return changed
|
||||||
|
|
|
@ -15,7 +15,7 @@ MAX_DIFF_PCT = getenv("PROCESS_REPLAY_MAX_DIFF_PCT", 20)
|
||||||
RUN_ID = os.getenv("GITHUB_RUN_ID", "HEAD")
|
RUN_ID = os.getenv("GITHUB_RUN_ID", "HEAD")
|
||||||
TABLE_NAME = f"process_replay_{RUN_ID}_{getenv('GITHUB_RUN_ATTEMPT')}_{VERSION}"
|
TABLE_NAME = f"process_replay_{RUN_ID}_{getenv('GITHUB_RUN_ATTEMPT')}_{VERSION}"
|
||||||
ASSERT_DIFF = getenv("ASSERT_PROCESS_REPLAY", int((k:="[run_process_replay]") in os.getenv("COMMIT_MESSAGE", k) or k in os.getenv("PR_TITLE", k)))
|
ASSERT_DIFF = getenv("ASSERT_PROCESS_REPLAY", int((k:="[run_process_replay]") in os.getenv("COMMIT_MESSAGE", k) or k in os.getenv("PR_TITLE", k)))
|
||||||
COMPARE_SCHEDULE = getenv("COMPARE_SCHEDULE", int((k:="[compare_schedule]") in os.getenv("COMMIT_MESSAGE", "") or k in os.getenv("PR_TITLE", "")))
|
COMPARE_SCHEDULE = getenv("COMPARE_SCHEDULE", int((k:="[compare_schedule]") in os.getenv("COMMIT_MESSAGE", k) or k in os.getenv("PR_TITLE", k)))
|
||||||
SKIP_PROCESS_REPLAY = (k:="[skip_process_replay]") in os.getenv("COMMIT_MESSAGE", "") or k in os.getenv("PR_TITLE", "")
|
SKIP_PROCESS_REPLAY = (k:="[skip_process_replay]") in os.getenv("COMMIT_MESSAGE", "") or k in os.getenv("PR_TITLE", "")
|
||||||
if REF == "master": SKIP_PROCESS_REPLAY = True
|
if REF == "master": SKIP_PROCESS_REPLAY = True
|
||||||
early_stop = multiprocessing.Event()
|
early_stop = multiprocessing.Event()
|
||||||
|
@ -113,6 +113,7 @@ def process_replay():
|
||||||
logging.warning(f"schedule_diff_{VERSION} isn't accessible in master, did DB_VERSION change?")
|
logging.warning(f"schedule_diff_{VERSION} isn't accessible in master, did DB_VERSION change?")
|
||||||
exit(0)
|
exit(0)
|
||||||
if has_diff:
|
if has_diff:
|
||||||
|
logging.info("***** schedule diff")
|
||||||
row_count = cur.execute(f"select count(*) from 'schedule_diff_{VERSION}'").fetchone()[0]
|
row_count = cur.execute(f"select count(*) from 'schedule_diff_{VERSION}'").fetchone()[0]
|
||||||
conn.commit()
|
conn.commit()
|
||||||
cur.close()
|
cur.close()
|
||||||
|
@ -125,6 +126,7 @@ def process_replay():
|
||||||
if ASSERT_DIFF: raise Exception("kernel process replay detected changes")
|
if ASSERT_DIFF: raise Exception("kernel process replay detected changes")
|
||||||
|
|
||||||
# *** kernel diff
|
# *** kernel diff
|
||||||
|
logging.info("***** kernel diff")
|
||||||
conn = db_connection()
|
conn = db_connection()
|
||||||
cur = conn.cursor()
|
cur = conn.cursor()
|
||||||
try: row_count = cur.execute(f"select count(*) from '{TABLE_NAME}'").fetchone()[0]
|
try: row_count = cur.execute(f"select count(*) from '{TABLE_NAME}'").fetchone()[0]
|
||||||
|
|
|
@ -1,12 +1,18 @@
|
||||||
from typing import cast
|
from typing import cast
|
||||||
import unittest
|
import unittest
|
||||||
from test.external.process_replay.diff_schedule import diff_schedule
|
from test.external.process_replay.diff_schedule import CAPTURING_PROCESS_REPLAY, diff_schedule
|
||||||
from tinygrad import Tensor, nn
|
from tinygrad import Tensor, nn
|
||||||
from tinygrad.helpers import Context
|
from tinygrad.helpers import Context
|
||||||
from tinygrad.engine.schedule import _graph_schedule
|
from tinygrad.engine.schedule import _graph_schedule
|
||||||
from tinygrad.lazy import LazyBuffer
|
from tinygrad.lazy import LazyBuffer
|
||||||
|
|
||||||
class TestDiffSchedule(unittest.TestCase):
|
class TestDiffSchedule(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.old_value = CAPTURING_PROCESS_REPLAY.value
|
||||||
|
CAPTURING_PROCESS_REPLAY.value = 0
|
||||||
|
def tearDown(self):
|
||||||
|
CAPTURING_PROCESS_REPLAY.value = self.old_value
|
||||||
|
|
||||||
def test_diff_arange(self):
|
def test_diff_arange(self):
|
||||||
# diff a single arange kernel
|
# diff a single arange kernel
|
||||||
X = Tensor.randn(10, 10).realize()
|
X = Tensor.randn(10, 10).realize()
|
||||||
|
|
Loading…
Reference in New Issue