mirror of https://github.com/commaai/tinygrad.git
context aware process replay [run_process_replay] (#5378)
* test tc as ctx var * remove from opts * process replay * pop variable * B -> Variable * fix re-assign * pop temp vars * move TRANSCENDENTAL=2
This commit is contained in:
parent
45e1b9d5e3
commit
004366b193
|
@ -348,12 +348,11 @@ jobs:
|
|||
run: PYTHONPATH="." METAL=1 CACHELEVEL=0 FUZZ_ALL_ACTIONS=1 DEPTH=2 FUZZ_N=48 FUZZ_MAX_SIZE=10000000 python test/external/fuzz_linearizer.py
|
||||
- name: Fuzz Test models schedule
|
||||
run: FUZZ_SCHEDULE=1 FUZZ_SCHEDULE_MAX_PATHS=5 python -m pytest test/models/test_train.py test/models/test_end2end.py
|
||||
- name: Run TRANSCENDENTAL math
|
||||
run: TRANSCENDENTAL=2 python -m pytest -n=auto test/test_ops.py::TestOps::test_sin test/test_ops.py::TestOps::test_cos test/test_ops.py::TestOps::test_tan test/test_ops.py::TestOps::test_exp test/test_ops.py::TestOps::test_log --durations=20
|
||||
- name: Run process replay tests
|
||||
if: env.RUN_PROCESS_REPLAY == '1'
|
||||
run: cp test/external/process_replay/process_replay.py ./process_replay.py && git fetch origin master && git checkout origin/master && PYTHONPATH=. python3 process_replay.py
|
||||
- name: Run TRANSCENDENTAL math
|
||||
# put this after process replay since it has same ast
|
||||
run: TRANSCENDENTAL=2 python -m pytest -n=auto test/test_ops.py::TestOps::test_sin test/test_ops.py::TestOps::test_cos test/test_ops.py::TestOps::test_tan test/test_ops.py::TestOps::test_exp test/test_ops.py::TestOps::test_log --durations=20
|
||||
|
||||
# testwebgl:
|
||||
# name: WebGL Tests
|
||||
|
@ -544,12 +543,11 @@ jobs:
|
|||
PYTHONPATH="." python examples/compile_efficientnet.py > recognize.c
|
||||
clang -O2 recognize.c -lm -o recognize
|
||||
cat test/models/efficientnet/Chicken.jpg | ./recognize | grep cock
|
||||
- name: Run TRANSCENDENTAL math
|
||||
run: TRANSCENDENTAL=2 python -m pytest -n=auto test/test_ops.py::TestOps::test_sin test/test_ops.py::TestOps::test_cos test/test_ops.py::TestOps::test_tan test/test_ops.py::TestOps::test_exp test/test_ops.py::TestOps::test_log --durations=20
|
||||
- name: Run process replay tests
|
||||
if: env.RUN_PROCESS_REPLAY == '1'
|
||||
run: cp test/external/process_replay/process_replay.py ./process_replay.py && git fetch origin master && git checkout origin/master && PYTHONPATH=. python3 process_replay.py
|
||||
- name: Run TRANSCENDENTAL math
|
||||
# put this after process replay since it has same ast
|
||||
run: TRANSCENDENTAL=2 python -m pytest -n=auto test/test_ops.py::TestOps::test_sin test/test_ops.py::TestOps::test_cos test/test_ops.py::TestOps::test_tan test/test_ops.py::TestOps::test_exp test/test_ops.py::TestOps::test_log --durations=20
|
||||
|
||||
#testunicorn:
|
||||
# name: ARM64 unicorn Test
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
# compare kernels created by HEAD against master
|
||||
import difflib, pickle
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
from tinygrad.helpers import colored, db_connection, VERSION, getenv, tqdm
|
||||
from tinygrad.helpers import Context, ContextVar, colored, db_connection, VERSION, getenv, tqdm
|
||||
|
||||
page_size = 100
|
||||
conn = db_connection()
|
||||
|
@ -11,16 +11,17 @@ row_count = cur.execute(f"select count(*) from 'process_replay_{VERSION}'").fetc
|
|||
for offset in tqdm(range(0, row_count, page_size)):
|
||||
cur.execute(f"SELECT val FROM 'process_replay_{VERSION}' LIMIT ? OFFSET ?", (page_size, offset))
|
||||
for row in cur.fetchall():
|
||||
ast, opts, applied_opts, name, compare_src = pickle.loads(row[0])
|
||||
k = Linearizer(*ast, opts=opts)
|
||||
for opt in applied_opts: k.apply_opt(opt)
|
||||
good_src = k.opts.render(name, k.linearize().uops)
|
||||
try: assert compare_src == good_src
|
||||
except AssertionError as e:
|
||||
print("PROCESS REPLAY DETECTED CHANGE")
|
||||
print(ast)
|
||||
print(applied_opts)
|
||||
diff = list(difflib.unified_diff(good_src.splitlines(), compare_src.splitlines()))
|
||||
for line in diff:
|
||||
print(colored(line, "red" if line.startswith("-") else "green" if line.startswith("+") else None))
|
||||
if getenv("ASSERT_PROCESS_REPLAY", 1): raise e
|
||||
ast, opts, applied_opts, name, compare_src, ctx = pickle.loads(row[0])
|
||||
with Context(**{k:v for k,v in ctx.items() if k in ContextVar._cache}):
|
||||
k = Linearizer(*ast, opts=opts)
|
||||
for opt in applied_opts: k.apply_opt(opt)
|
||||
good_src = k.opts.render(name, k.linearize().uops)
|
||||
try: assert compare_src == good_src
|
||||
except AssertionError as e:
|
||||
print("PROCESS REPLAY DETECTED CHANGE")
|
||||
print(ast)
|
||||
print(applied_opts)
|
||||
diff = list(difflib.unified_diff(good_src.splitlines(), compare_src.splitlines()))
|
||||
for line in diff:
|
||||
print(colored(line, "red" if line.startswith("-") else "green" if line.startswith("+") else None))
|
||||
if getenv("ASSERT_PROCESS_REPLAY", 1): raise e
|
||||
|
|
|
@ -8,7 +8,7 @@ from tinygrad.ops import BufferOps, LazyOp, TernaryOps, ReduceOps, UnaryOps, get
|
|||
from tinygrad.codegen.uops import UOp, flops_mem, UOps
|
||||
from tinygrad.codegen.uopgraph import UOpGraph
|
||||
from tinygrad.renderer import Program
|
||||
from tinygrad.helpers import to_function_name, DEBUG, getenv, prod, diskcache_put
|
||||
from tinygrad.helpers import to_function_name, DEBUG, getenv, prod, diskcache_put, ContextVar
|
||||
|
||||
# TODO: this needs to be replaced, there shouldn't be variables in the shapetracker
|
||||
def variable_to_uop(x, ctx=None) -> UOp:
|
||||
|
@ -136,7 +136,8 @@ class Lowerer(Kernel):
|
|||
def to_program(self) -> Program:
|
||||
self.linearize()
|
||||
src = self.opts.render(name:=to_function_name(self.name), self.uops)
|
||||
if getenv("RUN_PROCESS_REPLAY"): diskcache_put("process_replay", id(self), (self.ast, self.opts, self.applied_opts, name, src))
|
||||
if getenv("RUN_PROCESS_REPLAY"):
|
||||
diskcache_put("process_replay", id(self), (self.ast, self.opts, self.applied_opts, name, src, {k:v.value for k,v in ContextVar._cache.items()}))
|
||||
info = get_lazyop_info(self.ast[0])
|
||||
ops, mem = flops_mem(self.uops.uops)
|
||||
run_count = prod((self.global_size or []) + (self.local_size or []))
|
||||
|
|
Loading…
Reference in New Issue