diff LazyBuffer schedules in process replay (#5996)

* start diff printing

* this should be 2

* add to process_replay.py

* enable schedule capture

* arange diff is process replay
This commit is contained in:
qazal 2024-08-09 19:16:43 +08:00 committed by GitHub
parent d269bc95fa
commit 24c7c41ce0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 95 additions and 73 deletions

View File

@ -4,6 +4,7 @@ env:
DOWNLOAD_CACHE_VERSION: '5'
RUN_PROCESS_REPLAY: 1
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
PYTHONPATH: .
on:
push:

1
.gitignore vendored
View File

@ -55,3 +55,4 @@ weights
comgr_*
*.pkl
site/
master_schedule.py

View File

@ -1,13 +1,23 @@
# create a diff of two schedule graphs
import difflib #, ocdiff
import shutil, importlib, uuid, os
from collections import defaultdict
from typing import DefaultDict, List, Set, Tuple
from test.external.process_replay.utils import print_diff
from tinygrad.engine.schedule import LBScheduleItem, ScheduleItem
from tinygrad.helpers import Context, colored
from tinygrad.helpers import DEBUG, Context, colored, dedup, diskcache_put, fetch, getenv
from tinygrad.lazy import LazyBuffer
from tinygrad.ops import LazyOp
from tinygrad.engine.realize import CompiledRunner, lower_schedule_item
def process_replay(outs:List[LazyBuffer], graph:DefaultDict[LBScheduleItem, List[LBScheduleItem]], in_degree:DefaultDict[LBScheduleItem, int]):
# copy the reference module
fp = __file__.replace("diff_schedule", "master_schedule")
if not os.path.isfile(fp): shutil.copyfile(fetch("https://raw.githubusercontent.com/tinygrad/tinygrad/master/tinygrad/engine/schedule.py"), fp)
# create the reference graph
ref_graph, ref_in_degree = importlib.import_module("test.external.process_replay.master_schedule")._graph_schedule(outs, set())
# compare
diff_schedule([(ref_graph, ref_in_degree), (graph, in_degree)])
def diff_schedule(s:List[Tuple[DefaultDict[LBScheduleItem, List[LBScheduleItem]], DefaultDict[LBScheduleItem, int]]]) -> int:
si_for_buf: DefaultDict[LazyBuffer, List[ScheduleItem]] = defaultdict(list)
for _,in_degree in s:
@ -15,20 +25,29 @@ def diff_schedule(s:List[Tuple[DefaultDict[LBScheduleItem, List[LBScheduleItem]]
for buf in lsi.outputs:
si_for_buf[buf].append(ScheduleItem(lsi.ast, tuple(x.buffer for x in lsi.outputs+lsi.inputs if x.size != 0), lsi.metadata))
changed = 0
seen_diff: Set[Tuple[LazyOp, LazyOp]] = set()
seen_diffs: Set[Tuple[LazyOp, ...]] = set()
for buf, si in si_for_buf.items():
asts = [x.ast for x in si]
if len(set(asts)) == 1: continue
if (asts[0], asts[1]) in seen_diff: continue
seen_diff.add((asts[0], asts[1]))
asts = tuple(dedup([x.ast for x in si]))
# kernels didn't change
if len(si) > 1 and len(asts) == 1: continue
if asts in seen_diffs: continue
seen_diffs.add(asts)
changed += 1
#print(ocdiff.console_diff(render(ast[0]), render(ast[1])))
ei0 = lower_schedule_item(si[0])
ei1 = lower_schedule_item(si[1])
assert isinstance(ei0.prg, CompiledRunner) and isinstance(ei1.prg, CompiledRunner)
diff = list(difflib.unified_diff(ei0.prg.p.src.splitlines(), ei1.prg.p.src.splitlines()))
unified_diff = "\n".join(colored(line, "red" if line.startswith("-") else "green" if line.startswith("+") else None) for line in diff)
print(unified_diff)
if getenv("RUN_PROCESS_REPLAY"): diskcache_put("schedule_diff", str(uuid.uuid4()), (str(buf), asts))
if len(asts) == 1:
print(f"{buf} folded in the second schedule")
else: print_si_diff(si[0], si[1])
if DEBUG >= 1: print(f"*** process replay: {changed} unique kernel{'s' if changed>1 else ''} changed")
return changed
def print_si_diff(si0:ScheduleItem, si1:ScheduleItem):
ei0 = lower_schedule_item(si0)
ei1 = lower_schedule_item(si1)
assert isinstance(ei0.prg, CompiledRunner) and isinstance(ei1.prg, CompiledRunner)
print_diff(si0.ast, si1.ast)
print_diff(ei0.prg.p.src, ei1.prg.p.src)
# TODO: create new Buffers for process replay
if getenv("TIMING"):
with Context(DEBUG=2):
tm0 = ei0.run(wait=True)
tm1 = ei1.run(wait=True)
@ -36,5 +55,3 @@ def diff_schedule(s:List[Tuple[DefaultDict[LBScheduleItem, List[LBScheduleItem]]
tm_diff = ((tm0 - tm1) / tm0) * 100
if tm_diff > 0: print(colored(f"{tm_diff:.2f}% faster", "green"))
else: print(colored(f"{tm_diff:,.2f}% slower", "red"))
print(f"{changed} unique kernel{'s' if changed>1 else ''} changed")
return changed

View File

@ -5,9 +5,7 @@ from tabulate import tabulate
from datetime import datetime
from typing import Dict, List, cast
from tinygrad.codegen.kernel import Kernel
from tinygrad.device import Device
from tinygrad.helpers import Context, ContextVar, colored, db_connection, VERSION, getenv, temp, tqdm
from tinygrad.ops import LazyOp
# *** process replay settings
PAGE_SIZE = 100
@ -27,7 +25,7 @@ logging.basicConfig(level=logging.INFO, format='%(message)s')
BASE_URL = f"https://api.github.com/repos/{os.getenv('GITHUB_REPOSITORY', 'tinygrad/tinygrad')}"
GH_HEADERS = {"Authorization": f"Bearer {os.getenv('GH_TOKEN', '')}", "Accept": "application/vnd.github+json", "X-GitHub-Api-Version": "2022-11-28"}
def diff_kernel(offset:int, ref_schedule:List[LazyOp], kernel_changed):
def diff_kernel(offset:int, kernel_changed):
if early_stop.is_set(): return
conn = db_connection()
cur = conn.cursor()
@ -51,11 +49,6 @@ def diff_kernel(offset:int, ref_schedule:List[LazyOp], kernel_changed):
kernel_changed.value = True
if ASSERT_DIFF: raise e
continue
# try compare
if COMPARE_SCHEDULE and ast not in ref_schedule:
with Context(**{k:v for k,v in ctx.items() if k in ContextVar._cache and k != "DEBUG"}):
print(opts.render(name, Kernel(ast, opts=opts).linearize().uops))
continue
try: assert compare_src == good_src
except AssertionError as e:
changed += 1
@ -74,13 +67,19 @@ def diff_kernel(offset:int, ref_schedule:List[LazyOp], kernel_changed):
conn.commit()
cur.close()
def get_ref_schedule(offset:int, ref_schedule):
conn = sqlite3.connect("/tmp/process_replay/process_replay.db")
def print_ast_diff(offset:int):
conn = db_connection()
cur = conn.cursor()
cur.execute(f"SELECT val FROM '{REF_TABLE_NAME}' LIMIT ? OFFSET ?", (PAGE_SIZE, offset))
for row in cur.fetchall(): ref_schedule.append(pickle.loads(row[0])[0])
conn.commit()
cur.close()
cur.execute(f"SELECT val FROM 'schedule_diff_{VERSION}' LIMIT ? OFFSET ?", (PAGE_SIZE, offset))
for row in cur.fetchall():
buf, asts = pickle.loads(row[0])
if len(asts) == 1:
print(f"{buf} was folded")
print(asts[0])
else:
diff = list(difflib.unified_diff(str(asts[0]).splitlines(), str(asts[1]).splitlines()))
unified_diff = "\n".join(colored(line, "red" if line.startswith("-") else "green" if line.startswith("+") else None) for line in diff)
print(unified_diff)
def download_artifact(run_id:str, name:str, dest:str):
res = requests.get(f"{BASE_URL}/actions/runs/{run_id}/artifacts?name={name}", headers=GH_HEADERS)
@ -118,23 +117,26 @@ def process_replay():
logging.info(tabulate(diff, headers=["job", "master", "compare", "diff"]))
# *** schedule diff
ref_schedule = multiprocessing.Manager().list()
if COMPARE_SCHEDULE:
logging.info("fetching process replay reference")
# TODO: make this run_id dynamic
download_artifact("10093148840", f"process_replay_{Device.DEFAULT.lower()}.db", f"{TEMP_DIR}/schedule")
ref_conn = sqlite3.connect(f"{TEMP_DIR}/schedule/process_replay.db")
row_count = ref_conn.execute(f"select count(*) from '{REF_TABLE_NAME}'").fetchone()[0]
conn = db_connection()
cur = conn.cursor()
try: row_count = cur.execute(f"select count(*) from 'schedule_diff_{VERSION}'").fetchone()[0]
except sqlite3.OperationalError:
logging.warning(f"schedule_diff_{VERSION} isn't accessible in master, did DB_VERSION change?")
exit(0)
conn.commit()
cur.close()
processes = []
changed = multiprocessing.Manager().Value('b', False)
for i in tqdm(range(0, row_count, PAGE_SIZE)):
processes.append(p:=multiprocessing.Process(target=get_ref_schedule, args=(i, ref_schedule)))
processes.append(p:=multiprocessing.Process(target=print_ast_diff, args=(i,)))
p.start()
for p in processes: p.join()
ref_conn.close()
conn = db_connection()
cur = conn.cursor()
if row_count != 0 and ASSERT_DIFF: raise Exception("scheduler process replay detected changes")
# *** kernel diff
conn = db_connection()
cur = conn.cursor()
try: row_count = cur.execute(f"select count(*) from '{TABLE_NAME}'").fetchone()[0]
except sqlite3.OperationalError:
logging.warning(f"{TABLE_NAME} isn't accessible in master, did DB_VERSION change?")
@ -144,10 +146,10 @@ def process_replay():
processes = []
changed = multiprocessing.Manager().Value('b', False)
for i in tqdm(range(0, row_count, PAGE_SIZE)):
processes.append(p:=multiprocessing.Process(target=diff_kernel, args=(i, ref_schedule, changed)))
processes.append(p:=multiprocessing.Process(target=diff_kernel, args=(i, changed)))
p.start()
for p in processes: p.join()
if changed.value and ASSERT_DIFF: raise Exception("process replay detected changes")
if changed.value and ASSERT_DIFF: raise Exception("kernel process replay detected changes")
if __name__ == "__main__":
if SKIP_PROCESS_REPLAY:

View File

@ -2,3 +2,4 @@
from tinygrad.helpers import db_connection, VERSION, getenv
cur = db_connection()
cur.execute(f"drop table if exists process_replay_{getenv('GITHUB_RUN_ID', 'HEAD')}_{getenv('GITHUB_RUN_ATTEMPT')}_{VERSION}")
cur.execute(f"drop table if exists schedule_diff_{VERSION}")

View File

@ -1,36 +1,28 @@
from typing import cast
import unittest
from test.external.process_replay.diff_schedule import diff_schedule
from tinygrad import Tensor
from tinygrad.helpers import Context
from tinygrad.engine.schedule import SCHEDULES
from tinygrad.engine.schedule import _graph_schedule
from tinygrad.lazy import LazyBuffer
class TestDiffSchedule(unittest.TestCase):
def test_diff_arange(self):
# diff a single arange kernel
X = Tensor.randn(10, 10).realize()
idxs = Tensor([0, 2]).realize()
xt = X[idxs]
with Context(ARANGE_DIFF=1): xt.schedule()
self.assertEqual(len(SCHEDULES), 2)
changed = diff_schedule(SCHEDULES)
self.assertEqual(changed, 1)
SCHEDULES.clear()
xt = cast(LazyBuffer, X[idxs].lazydata)
with Context(FUSE_ARANGE=0): ref_graph, ref_in_degree = _graph_schedule([xt], set())
with Context(FUSE_ARANGE=1): compare_graph, compare_in_degree = _graph_schedule([xt], set())
# 1 arange LazyBuffer folds, 1 arange child's kernel changes
changed = diff_schedule([(ref_graph, ref_in_degree), (compare_graph, compare_in_degree)])
self.assertEqual(changed, 2)
# no diff
a = Tensor([1])+Tensor([2])
with Context(ARANGE_DIFF=1): a.schedule()
self.assertEqual(len(SCHEDULES), 2)
changed = diff_schedule(SCHEDULES)
self.assertEqual(changed, 0)
SCHEDULES.clear()
# no diff with two schedule creation calls
a = Tensor([1])+Tensor([2])
with Context(ARANGE_DIFF=1): a.schedule()
b = Tensor([3])+Tensor([4])
with Context(ARANGE_DIFF=1): b.schedule()
self.assertEqual(len(SCHEDULES), 4)
changed = diff_schedule(SCHEDULES)
a = cast(LazyBuffer, (Tensor([1])+Tensor([2])).lazydata)
with Context(FUSE_ARANGE=0): ref_graph, ref_in_degree = _graph_schedule([a], set())
with Context(FUSE_ARANGE=1): compare_graph, compare_in_degree = _graph_schedule([a], set())
changed = diff_schedule([(ref_graph, ref_in_degree), (compare_graph, compare_in_degree)])
self.assertEqual(changed, 0)
if __name__ == '__main__':

View File

@ -2,11 +2,11 @@
# should assert
sed -i 's/temp/temp1/g' ./tinygrad/codegen/lowerer.py
ASSERT_PROCESS_REPLAY=1 python3 test/external/process_replay/process_replay.py &> /dev/null
COMPRAE_SCHEDULE=0 ASSERT_PROCESS_REPLAY=1 python3 test/external/process_replay/process_replay.py &> /dev/null
if [[ $? -eq 0 ]]; then
echo "PROCESS REPLAY IS WRONG."
exit 1
fi
# should NOT assert
git stash > /dev/null
ASSERT_PROCESS_REPLAY=1 python3 test/external/process_replay/process_replay.py &> /dev/null
COMPRAE_SCHEDULE=0 ASSERT_PROCESS_REPLAY=1 python3 test/external/process_replay/process_replay.py &> /dev/null

11
test/external/process_replay/utils.py vendored Normal file
View File

@ -0,0 +1,11 @@
import difflib, logging
from tinygrad.helpers import colored, getenv
def print_diff(s0, s1, unified=getenv("UNIFIED_DIFF",1)):
if unified:
lines = list(difflib.unified_diff(str(s0).splitlines(), str(s1).splitlines()))
diff = "\n".join(colored(line, "red" if line.startswith("-") else "green" if line.startswith("+") else None) for line in lines)
else:
import ocdiff
diff = ocdiff.console_diff(str(s0), str(s1))
logging.info(diff)

View File

@ -4,7 +4,7 @@ from dataclasses import dataclass, field
from typing import Tuple, List, Dict, Optional, Set, DefaultDict, cast, get_args
from tinygrad.ops import MetaOps, BufferOps, LazyOp, Op, ReduceOps, ConstBuffer, MemBuffer, UNSAFE_PAD_OPS, UnaryOps, reduce_st
from tinygrad.engine.graph import log_lazybuffer, realized_lazybuffer
from tinygrad.helpers import ARANGE_DIFF, GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, FUSE_CONV_BW, FUSE_ARANGE, Context, \
from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, FUSE_CONV_BW, FUSE_ARANGE, \
GlobalCounters, colored, prod, dedup, all_int, merge_dicts, getenv, Metadata
from tinygrad.shape.symbolic import Variable, sint
from tinygrad.dtype import ConstType, ImageDType, dtypes
@ -367,9 +367,6 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> \
if SAVE_SCHEDULE:
def _save():
if ARANGE_DIFF:
from test.external.process_replay.diff_schedule import diff_schedule
return diff_schedule(SCHEDULES)
print(f"saving {len(SCHEDULES)} schedule graphs to", fp:=getenv("SAVE_SCHEDULE_PATH", "schedule.pkl"))
with open(fp, "wb") as f: pickle.dump(SCHEDULES, f)
if len(SCHEDULES) == 0: atexit.register(_save)
@ -380,10 +377,10 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> \
def create_schedule_with_vars(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None) -> Tuple[List[ScheduleItem], Dict[Variable, int]]:
if seen is None: seen = set()
if ARANGE_DIFF:
with Context(FUSE_ARANGE=0, SAVE_SCHEDULE=1): _graph_schedule(outs, set())
with Context(FUSE_ARANGE=1, SAVE_SCHEDULE=1): graph, in_degree = _graph_schedule(outs, seen)
else: graph, in_degree = _graph_schedule(outs, seen)
graph, in_degree = _graph_schedule(outs, seen)
if getenv("RUN_PROCESS_REPLAY"):
from test.external.process_replay.diff_schedule import process_replay
process_replay(outs, graph, in_degree)
queue = deque(lsi for lsi,deg in in_degree.items() if deg == 0)
schedule: List[ScheduleItem] = []
var_vals: Dict[Variable, int] = {}