diff LazyBuffer schedules in process replay (#5996)

* start diff printing

* this should be 2

* add to process_replay.py

* enable schedule capture

* arange diff is process replay
This commit is contained in:
qazal 2024-08-09 19:16:43 +08:00 committed by GitHub
parent d269bc95fa
commit 24c7c41ce0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 95 additions and 73 deletions

View File

@ -4,6 +4,7 @@ env:
DOWNLOAD_CACHE_VERSION: '5' DOWNLOAD_CACHE_VERSION: '5'
RUN_PROCESS_REPLAY: 1 RUN_PROCESS_REPLAY: 1
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
PYTHONPATH: .
on: on:
push: push:

1
.gitignore vendored
View File

@ -55,3 +55,4 @@ weights
comgr_* comgr_*
*.pkl *.pkl
site/ site/
master_schedule.py

View File

@ -1,13 +1,23 @@
# create a diff of two schedule graphs # create a diff of two schedule graphs
import difflib #, ocdiff import shutil, importlib, uuid, os
from collections import defaultdict from collections import defaultdict
from typing import DefaultDict, List, Set, Tuple from typing import DefaultDict, List, Set, Tuple
from test.external.process_replay.utils import print_diff
from tinygrad.engine.schedule import LBScheduleItem, ScheduleItem from tinygrad.engine.schedule import LBScheduleItem, ScheduleItem
from tinygrad.helpers import Context, colored from tinygrad.helpers import DEBUG, Context, colored, dedup, diskcache_put, fetch, getenv
from tinygrad.lazy import LazyBuffer from tinygrad.lazy import LazyBuffer
from tinygrad.ops import LazyOp from tinygrad.ops import LazyOp
from tinygrad.engine.realize import CompiledRunner, lower_schedule_item from tinygrad.engine.realize import CompiledRunner, lower_schedule_item
def process_replay(outs:List[LazyBuffer], graph:DefaultDict[LBScheduleItem, List[LBScheduleItem]], in_degree:DefaultDict[LBScheduleItem, int]):
# copy the reference module
fp = __file__.replace("diff_schedule", "master_schedule")
if not os.path.isfile(fp): shutil.copyfile(fetch("https://raw.githubusercontent.com/tinygrad/tinygrad/master/tinygrad/engine/schedule.py"), fp)
# create the reference graph
ref_graph, ref_in_degree = importlib.import_module("test.external.process_replay.master_schedule")._graph_schedule(outs, set())
# compare
diff_schedule([(ref_graph, ref_in_degree), (graph, in_degree)])
def diff_schedule(s:List[Tuple[DefaultDict[LBScheduleItem, List[LBScheduleItem]], DefaultDict[LBScheduleItem, int]]]) -> int: def diff_schedule(s:List[Tuple[DefaultDict[LBScheduleItem, List[LBScheduleItem]], DefaultDict[LBScheduleItem, int]]]) -> int:
si_for_buf: DefaultDict[LazyBuffer, List[ScheduleItem]] = defaultdict(list) si_for_buf: DefaultDict[LazyBuffer, List[ScheduleItem]] = defaultdict(list)
for _,in_degree in s: for _,in_degree in s:
@ -15,20 +25,29 @@ def diff_schedule(s:List[Tuple[DefaultDict[LBScheduleItem, List[LBScheduleItem]]
for buf in lsi.outputs: for buf in lsi.outputs:
si_for_buf[buf].append(ScheduleItem(lsi.ast, tuple(x.buffer for x in lsi.outputs+lsi.inputs if x.size != 0), lsi.metadata)) si_for_buf[buf].append(ScheduleItem(lsi.ast, tuple(x.buffer for x in lsi.outputs+lsi.inputs if x.size != 0), lsi.metadata))
changed = 0 changed = 0
seen_diff: Set[Tuple[LazyOp, LazyOp]] = set() seen_diffs: Set[Tuple[LazyOp, ...]] = set()
for buf, si in si_for_buf.items(): for buf, si in si_for_buf.items():
asts = [x.ast for x in si] asts = tuple(dedup([x.ast for x in si]))
if len(set(asts)) == 1: continue # kernels didn't change
if (asts[0], asts[1]) in seen_diff: continue if len(si) > 1 and len(asts) == 1: continue
seen_diff.add((asts[0], asts[1])) if asts in seen_diffs: continue
seen_diffs.add(asts)
changed += 1 changed += 1
#print(ocdiff.console_diff(render(ast[0]), render(ast[1]))) if getenv("RUN_PROCESS_REPLAY"): diskcache_put("schedule_diff", str(uuid.uuid4()), (str(buf), asts))
ei0 = lower_schedule_item(si[0]) if len(asts) == 1:
ei1 = lower_schedule_item(si[1]) print(f"{buf} folded in the second schedule")
assert isinstance(ei0.prg, CompiledRunner) and isinstance(ei1.prg, CompiledRunner) else: print_si_diff(si[0], si[1])
diff = list(difflib.unified_diff(ei0.prg.p.src.splitlines(), ei1.prg.p.src.splitlines())) if DEBUG >= 1: print(f"*** process replay: {changed} unique kernel{'s' if changed>1 else ''} changed")
unified_diff = "\n".join(colored(line, "red" if line.startswith("-") else "green" if line.startswith("+") else None) for line in diff) return changed
print(unified_diff)
def print_si_diff(si0:ScheduleItem, si1:ScheduleItem):
ei0 = lower_schedule_item(si0)
ei1 = lower_schedule_item(si1)
assert isinstance(ei0.prg, CompiledRunner) and isinstance(ei1.prg, CompiledRunner)
print_diff(si0.ast, si1.ast)
print_diff(ei0.prg.p.src, ei1.prg.p.src)
# TODO: create new Buffers for process replay
if getenv("TIMING"):
with Context(DEBUG=2): with Context(DEBUG=2):
tm0 = ei0.run(wait=True) tm0 = ei0.run(wait=True)
tm1 = ei1.run(wait=True) tm1 = ei1.run(wait=True)
@ -36,5 +55,3 @@ def diff_schedule(s:List[Tuple[DefaultDict[LBScheduleItem, List[LBScheduleItem]]
tm_diff = ((tm0 - tm1) / tm0) * 100 tm_diff = ((tm0 - tm1) / tm0) * 100
if tm_diff > 0: print(colored(f"{tm_diff:.2f}% faster", "green")) if tm_diff > 0: print(colored(f"{tm_diff:.2f}% faster", "green"))
else: print(colored(f"{tm_diff:,.2f}% slower", "red")) else: print(colored(f"{tm_diff:,.2f}% slower", "red"))
print(f"{changed} unique kernel{'s' if changed>1 else ''} changed")
return changed

View File

@ -5,9 +5,7 @@ from tabulate import tabulate
from datetime import datetime from datetime import datetime
from typing import Dict, List, cast from typing import Dict, List, cast
from tinygrad.codegen.kernel import Kernel from tinygrad.codegen.kernel import Kernel
from tinygrad.device import Device
from tinygrad.helpers import Context, ContextVar, colored, db_connection, VERSION, getenv, temp, tqdm from tinygrad.helpers import Context, ContextVar, colored, db_connection, VERSION, getenv, temp, tqdm
from tinygrad.ops import LazyOp
# *** process replay settings # *** process replay settings
PAGE_SIZE = 100 PAGE_SIZE = 100
@ -27,7 +25,7 @@ logging.basicConfig(level=logging.INFO, format='%(message)s')
BASE_URL = f"https://api.github.com/repos/{os.getenv('GITHUB_REPOSITORY', 'tinygrad/tinygrad')}" BASE_URL = f"https://api.github.com/repos/{os.getenv('GITHUB_REPOSITORY', 'tinygrad/tinygrad')}"
GH_HEADERS = {"Authorization": f"Bearer {os.getenv('GH_TOKEN', '')}", "Accept": "application/vnd.github+json", "X-GitHub-Api-Version": "2022-11-28"} GH_HEADERS = {"Authorization": f"Bearer {os.getenv('GH_TOKEN', '')}", "Accept": "application/vnd.github+json", "X-GitHub-Api-Version": "2022-11-28"}
def diff_kernel(offset:int, ref_schedule:List[LazyOp], kernel_changed): def diff_kernel(offset:int, kernel_changed):
if early_stop.is_set(): return if early_stop.is_set(): return
conn = db_connection() conn = db_connection()
cur = conn.cursor() cur = conn.cursor()
@ -51,11 +49,6 @@ def diff_kernel(offset:int, ref_schedule:List[LazyOp], kernel_changed):
kernel_changed.value = True kernel_changed.value = True
if ASSERT_DIFF: raise e if ASSERT_DIFF: raise e
continue continue
# try compare
if COMPARE_SCHEDULE and ast not in ref_schedule:
with Context(**{k:v for k,v in ctx.items() if k in ContextVar._cache and k != "DEBUG"}):
print(opts.render(name, Kernel(ast, opts=opts).linearize().uops))
continue
try: assert compare_src == good_src try: assert compare_src == good_src
except AssertionError as e: except AssertionError as e:
changed += 1 changed += 1
@ -74,13 +67,19 @@ def diff_kernel(offset:int, ref_schedule:List[LazyOp], kernel_changed):
conn.commit() conn.commit()
cur.close() cur.close()
def get_ref_schedule(offset:int, ref_schedule): def print_ast_diff(offset:int):
conn = sqlite3.connect("/tmp/process_replay/process_replay.db") conn = db_connection()
cur = conn.cursor() cur = conn.cursor()
cur.execute(f"SELECT val FROM '{REF_TABLE_NAME}' LIMIT ? OFFSET ?", (PAGE_SIZE, offset)) cur.execute(f"SELECT val FROM 'schedule_diff_{VERSION}' LIMIT ? OFFSET ?", (PAGE_SIZE, offset))
for row in cur.fetchall(): ref_schedule.append(pickle.loads(row[0])[0]) for row in cur.fetchall():
conn.commit() buf, asts = pickle.loads(row[0])
cur.close() if len(asts) == 1:
print(f"{buf} was folded")
print(asts[0])
else:
diff = list(difflib.unified_diff(str(asts[0]).splitlines(), str(asts[1]).splitlines()))
unified_diff = "\n".join(colored(line, "red" if line.startswith("-") else "green" if line.startswith("+") else None) for line in diff)
print(unified_diff)
def download_artifact(run_id:str, name:str, dest:str): def download_artifact(run_id:str, name:str, dest:str):
res = requests.get(f"{BASE_URL}/actions/runs/{run_id}/artifacts?name={name}", headers=GH_HEADERS) res = requests.get(f"{BASE_URL}/actions/runs/{run_id}/artifacts?name={name}", headers=GH_HEADERS)
@ -118,23 +117,26 @@ def process_replay():
logging.info(tabulate(diff, headers=["job", "master", "compare", "diff"])) logging.info(tabulate(diff, headers=["job", "master", "compare", "diff"]))
# *** schedule diff # *** schedule diff
ref_schedule = multiprocessing.Manager().list()
if COMPARE_SCHEDULE: if COMPARE_SCHEDULE:
logging.info("fetching process replay reference") conn = db_connection()
# TODO: make this run_id dynamic cur = conn.cursor()
download_artifact("10093148840", f"process_replay_{Device.DEFAULT.lower()}.db", f"{TEMP_DIR}/schedule") try: row_count = cur.execute(f"select count(*) from 'schedule_diff_{VERSION}'").fetchone()[0]
ref_conn = sqlite3.connect(f"{TEMP_DIR}/schedule/process_replay.db") except sqlite3.OperationalError:
row_count = ref_conn.execute(f"select count(*) from '{REF_TABLE_NAME}'").fetchone()[0] logging.warning(f"schedule_diff_{VERSION} isn't accessible in master, did DB_VERSION change?")
exit(0)
conn.commit()
cur.close()
processes = [] processes = []
changed = multiprocessing.Manager().Value('b', False)
for i in tqdm(range(0, row_count, PAGE_SIZE)): for i in tqdm(range(0, row_count, PAGE_SIZE)):
processes.append(p:=multiprocessing.Process(target=get_ref_schedule, args=(i, ref_schedule))) processes.append(p:=multiprocessing.Process(target=print_ast_diff, args=(i,)))
p.start() p.start()
for p in processes: p.join() for p in processes: p.join()
ref_conn.close() if row_count != 0 and ASSERT_DIFF: raise Exception("scheduler process replay detected changes")
conn = db_connection()
cur = conn.cursor()
# *** kernel diff # *** kernel diff
conn = db_connection()
cur = conn.cursor()
try: row_count = cur.execute(f"select count(*) from '{TABLE_NAME}'").fetchone()[0] try: row_count = cur.execute(f"select count(*) from '{TABLE_NAME}'").fetchone()[0]
except sqlite3.OperationalError: except sqlite3.OperationalError:
logging.warning(f"{TABLE_NAME} isn't accessible in master, did DB_VERSION change?") logging.warning(f"{TABLE_NAME} isn't accessible in master, did DB_VERSION change?")
@ -144,10 +146,10 @@ def process_replay():
processes = [] processes = []
changed = multiprocessing.Manager().Value('b', False) changed = multiprocessing.Manager().Value('b', False)
for i in tqdm(range(0, row_count, PAGE_SIZE)): for i in tqdm(range(0, row_count, PAGE_SIZE)):
processes.append(p:=multiprocessing.Process(target=diff_kernel, args=(i, ref_schedule, changed))) processes.append(p:=multiprocessing.Process(target=diff_kernel, args=(i, changed)))
p.start() p.start()
for p in processes: p.join() for p in processes: p.join()
if changed.value and ASSERT_DIFF: raise Exception("process replay detected changes") if changed.value and ASSERT_DIFF: raise Exception("kernel process replay detected changes")
if __name__ == "__main__": if __name__ == "__main__":
if SKIP_PROCESS_REPLAY: if SKIP_PROCESS_REPLAY:

View File

@ -2,3 +2,4 @@
from tinygrad.helpers import db_connection, VERSION, getenv from tinygrad.helpers import db_connection, VERSION, getenv
cur = db_connection() cur = db_connection()
cur.execute(f"drop table if exists process_replay_{getenv('GITHUB_RUN_ID', 'HEAD')}_{getenv('GITHUB_RUN_ATTEMPT')}_{VERSION}") cur.execute(f"drop table if exists process_replay_{getenv('GITHUB_RUN_ID', 'HEAD')}_{getenv('GITHUB_RUN_ATTEMPT')}_{VERSION}")
cur.execute(f"drop table if exists schedule_diff_{VERSION}")

View File

@ -1,36 +1,28 @@
from typing import cast
import unittest import unittest
from test.external.process_replay.diff_schedule import diff_schedule from test.external.process_replay.diff_schedule import diff_schedule
from tinygrad import Tensor from tinygrad import Tensor
from tinygrad.helpers import Context from tinygrad.helpers import Context
from tinygrad.engine.schedule import SCHEDULES from tinygrad.engine.schedule import _graph_schedule
from tinygrad.lazy import LazyBuffer
class TestDiffSchedule(unittest.TestCase): class TestDiffSchedule(unittest.TestCase):
def test_diff_arange(self): def test_diff_arange(self):
# diff a single arange kernel # diff a single arange kernel
X = Tensor.randn(10, 10).realize() X = Tensor.randn(10, 10).realize()
idxs = Tensor([0, 2]).realize() idxs = Tensor([0, 2]).realize()
xt = X[idxs] xt = cast(LazyBuffer, X[idxs].lazydata)
with Context(ARANGE_DIFF=1): xt.schedule() with Context(FUSE_ARANGE=0): ref_graph, ref_in_degree = _graph_schedule([xt], set())
self.assertEqual(len(SCHEDULES), 2) with Context(FUSE_ARANGE=1): compare_graph, compare_in_degree = _graph_schedule([xt], set())
changed = diff_schedule(SCHEDULES) # 1 arange LazyBuffer folds, 1 arange child's kernel changes
self.assertEqual(changed, 1) changed = diff_schedule([(ref_graph, ref_in_degree), (compare_graph, compare_in_degree)])
SCHEDULES.clear() self.assertEqual(changed, 2)
# no diff # no diff
a = Tensor([1])+Tensor([2]) a = cast(LazyBuffer, (Tensor([1])+Tensor([2])).lazydata)
with Context(ARANGE_DIFF=1): a.schedule() with Context(FUSE_ARANGE=0): ref_graph, ref_in_degree = _graph_schedule([a], set())
self.assertEqual(len(SCHEDULES), 2) with Context(FUSE_ARANGE=1): compare_graph, compare_in_degree = _graph_schedule([a], set())
changed = diff_schedule(SCHEDULES) changed = diff_schedule([(ref_graph, ref_in_degree), (compare_graph, compare_in_degree)])
self.assertEqual(changed, 0)
SCHEDULES.clear()
# no diff with two schedule creation calls
a = Tensor([1])+Tensor([2])
with Context(ARANGE_DIFF=1): a.schedule()
b = Tensor([3])+Tensor([4])
with Context(ARANGE_DIFF=1): b.schedule()
self.assertEqual(len(SCHEDULES), 4)
changed = diff_schedule(SCHEDULES)
self.assertEqual(changed, 0) self.assertEqual(changed, 0)
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -2,11 +2,11 @@
# should assert # should assert
sed -i 's/temp/temp1/g' ./tinygrad/codegen/lowerer.py sed -i 's/temp/temp1/g' ./tinygrad/codegen/lowerer.py
ASSERT_PROCESS_REPLAY=1 python3 test/external/process_replay/process_replay.py &> /dev/null COMPRAE_SCHEDULE=0 ASSERT_PROCESS_REPLAY=1 python3 test/external/process_replay/process_replay.py &> /dev/null
if [[ $? -eq 0 ]]; then if [[ $? -eq 0 ]]; then
echo "PROCESS REPLAY IS WRONG." echo "PROCESS REPLAY IS WRONG."
exit 1 exit 1
fi fi
# should NOT assert # should NOT assert
git stash > /dev/null git stash > /dev/null
ASSERT_PROCESS_REPLAY=1 python3 test/external/process_replay/process_replay.py &> /dev/null COMPRAE_SCHEDULE=0 ASSERT_PROCESS_REPLAY=1 python3 test/external/process_replay/process_replay.py &> /dev/null

11
test/external/process_replay/utils.py vendored Normal file
View File

@ -0,0 +1,11 @@
import difflib, logging
from tinygrad.helpers import colored, getenv
def print_diff(s0, s1, unified=getenv("UNIFIED_DIFF",1)):
if unified:
lines = list(difflib.unified_diff(str(s0).splitlines(), str(s1).splitlines()))
diff = "\n".join(colored(line, "red" if line.startswith("-") else "green" if line.startswith("+") else None) for line in lines)
else:
import ocdiff
diff = ocdiff.console_diff(str(s0), str(s1))
logging.info(diff)

View File

@ -4,7 +4,7 @@ from dataclasses import dataclass, field
from typing import Tuple, List, Dict, Optional, Set, DefaultDict, cast, get_args from typing import Tuple, List, Dict, Optional, Set, DefaultDict, cast, get_args
from tinygrad.ops import MetaOps, BufferOps, LazyOp, Op, ReduceOps, ConstBuffer, MemBuffer, UNSAFE_PAD_OPS, UnaryOps, reduce_st from tinygrad.ops import MetaOps, BufferOps, LazyOp, Op, ReduceOps, ConstBuffer, MemBuffer, UNSAFE_PAD_OPS, UnaryOps, reduce_st
from tinygrad.engine.graph import log_lazybuffer, realized_lazybuffer from tinygrad.engine.graph import log_lazybuffer, realized_lazybuffer
from tinygrad.helpers import ARANGE_DIFF, GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, FUSE_CONV_BW, FUSE_ARANGE, Context, \ from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, FUSE_CONV_BW, FUSE_ARANGE, \
GlobalCounters, colored, prod, dedup, all_int, merge_dicts, getenv, Metadata GlobalCounters, colored, prod, dedup, all_int, merge_dicts, getenv, Metadata
from tinygrad.shape.symbolic import Variable, sint from tinygrad.shape.symbolic import Variable, sint
from tinygrad.dtype import ConstType, ImageDType, dtypes from tinygrad.dtype import ConstType, ImageDType, dtypes
@ -367,9 +367,6 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> \
if SAVE_SCHEDULE: if SAVE_SCHEDULE:
def _save(): def _save():
if ARANGE_DIFF:
from test.external.process_replay.diff_schedule import diff_schedule
return diff_schedule(SCHEDULES)
print(f"saving {len(SCHEDULES)} schedule graphs to", fp:=getenv("SAVE_SCHEDULE_PATH", "schedule.pkl")) print(f"saving {len(SCHEDULES)} schedule graphs to", fp:=getenv("SAVE_SCHEDULE_PATH", "schedule.pkl"))
with open(fp, "wb") as f: pickle.dump(SCHEDULES, f) with open(fp, "wb") as f: pickle.dump(SCHEDULES, f)
if len(SCHEDULES) == 0: atexit.register(_save) if len(SCHEDULES) == 0: atexit.register(_save)
@ -380,10 +377,10 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> \
def create_schedule_with_vars(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None) -> Tuple[List[ScheduleItem], Dict[Variable, int]]: def create_schedule_with_vars(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None) -> Tuple[List[ScheduleItem], Dict[Variable, int]]:
if seen is None: seen = set() if seen is None: seen = set()
if ARANGE_DIFF: graph, in_degree = _graph_schedule(outs, seen)
with Context(FUSE_ARANGE=0, SAVE_SCHEDULE=1): _graph_schedule(outs, set()) if getenv("RUN_PROCESS_REPLAY"):
with Context(FUSE_ARANGE=1, SAVE_SCHEDULE=1): graph, in_degree = _graph_schedule(outs, seen) from test.external.process_replay.diff_schedule import process_replay
else: graph, in_degree = _graph_schedule(outs, seen) process_replay(outs, graph, in_degree)
queue = deque(lsi for lsi,deg in in_degree.items() if deg == 0) queue = deque(lsi for lsi,deg in in_degree.items() if deg == 0)
schedule: List[ScheduleItem] = [] schedule: List[ScheduleItem] = []
var_vals: Dict[Variable, int] = {} var_vals: Dict[Variable, int] = {}