mirror of https://github.com/commaai/tinygrad.git
diff LazyBuffer schedules in process replay (#5996)
* start diff printing * this should be 2 * add to process_replay.py * enable schedule capture * arange diff is process replay
This commit is contained in:
parent
d269bc95fa
commit
24c7c41ce0
|
@ -4,6 +4,7 @@ env:
|
||||||
DOWNLOAD_CACHE_VERSION: '5'
|
DOWNLOAD_CACHE_VERSION: '5'
|
||||||
RUN_PROCESS_REPLAY: 1
|
RUN_PROCESS_REPLAY: 1
|
||||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
PYTHONPATH: .
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
|
|
|
@ -55,3 +55,4 @@ weights
|
||||||
comgr_*
|
comgr_*
|
||||||
*.pkl
|
*.pkl
|
||||||
site/
|
site/
|
||||||
|
master_schedule.py
|
||||||
|
|
|
@ -1,13 +1,23 @@
|
||||||
# create a diff of two schedule graphs
|
# create a diff of two schedule graphs
|
||||||
import difflib #, ocdiff
|
import shutil, importlib, uuid, os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import DefaultDict, List, Set, Tuple
|
from typing import DefaultDict, List, Set, Tuple
|
||||||
|
from test.external.process_replay.utils import print_diff
|
||||||
from tinygrad.engine.schedule import LBScheduleItem, ScheduleItem
|
from tinygrad.engine.schedule import LBScheduleItem, ScheduleItem
|
||||||
from tinygrad.helpers import Context, colored
|
from tinygrad.helpers import DEBUG, Context, colored, dedup, diskcache_put, fetch, getenv
|
||||||
from tinygrad.lazy import LazyBuffer
|
from tinygrad.lazy import LazyBuffer
|
||||||
from tinygrad.ops import LazyOp
|
from tinygrad.ops import LazyOp
|
||||||
from tinygrad.engine.realize import CompiledRunner, lower_schedule_item
|
from tinygrad.engine.realize import CompiledRunner, lower_schedule_item
|
||||||
|
|
||||||
|
def process_replay(outs:List[LazyBuffer], graph:DefaultDict[LBScheduleItem, List[LBScheduleItem]], in_degree:DefaultDict[LBScheduleItem, int]):
|
||||||
|
# copy the reference module
|
||||||
|
fp = __file__.replace("diff_schedule", "master_schedule")
|
||||||
|
if not os.path.isfile(fp): shutil.copyfile(fetch("https://raw.githubusercontent.com/tinygrad/tinygrad/master/tinygrad/engine/schedule.py"), fp)
|
||||||
|
# create the reference graph
|
||||||
|
ref_graph, ref_in_degree = importlib.import_module("test.external.process_replay.master_schedule")._graph_schedule(outs, set())
|
||||||
|
# compare
|
||||||
|
diff_schedule([(ref_graph, ref_in_degree), (graph, in_degree)])
|
||||||
|
|
||||||
def diff_schedule(s:List[Tuple[DefaultDict[LBScheduleItem, List[LBScheduleItem]], DefaultDict[LBScheduleItem, int]]]) -> int:
|
def diff_schedule(s:List[Tuple[DefaultDict[LBScheduleItem, List[LBScheduleItem]], DefaultDict[LBScheduleItem, int]]]) -> int:
|
||||||
si_for_buf: DefaultDict[LazyBuffer, List[ScheduleItem]] = defaultdict(list)
|
si_for_buf: DefaultDict[LazyBuffer, List[ScheduleItem]] = defaultdict(list)
|
||||||
for _,in_degree in s:
|
for _,in_degree in s:
|
||||||
|
@ -15,20 +25,29 @@ def diff_schedule(s:List[Tuple[DefaultDict[LBScheduleItem, List[LBScheduleItem]]
|
||||||
for buf in lsi.outputs:
|
for buf in lsi.outputs:
|
||||||
si_for_buf[buf].append(ScheduleItem(lsi.ast, tuple(x.buffer for x in lsi.outputs+lsi.inputs if x.size != 0), lsi.metadata))
|
si_for_buf[buf].append(ScheduleItem(lsi.ast, tuple(x.buffer for x in lsi.outputs+lsi.inputs if x.size != 0), lsi.metadata))
|
||||||
changed = 0
|
changed = 0
|
||||||
seen_diff: Set[Tuple[LazyOp, LazyOp]] = set()
|
seen_diffs: Set[Tuple[LazyOp, ...]] = set()
|
||||||
for buf, si in si_for_buf.items():
|
for buf, si in si_for_buf.items():
|
||||||
asts = [x.ast for x in si]
|
asts = tuple(dedup([x.ast for x in si]))
|
||||||
if len(set(asts)) == 1: continue
|
# kernels didn't change
|
||||||
if (asts[0], asts[1]) in seen_diff: continue
|
if len(si) > 1 and len(asts) == 1: continue
|
||||||
seen_diff.add((asts[0], asts[1]))
|
if asts in seen_diffs: continue
|
||||||
|
seen_diffs.add(asts)
|
||||||
changed += 1
|
changed += 1
|
||||||
#print(ocdiff.console_diff(render(ast[0]), render(ast[1])))
|
if getenv("RUN_PROCESS_REPLAY"): diskcache_put("schedule_diff", str(uuid.uuid4()), (str(buf), asts))
|
||||||
ei0 = lower_schedule_item(si[0])
|
if len(asts) == 1:
|
||||||
ei1 = lower_schedule_item(si[1])
|
print(f"{buf} folded in the second schedule")
|
||||||
assert isinstance(ei0.prg, CompiledRunner) and isinstance(ei1.prg, CompiledRunner)
|
else: print_si_diff(si[0], si[1])
|
||||||
diff = list(difflib.unified_diff(ei0.prg.p.src.splitlines(), ei1.prg.p.src.splitlines()))
|
if DEBUG >= 1: print(f"*** process replay: {changed} unique kernel{'s' if changed>1 else ''} changed")
|
||||||
unified_diff = "\n".join(colored(line, "red" if line.startswith("-") else "green" if line.startswith("+") else None) for line in diff)
|
return changed
|
||||||
print(unified_diff)
|
|
||||||
|
def print_si_diff(si0:ScheduleItem, si1:ScheduleItem):
|
||||||
|
ei0 = lower_schedule_item(si0)
|
||||||
|
ei1 = lower_schedule_item(si1)
|
||||||
|
assert isinstance(ei0.prg, CompiledRunner) and isinstance(ei1.prg, CompiledRunner)
|
||||||
|
print_diff(si0.ast, si1.ast)
|
||||||
|
print_diff(ei0.prg.p.src, ei1.prg.p.src)
|
||||||
|
# TODO: create new Buffers for process replay
|
||||||
|
if getenv("TIMING"):
|
||||||
with Context(DEBUG=2):
|
with Context(DEBUG=2):
|
||||||
tm0 = ei0.run(wait=True)
|
tm0 = ei0.run(wait=True)
|
||||||
tm1 = ei1.run(wait=True)
|
tm1 = ei1.run(wait=True)
|
||||||
|
@ -36,5 +55,3 @@ def diff_schedule(s:List[Tuple[DefaultDict[LBScheduleItem, List[LBScheduleItem]]
|
||||||
tm_diff = ((tm0 - tm1) / tm0) * 100
|
tm_diff = ((tm0 - tm1) / tm0) * 100
|
||||||
if tm_diff > 0: print(colored(f"{tm_diff:.2f}% faster", "green"))
|
if tm_diff > 0: print(colored(f"{tm_diff:.2f}% faster", "green"))
|
||||||
else: print(colored(f"{tm_diff:,.2f}% slower", "red"))
|
else: print(colored(f"{tm_diff:,.2f}% slower", "red"))
|
||||||
print(f"{changed} unique kernel{'s' if changed>1 else ''} changed")
|
|
||||||
return changed
|
|
||||||
|
|
|
@ -5,9 +5,7 @@ from tabulate import tabulate
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Dict, List, cast
|
from typing import Dict, List, cast
|
||||||
from tinygrad.codegen.kernel import Kernel
|
from tinygrad.codegen.kernel import Kernel
|
||||||
from tinygrad.device import Device
|
|
||||||
from tinygrad.helpers import Context, ContextVar, colored, db_connection, VERSION, getenv, temp, tqdm
|
from tinygrad.helpers import Context, ContextVar, colored, db_connection, VERSION, getenv, temp, tqdm
|
||||||
from tinygrad.ops import LazyOp
|
|
||||||
|
|
||||||
# *** process replay settings
|
# *** process replay settings
|
||||||
PAGE_SIZE = 100
|
PAGE_SIZE = 100
|
||||||
|
@ -27,7 +25,7 @@ logging.basicConfig(level=logging.INFO, format='%(message)s')
|
||||||
BASE_URL = f"https://api.github.com/repos/{os.getenv('GITHUB_REPOSITORY', 'tinygrad/tinygrad')}"
|
BASE_URL = f"https://api.github.com/repos/{os.getenv('GITHUB_REPOSITORY', 'tinygrad/tinygrad')}"
|
||||||
GH_HEADERS = {"Authorization": f"Bearer {os.getenv('GH_TOKEN', '')}", "Accept": "application/vnd.github+json", "X-GitHub-Api-Version": "2022-11-28"}
|
GH_HEADERS = {"Authorization": f"Bearer {os.getenv('GH_TOKEN', '')}", "Accept": "application/vnd.github+json", "X-GitHub-Api-Version": "2022-11-28"}
|
||||||
|
|
||||||
def diff_kernel(offset:int, ref_schedule:List[LazyOp], kernel_changed):
|
def diff_kernel(offset:int, kernel_changed):
|
||||||
if early_stop.is_set(): return
|
if early_stop.is_set(): return
|
||||||
conn = db_connection()
|
conn = db_connection()
|
||||||
cur = conn.cursor()
|
cur = conn.cursor()
|
||||||
|
@ -51,11 +49,6 @@ def diff_kernel(offset:int, ref_schedule:List[LazyOp], kernel_changed):
|
||||||
kernel_changed.value = True
|
kernel_changed.value = True
|
||||||
if ASSERT_DIFF: raise e
|
if ASSERT_DIFF: raise e
|
||||||
continue
|
continue
|
||||||
# try compare
|
|
||||||
if COMPARE_SCHEDULE and ast not in ref_schedule:
|
|
||||||
with Context(**{k:v for k,v in ctx.items() if k in ContextVar._cache and k != "DEBUG"}):
|
|
||||||
print(opts.render(name, Kernel(ast, opts=opts).linearize().uops))
|
|
||||||
continue
|
|
||||||
try: assert compare_src == good_src
|
try: assert compare_src == good_src
|
||||||
except AssertionError as e:
|
except AssertionError as e:
|
||||||
changed += 1
|
changed += 1
|
||||||
|
@ -74,13 +67,19 @@ def diff_kernel(offset:int, ref_schedule:List[LazyOp], kernel_changed):
|
||||||
conn.commit()
|
conn.commit()
|
||||||
cur.close()
|
cur.close()
|
||||||
|
|
||||||
def get_ref_schedule(offset:int, ref_schedule):
|
def print_ast_diff(offset:int):
|
||||||
conn = sqlite3.connect("/tmp/process_replay/process_replay.db")
|
conn = db_connection()
|
||||||
cur = conn.cursor()
|
cur = conn.cursor()
|
||||||
cur.execute(f"SELECT val FROM '{REF_TABLE_NAME}' LIMIT ? OFFSET ?", (PAGE_SIZE, offset))
|
cur.execute(f"SELECT val FROM 'schedule_diff_{VERSION}' LIMIT ? OFFSET ?", (PAGE_SIZE, offset))
|
||||||
for row in cur.fetchall(): ref_schedule.append(pickle.loads(row[0])[0])
|
for row in cur.fetchall():
|
||||||
conn.commit()
|
buf, asts = pickle.loads(row[0])
|
||||||
cur.close()
|
if len(asts) == 1:
|
||||||
|
print(f"{buf} was folded")
|
||||||
|
print(asts[0])
|
||||||
|
else:
|
||||||
|
diff = list(difflib.unified_diff(str(asts[0]).splitlines(), str(asts[1]).splitlines()))
|
||||||
|
unified_diff = "\n".join(colored(line, "red" if line.startswith("-") else "green" if line.startswith("+") else None) for line in diff)
|
||||||
|
print(unified_diff)
|
||||||
|
|
||||||
def download_artifact(run_id:str, name:str, dest:str):
|
def download_artifact(run_id:str, name:str, dest:str):
|
||||||
res = requests.get(f"{BASE_URL}/actions/runs/{run_id}/artifacts?name={name}", headers=GH_HEADERS)
|
res = requests.get(f"{BASE_URL}/actions/runs/{run_id}/artifacts?name={name}", headers=GH_HEADERS)
|
||||||
|
@ -118,23 +117,26 @@ def process_replay():
|
||||||
logging.info(tabulate(diff, headers=["job", "master", "compare", "diff"]))
|
logging.info(tabulate(diff, headers=["job", "master", "compare", "diff"]))
|
||||||
|
|
||||||
# *** schedule diff
|
# *** schedule diff
|
||||||
ref_schedule = multiprocessing.Manager().list()
|
|
||||||
if COMPARE_SCHEDULE:
|
if COMPARE_SCHEDULE:
|
||||||
logging.info("fetching process replay reference")
|
conn = db_connection()
|
||||||
# TODO: make this run_id dynamic
|
cur = conn.cursor()
|
||||||
download_artifact("10093148840", f"process_replay_{Device.DEFAULT.lower()}.db", f"{TEMP_DIR}/schedule")
|
try: row_count = cur.execute(f"select count(*) from 'schedule_diff_{VERSION}'").fetchone()[0]
|
||||||
ref_conn = sqlite3.connect(f"{TEMP_DIR}/schedule/process_replay.db")
|
except sqlite3.OperationalError:
|
||||||
row_count = ref_conn.execute(f"select count(*) from '{REF_TABLE_NAME}'").fetchone()[0]
|
logging.warning(f"schedule_diff_{VERSION} isn't accessible in master, did DB_VERSION change?")
|
||||||
|
exit(0)
|
||||||
|
conn.commit()
|
||||||
|
cur.close()
|
||||||
processes = []
|
processes = []
|
||||||
|
changed = multiprocessing.Manager().Value('b', False)
|
||||||
for i in tqdm(range(0, row_count, PAGE_SIZE)):
|
for i in tqdm(range(0, row_count, PAGE_SIZE)):
|
||||||
processes.append(p:=multiprocessing.Process(target=get_ref_schedule, args=(i, ref_schedule)))
|
processes.append(p:=multiprocessing.Process(target=print_ast_diff, args=(i,)))
|
||||||
p.start()
|
p.start()
|
||||||
for p in processes: p.join()
|
for p in processes: p.join()
|
||||||
ref_conn.close()
|
if row_count != 0 and ASSERT_DIFF: raise Exception("scheduler process replay detected changes")
|
||||||
conn = db_connection()
|
|
||||||
cur = conn.cursor()
|
|
||||||
|
|
||||||
# *** kernel diff
|
# *** kernel diff
|
||||||
|
conn = db_connection()
|
||||||
|
cur = conn.cursor()
|
||||||
try: row_count = cur.execute(f"select count(*) from '{TABLE_NAME}'").fetchone()[0]
|
try: row_count = cur.execute(f"select count(*) from '{TABLE_NAME}'").fetchone()[0]
|
||||||
except sqlite3.OperationalError:
|
except sqlite3.OperationalError:
|
||||||
logging.warning(f"{TABLE_NAME} isn't accessible in master, did DB_VERSION change?")
|
logging.warning(f"{TABLE_NAME} isn't accessible in master, did DB_VERSION change?")
|
||||||
|
@ -144,10 +146,10 @@ def process_replay():
|
||||||
processes = []
|
processes = []
|
||||||
changed = multiprocessing.Manager().Value('b', False)
|
changed = multiprocessing.Manager().Value('b', False)
|
||||||
for i in tqdm(range(0, row_count, PAGE_SIZE)):
|
for i in tqdm(range(0, row_count, PAGE_SIZE)):
|
||||||
processes.append(p:=multiprocessing.Process(target=diff_kernel, args=(i, ref_schedule, changed)))
|
processes.append(p:=multiprocessing.Process(target=diff_kernel, args=(i, changed)))
|
||||||
p.start()
|
p.start()
|
||||||
for p in processes: p.join()
|
for p in processes: p.join()
|
||||||
if changed.value and ASSERT_DIFF: raise Exception("process replay detected changes")
|
if changed.value and ASSERT_DIFF: raise Exception("kernel process replay detected changes")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
if SKIP_PROCESS_REPLAY:
|
if SKIP_PROCESS_REPLAY:
|
||||||
|
|
|
@ -2,3 +2,4 @@
|
||||||
from tinygrad.helpers import db_connection, VERSION, getenv
|
from tinygrad.helpers import db_connection, VERSION, getenv
|
||||||
cur = db_connection()
|
cur = db_connection()
|
||||||
cur.execute(f"drop table if exists process_replay_{getenv('GITHUB_RUN_ID', 'HEAD')}_{getenv('GITHUB_RUN_ATTEMPT')}_{VERSION}")
|
cur.execute(f"drop table if exists process_replay_{getenv('GITHUB_RUN_ID', 'HEAD')}_{getenv('GITHUB_RUN_ATTEMPT')}_{VERSION}")
|
||||||
|
cur.execute(f"drop table if exists schedule_diff_{VERSION}")
|
||||||
|
|
|
@ -1,36 +1,28 @@
|
||||||
|
from typing import cast
|
||||||
import unittest
|
import unittest
|
||||||
from test.external.process_replay.diff_schedule import diff_schedule
|
from test.external.process_replay.diff_schedule import diff_schedule
|
||||||
from tinygrad import Tensor
|
from tinygrad import Tensor
|
||||||
from tinygrad.helpers import Context
|
from tinygrad.helpers import Context
|
||||||
from tinygrad.engine.schedule import SCHEDULES
|
from tinygrad.engine.schedule import _graph_schedule
|
||||||
|
from tinygrad.lazy import LazyBuffer
|
||||||
|
|
||||||
class TestDiffSchedule(unittest.TestCase):
|
class TestDiffSchedule(unittest.TestCase):
|
||||||
def test_diff_arange(self):
|
def test_diff_arange(self):
|
||||||
# diff a single arange kernel
|
# diff a single arange kernel
|
||||||
X = Tensor.randn(10, 10).realize()
|
X = Tensor.randn(10, 10).realize()
|
||||||
idxs = Tensor([0, 2]).realize()
|
idxs = Tensor([0, 2]).realize()
|
||||||
xt = X[idxs]
|
xt = cast(LazyBuffer, X[idxs].lazydata)
|
||||||
with Context(ARANGE_DIFF=1): xt.schedule()
|
with Context(FUSE_ARANGE=0): ref_graph, ref_in_degree = _graph_schedule([xt], set())
|
||||||
self.assertEqual(len(SCHEDULES), 2)
|
with Context(FUSE_ARANGE=1): compare_graph, compare_in_degree = _graph_schedule([xt], set())
|
||||||
changed = diff_schedule(SCHEDULES)
|
# 1 arange LazyBuffer folds, 1 arange child's kernel changes
|
||||||
self.assertEqual(changed, 1)
|
changed = diff_schedule([(ref_graph, ref_in_degree), (compare_graph, compare_in_degree)])
|
||||||
SCHEDULES.clear()
|
self.assertEqual(changed, 2)
|
||||||
|
|
||||||
# no diff
|
# no diff
|
||||||
a = Tensor([1])+Tensor([2])
|
a = cast(LazyBuffer, (Tensor([1])+Tensor([2])).lazydata)
|
||||||
with Context(ARANGE_DIFF=1): a.schedule()
|
with Context(FUSE_ARANGE=0): ref_graph, ref_in_degree = _graph_schedule([a], set())
|
||||||
self.assertEqual(len(SCHEDULES), 2)
|
with Context(FUSE_ARANGE=1): compare_graph, compare_in_degree = _graph_schedule([a], set())
|
||||||
changed = diff_schedule(SCHEDULES)
|
changed = diff_schedule([(ref_graph, ref_in_degree), (compare_graph, compare_in_degree)])
|
||||||
self.assertEqual(changed, 0)
|
|
||||||
SCHEDULES.clear()
|
|
||||||
|
|
||||||
# no diff with two schedule creation calls
|
|
||||||
a = Tensor([1])+Tensor([2])
|
|
||||||
with Context(ARANGE_DIFF=1): a.schedule()
|
|
||||||
b = Tensor([3])+Tensor([4])
|
|
||||||
with Context(ARANGE_DIFF=1): b.schedule()
|
|
||||||
self.assertEqual(len(SCHEDULES), 4)
|
|
||||||
changed = diff_schedule(SCHEDULES)
|
|
||||||
self.assertEqual(changed, 0)
|
self.assertEqual(changed, 0)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -2,11 +2,11 @@
|
||||||
|
|
||||||
# should assert
|
# should assert
|
||||||
sed -i 's/temp/temp1/g' ./tinygrad/codegen/lowerer.py
|
sed -i 's/temp/temp1/g' ./tinygrad/codegen/lowerer.py
|
||||||
ASSERT_PROCESS_REPLAY=1 python3 test/external/process_replay/process_replay.py &> /dev/null
|
COMPRAE_SCHEDULE=0 ASSERT_PROCESS_REPLAY=1 python3 test/external/process_replay/process_replay.py &> /dev/null
|
||||||
if [[ $? -eq 0 ]]; then
|
if [[ $? -eq 0 ]]; then
|
||||||
echo "PROCESS REPLAY IS WRONG."
|
echo "PROCESS REPLAY IS WRONG."
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
# should NOT assert
|
# should NOT assert
|
||||||
git stash > /dev/null
|
git stash > /dev/null
|
||||||
ASSERT_PROCESS_REPLAY=1 python3 test/external/process_replay/process_replay.py &> /dev/null
|
COMPRAE_SCHEDULE=0 ASSERT_PROCESS_REPLAY=1 python3 test/external/process_replay/process_replay.py &> /dev/null
|
||||||
|
|
|
@ -0,0 +1,11 @@
|
||||||
|
import difflib, logging
|
||||||
|
from tinygrad.helpers import colored, getenv
|
||||||
|
|
||||||
|
def print_diff(s0, s1, unified=getenv("UNIFIED_DIFF",1)):
|
||||||
|
if unified:
|
||||||
|
lines = list(difflib.unified_diff(str(s0).splitlines(), str(s1).splitlines()))
|
||||||
|
diff = "\n".join(colored(line, "red" if line.startswith("-") else "green" if line.startswith("+") else None) for line in lines)
|
||||||
|
else:
|
||||||
|
import ocdiff
|
||||||
|
diff = ocdiff.console_diff(str(s0), str(s1))
|
||||||
|
logging.info(diff)
|
|
@ -4,7 +4,7 @@ from dataclasses import dataclass, field
|
||||||
from typing import Tuple, List, Dict, Optional, Set, DefaultDict, cast, get_args
|
from typing import Tuple, List, Dict, Optional, Set, DefaultDict, cast, get_args
|
||||||
from tinygrad.ops import MetaOps, BufferOps, LazyOp, Op, ReduceOps, ConstBuffer, MemBuffer, UNSAFE_PAD_OPS, UnaryOps, reduce_st
|
from tinygrad.ops import MetaOps, BufferOps, LazyOp, Op, ReduceOps, ConstBuffer, MemBuffer, UNSAFE_PAD_OPS, UnaryOps, reduce_st
|
||||||
from tinygrad.engine.graph import log_lazybuffer, realized_lazybuffer
|
from tinygrad.engine.graph import log_lazybuffer, realized_lazybuffer
|
||||||
from tinygrad.helpers import ARANGE_DIFF, GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, FUSE_CONV_BW, FUSE_ARANGE, Context, \
|
from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, FUSE_CONV_BW, FUSE_ARANGE, \
|
||||||
GlobalCounters, colored, prod, dedup, all_int, merge_dicts, getenv, Metadata
|
GlobalCounters, colored, prod, dedup, all_int, merge_dicts, getenv, Metadata
|
||||||
from tinygrad.shape.symbolic import Variable, sint
|
from tinygrad.shape.symbolic import Variable, sint
|
||||||
from tinygrad.dtype import ConstType, ImageDType, dtypes
|
from tinygrad.dtype import ConstType, ImageDType, dtypes
|
||||||
|
@ -367,9 +367,6 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> \
|
||||||
|
|
||||||
if SAVE_SCHEDULE:
|
if SAVE_SCHEDULE:
|
||||||
def _save():
|
def _save():
|
||||||
if ARANGE_DIFF:
|
|
||||||
from test.external.process_replay.diff_schedule import diff_schedule
|
|
||||||
return diff_schedule(SCHEDULES)
|
|
||||||
print(f"saving {len(SCHEDULES)} schedule graphs to", fp:=getenv("SAVE_SCHEDULE_PATH", "schedule.pkl"))
|
print(f"saving {len(SCHEDULES)} schedule graphs to", fp:=getenv("SAVE_SCHEDULE_PATH", "schedule.pkl"))
|
||||||
with open(fp, "wb") as f: pickle.dump(SCHEDULES, f)
|
with open(fp, "wb") as f: pickle.dump(SCHEDULES, f)
|
||||||
if len(SCHEDULES) == 0: atexit.register(_save)
|
if len(SCHEDULES) == 0: atexit.register(_save)
|
||||||
|
@ -380,10 +377,10 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> \
|
||||||
|
|
||||||
def create_schedule_with_vars(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None) -> Tuple[List[ScheduleItem], Dict[Variable, int]]:
|
def create_schedule_with_vars(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None) -> Tuple[List[ScheduleItem], Dict[Variable, int]]:
|
||||||
if seen is None: seen = set()
|
if seen is None: seen = set()
|
||||||
if ARANGE_DIFF:
|
graph, in_degree = _graph_schedule(outs, seen)
|
||||||
with Context(FUSE_ARANGE=0, SAVE_SCHEDULE=1): _graph_schedule(outs, set())
|
if getenv("RUN_PROCESS_REPLAY"):
|
||||||
with Context(FUSE_ARANGE=1, SAVE_SCHEDULE=1): graph, in_degree = _graph_schedule(outs, seen)
|
from test.external.process_replay.diff_schedule import process_replay
|
||||||
else: graph, in_degree = _graph_schedule(outs, seen)
|
process_replay(outs, graph, in_degree)
|
||||||
queue = deque(lsi for lsi,deg in in_degree.items() if deg == 0)
|
queue = deque(lsi for lsi,deg in in_degree.items() if deg == 0)
|
||||||
schedule: List[ScheduleItem] = []
|
schedule: List[ScheduleItem] = []
|
||||||
var_vals: Dict[Variable, int] = {}
|
var_vals: Dict[Variable, int] = {}
|
||||||
|
|
Loading…
Reference in New Issue