diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 4b18e2f6..b15738f3 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -4,6 +4,7 @@ env: DOWNLOAD_CACHE_VERSION: '5' RUN_PROCESS_REPLAY: 1 GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + PYTHONPATH: . on: push: diff --git a/.gitignore b/.gitignore index 766eb808..0e10e658 100644 --- a/.gitignore +++ b/.gitignore @@ -55,3 +55,4 @@ weights comgr_* *.pkl site/ +master_schedule.py diff --git a/test/external/process_replay/diff_schedule.py b/test/external/process_replay/diff_schedule.py index 8ec05a2b..2e2ba7ea 100644 --- a/test/external/process_replay/diff_schedule.py +++ b/test/external/process_replay/diff_schedule.py @@ -1,13 +1,23 @@ # create a diff of two schedule graphs -import difflib #, ocdiff +import shutil, importlib, uuid, os from collections import defaultdict from typing import DefaultDict, List, Set, Tuple +from test.external.process_replay.utils import print_diff from tinygrad.engine.schedule import LBScheduleItem, ScheduleItem -from tinygrad.helpers import Context, colored +from tinygrad.helpers import DEBUG, Context, colored, dedup, diskcache_put, fetch, getenv from tinygrad.lazy import LazyBuffer from tinygrad.ops import LazyOp from tinygrad.engine.realize import CompiledRunner, lower_schedule_item +def process_replay(outs:List[LazyBuffer], graph:DefaultDict[LBScheduleItem, List[LBScheduleItem]], in_degree:DefaultDict[LBScheduleItem, int]): + # copy the reference module + fp = __file__.replace("diff_schedule", "master_schedule") + if not os.path.isfile(fp): shutil.copyfile(fetch("https://raw.githubusercontent.com/tinygrad/tinygrad/master/tinygrad/engine/schedule.py"), fp) + # create the reference graph + ref_graph, ref_in_degree = importlib.import_module("test.external.process_replay.master_schedule")._graph_schedule(outs, set()) + # compare + diff_schedule([(ref_graph, ref_in_degree), (graph, in_degree)]) + def diff_schedule(s:List[Tuple[DefaultDict[LBScheduleItem, List[LBScheduleItem]], DefaultDict[LBScheduleItem, int]]]) -> int: si_for_buf: DefaultDict[LazyBuffer, List[ScheduleItem]] = defaultdict(list) for _,in_degree in s: @@ -15,20 +25,29 @@ def diff_schedule(s:List[Tuple[DefaultDict[LBScheduleItem, List[LBScheduleItem]] for buf in lsi.outputs: si_for_buf[buf].append(ScheduleItem(lsi.ast, tuple(x.buffer for x in lsi.outputs+lsi.inputs if x.size != 0), lsi.metadata)) changed = 0 - seen_diff: Set[Tuple[LazyOp, LazyOp]] = set() + seen_diffs: Set[Tuple[LazyOp, ...]] = set() for buf, si in si_for_buf.items(): - asts = [x.ast for x in si] - if len(set(asts)) == 1: continue - if (asts[0], asts[1]) in seen_diff: continue - seen_diff.add((asts[0], asts[1])) + asts = tuple(dedup([x.ast for x in si])) + # kernels didn't change + if len(si) > 1 and len(asts) == 1: continue + if asts in seen_diffs: continue + seen_diffs.add(asts) changed += 1 - #print(ocdiff.console_diff(render(ast[0]), render(ast[1]))) - ei0 = lower_schedule_item(si[0]) - ei1 = lower_schedule_item(si[1]) - assert isinstance(ei0.prg, CompiledRunner) and isinstance(ei1.prg, CompiledRunner) - diff = list(difflib.unified_diff(ei0.prg.p.src.splitlines(), ei1.prg.p.src.splitlines())) - unified_diff = "\n".join(colored(line, "red" if line.startswith("-") else "green" if line.startswith("+") else None) for line in diff) - print(unified_diff) + if getenv("RUN_PROCESS_REPLAY"): diskcache_put("schedule_diff", str(uuid.uuid4()), (str(buf), asts)) + if len(asts) == 1: + print(f"{buf} folded in the second schedule") + else: print_si_diff(si[0], si[1]) + if DEBUG >= 1: print(f"*** process replay: {changed} unique kernel{'s' if changed>1 else ''} changed") + return changed + +def print_si_diff(si0:ScheduleItem, si1:ScheduleItem): + ei0 = lower_schedule_item(si0) + ei1 = lower_schedule_item(si1) + assert isinstance(ei0.prg, CompiledRunner) and isinstance(ei1.prg, CompiledRunner) + print_diff(si0.ast, si1.ast) + print_diff(ei0.prg.p.src, ei1.prg.p.src) + # TODO: create new Buffers for process replay + if getenv("TIMING"): with Context(DEBUG=2): tm0 = ei0.run(wait=True) tm1 = ei1.run(wait=True) @@ -36,5 +55,3 @@ def diff_schedule(s:List[Tuple[DefaultDict[LBScheduleItem, List[LBScheduleItem]] tm_diff = ((tm0 - tm1) / tm0) * 100 if tm_diff > 0: print(colored(f"{tm_diff:.2f}% faster", "green")) else: print(colored(f"{tm_diff:,.2f}% slower", "red")) - print(f"{changed} unique kernel{'s' if changed>1 else ''} changed") - return changed diff --git a/test/external/process_replay/process_replay.py b/test/external/process_replay/process_replay.py index 01cf0100..b0f35738 100755 --- a/test/external/process_replay/process_replay.py +++ b/test/external/process_replay/process_replay.py @@ -5,9 +5,7 @@ from tabulate import tabulate from datetime import datetime from typing import Dict, List, cast from tinygrad.codegen.kernel import Kernel -from tinygrad.device import Device from tinygrad.helpers import Context, ContextVar, colored, db_connection, VERSION, getenv, temp, tqdm -from tinygrad.ops import LazyOp # *** process replay settings PAGE_SIZE = 100 @@ -27,7 +25,7 @@ logging.basicConfig(level=logging.INFO, format='%(message)s') BASE_URL = f"https://api.github.com/repos/{os.getenv('GITHUB_REPOSITORY', 'tinygrad/tinygrad')}" GH_HEADERS = {"Authorization": f"Bearer {os.getenv('GH_TOKEN', '')}", "Accept": "application/vnd.github+json", "X-GitHub-Api-Version": "2022-11-28"} -def diff_kernel(offset:int, ref_schedule:List[LazyOp], kernel_changed): +def diff_kernel(offset:int, kernel_changed): if early_stop.is_set(): return conn = db_connection() cur = conn.cursor() @@ -51,11 +49,6 @@ def diff_kernel(offset:int, ref_schedule:List[LazyOp], kernel_changed): kernel_changed.value = True if ASSERT_DIFF: raise e continue - # try compare - if COMPARE_SCHEDULE and ast not in ref_schedule: - with Context(**{k:v for k,v in ctx.items() if k in ContextVar._cache and k != "DEBUG"}): - print(opts.render(name, Kernel(ast, opts=opts).linearize().uops)) - continue try: assert compare_src == good_src except AssertionError as e: changed += 1 @@ -74,13 +67,19 @@ def diff_kernel(offset:int, ref_schedule:List[LazyOp], kernel_changed): conn.commit() cur.close() -def get_ref_schedule(offset:int, ref_schedule): - conn = sqlite3.connect("/tmp/process_replay/process_replay.db") +def print_ast_diff(offset:int): + conn = db_connection() cur = conn.cursor() - cur.execute(f"SELECT val FROM '{REF_TABLE_NAME}' LIMIT ? OFFSET ?", (PAGE_SIZE, offset)) - for row in cur.fetchall(): ref_schedule.append(pickle.loads(row[0])[0]) - conn.commit() - cur.close() + cur.execute(f"SELECT val FROM 'schedule_diff_{VERSION}' LIMIT ? OFFSET ?", (PAGE_SIZE, offset)) + for row in cur.fetchall(): + buf, asts = pickle.loads(row[0]) + if len(asts) == 1: + print(f"{buf} was folded") + print(asts[0]) + else: + diff = list(difflib.unified_diff(str(asts[0]).splitlines(), str(asts[1]).splitlines())) + unified_diff = "\n".join(colored(line, "red" if line.startswith("-") else "green" if line.startswith("+") else None) for line in diff) + print(unified_diff) def download_artifact(run_id:str, name:str, dest:str): res = requests.get(f"{BASE_URL}/actions/runs/{run_id}/artifacts?name={name}", headers=GH_HEADERS) @@ -118,23 +117,26 @@ def process_replay(): logging.info(tabulate(diff, headers=["job", "master", "compare", "diff"])) # *** schedule diff - ref_schedule = multiprocessing.Manager().list() if COMPARE_SCHEDULE: - logging.info("fetching process replay reference") - # TODO: make this run_id dynamic - download_artifact("10093148840", f"process_replay_{Device.DEFAULT.lower()}.db", f"{TEMP_DIR}/schedule") - ref_conn = sqlite3.connect(f"{TEMP_DIR}/schedule/process_replay.db") - row_count = ref_conn.execute(f"select count(*) from '{REF_TABLE_NAME}'").fetchone()[0] + conn = db_connection() + cur = conn.cursor() + try: row_count = cur.execute(f"select count(*) from 'schedule_diff_{VERSION}'").fetchone()[0] + except sqlite3.OperationalError: + logging.warning(f"schedule_diff_{VERSION} isn't accessible in master, did DB_VERSION change?") + exit(0) + conn.commit() + cur.close() processes = [] + changed = multiprocessing.Manager().Value('b', False) for i in tqdm(range(0, row_count, PAGE_SIZE)): - processes.append(p:=multiprocessing.Process(target=get_ref_schedule, args=(i, ref_schedule))) + processes.append(p:=multiprocessing.Process(target=print_ast_diff, args=(i,))) p.start() for p in processes: p.join() - ref_conn.close() - conn = db_connection() - cur = conn.cursor() + if row_count != 0 and ASSERT_DIFF: raise Exception("scheduler process replay detected changes") # *** kernel diff + conn = db_connection() + cur = conn.cursor() try: row_count = cur.execute(f"select count(*) from '{TABLE_NAME}'").fetchone()[0] except sqlite3.OperationalError: logging.warning(f"{TABLE_NAME} isn't accessible in master, did DB_VERSION change?") @@ -144,10 +146,10 @@ def process_replay(): processes = [] changed = multiprocessing.Manager().Value('b', False) for i in tqdm(range(0, row_count, PAGE_SIZE)): - processes.append(p:=multiprocessing.Process(target=diff_kernel, args=(i, ref_schedule, changed))) + processes.append(p:=multiprocessing.Process(target=diff_kernel, args=(i, changed))) p.start() for p in processes: p.join() - if changed.value and ASSERT_DIFF: raise Exception("process replay detected changes") + if changed.value and ASSERT_DIFF: raise Exception("kernel process replay detected changes") if __name__ == "__main__": if SKIP_PROCESS_REPLAY: diff --git a/test/external/process_replay/reset.py b/test/external/process_replay/reset.py index 0ce2724c..e07b207a 100755 --- a/test/external/process_replay/reset.py +++ b/test/external/process_replay/reset.py @@ -2,3 +2,4 @@ from tinygrad.helpers import db_connection, VERSION, getenv cur = db_connection() cur.execute(f"drop table if exists process_replay_{getenv('GITHUB_RUN_ID', 'HEAD')}_{getenv('GITHUB_RUN_ATTEMPT')}_{VERSION}") +cur.execute(f"drop table if exists schedule_diff_{VERSION}") diff --git a/test/external/process_replay/test_diff_schedule.py b/test/external/process_replay/test_diff_schedule.py index 08f1791c..6f476f6e 100644 --- a/test/external/process_replay/test_diff_schedule.py +++ b/test/external/process_replay/test_diff_schedule.py @@ -1,36 +1,28 @@ +from typing import cast import unittest from test.external.process_replay.diff_schedule import diff_schedule from tinygrad import Tensor from tinygrad.helpers import Context -from tinygrad.engine.schedule import SCHEDULES +from tinygrad.engine.schedule import _graph_schedule +from tinygrad.lazy import LazyBuffer class TestDiffSchedule(unittest.TestCase): def test_diff_arange(self): # diff a single arange kernel X = Tensor.randn(10, 10).realize() idxs = Tensor([0, 2]).realize() - xt = X[idxs] - with Context(ARANGE_DIFF=1): xt.schedule() - self.assertEqual(len(SCHEDULES), 2) - changed = diff_schedule(SCHEDULES) - self.assertEqual(changed, 1) - SCHEDULES.clear() + xt = cast(LazyBuffer, X[idxs].lazydata) + with Context(FUSE_ARANGE=0): ref_graph, ref_in_degree = _graph_schedule([xt], set()) + with Context(FUSE_ARANGE=1): compare_graph, compare_in_degree = _graph_schedule([xt], set()) + # 1 arange LazyBuffer folds, 1 arange child's kernel changes + changed = diff_schedule([(ref_graph, ref_in_degree), (compare_graph, compare_in_degree)]) + self.assertEqual(changed, 2) # no diff - a = Tensor([1])+Tensor([2]) - with Context(ARANGE_DIFF=1): a.schedule() - self.assertEqual(len(SCHEDULES), 2) - changed = diff_schedule(SCHEDULES) - self.assertEqual(changed, 0) - SCHEDULES.clear() - - # no diff with two schedule creation calls - a = Tensor([1])+Tensor([2]) - with Context(ARANGE_DIFF=1): a.schedule() - b = Tensor([3])+Tensor([4]) - with Context(ARANGE_DIFF=1): b.schedule() - self.assertEqual(len(SCHEDULES), 4) - changed = diff_schedule(SCHEDULES) + a = cast(LazyBuffer, (Tensor([1])+Tensor([2])).lazydata) + with Context(FUSE_ARANGE=0): ref_graph, ref_in_degree = _graph_schedule([a], set()) + with Context(FUSE_ARANGE=1): compare_graph, compare_in_degree = _graph_schedule([a], set()) + changed = diff_schedule([(ref_graph, ref_in_degree), (compare_graph, compare_in_degree)]) self.assertEqual(changed, 0) if __name__ == '__main__': diff --git a/test/external/process_replay/test_process_replay.sh b/test/external/process_replay/test_process_replay.sh index 79a2093a..adb6a111 100755 --- a/test/external/process_replay/test_process_replay.sh +++ b/test/external/process_replay/test_process_replay.sh @@ -2,11 +2,11 @@ # should assert sed -i 's/temp/temp1/g' ./tinygrad/codegen/lowerer.py -ASSERT_PROCESS_REPLAY=1 python3 test/external/process_replay/process_replay.py &> /dev/null +COMPRAE_SCHEDULE=0 ASSERT_PROCESS_REPLAY=1 python3 test/external/process_replay/process_replay.py &> /dev/null if [[ $? -eq 0 ]]; then echo "PROCESS REPLAY IS WRONG." exit 1 fi # should NOT assert git stash > /dev/null -ASSERT_PROCESS_REPLAY=1 python3 test/external/process_replay/process_replay.py &> /dev/null +COMPRAE_SCHEDULE=0 ASSERT_PROCESS_REPLAY=1 python3 test/external/process_replay/process_replay.py &> /dev/null diff --git a/test/external/process_replay/utils.py b/test/external/process_replay/utils.py new file mode 100644 index 00000000..78920d2e --- /dev/null +++ b/test/external/process_replay/utils.py @@ -0,0 +1,11 @@ +import difflib, logging +from tinygrad.helpers import colored, getenv + +def print_diff(s0, s1, unified=getenv("UNIFIED_DIFF",1)): + if unified: + lines = list(difflib.unified_diff(str(s0).splitlines(), str(s1).splitlines())) + diff = "\n".join(colored(line, "red" if line.startswith("-") else "green" if line.startswith("+") else None) for line in lines) + else: + import ocdiff + diff = ocdiff.console_diff(str(s0), str(s1)) + logging.info(diff) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index a0fa0297..fc098371 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -4,7 +4,7 @@ from dataclasses import dataclass, field from typing import Tuple, List, Dict, Optional, Set, DefaultDict, cast, get_args from tinygrad.ops import MetaOps, BufferOps, LazyOp, Op, ReduceOps, ConstBuffer, MemBuffer, UNSAFE_PAD_OPS, UnaryOps, reduce_st from tinygrad.engine.graph import log_lazybuffer, realized_lazybuffer -from tinygrad.helpers import ARANGE_DIFF, GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, FUSE_CONV_BW, FUSE_ARANGE, Context, \ +from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, FUSE_CONV_BW, FUSE_ARANGE, \ GlobalCounters, colored, prod, dedup, all_int, merge_dicts, getenv, Metadata from tinygrad.shape.symbolic import Variable, sint from tinygrad.dtype import ConstType, ImageDType, dtypes @@ -367,9 +367,6 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> \ if SAVE_SCHEDULE: def _save(): - if ARANGE_DIFF: - from test.external.process_replay.diff_schedule import diff_schedule - return diff_schedule(SCHEDULES) print(f"saving {len(SCHEDULES)} schedule graphs to", fp:=getenv("SAVE_SCHEDULE_PATH", "schedule.pkl")) with open(fp, "wb") as f: pickle.dump(SCHEDULES, f) if len(SCHEDULES) == 0: atexit.register(_save) @@ -380,10 +377,10 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> \ def create_schedule_with_vars(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None) -> Tuple[List[ScheduleItem], Dict[Variable, int]]: if seen is None: seen = set() - if ARANGE_DIFF: - with Context(FUSE_ARANGE=0, SAVE_SCHEDULE=1): _graph_schedule(outs, set()) - with Context(FUSE_ARANGE=1, SAVE_SCHEDULE=1): graph, in_degree = _graph_schedule(outs, seen) - else: graph, in_degree = _graph_schedule(outs, seen) + graph, in_degree = _graph_schedule(outs, seen) + if getenv("RUN_PROCESS_REPLAY"): + from test.external.process_replay.diff_schedule import process_replay + process_replay(outs, graph, in_degree) queue = deque(lsi for lsi,deg in in_degree.items() if deg == 0) schedule: List[ScheduleItem] = [] var_vals: Dict[Variable, int] = {}