mirror of https://github.com/commaai/tinygrad.git
new viz unittests, isolate the ctx bug (#7069)
* start new test_viz * test_rewrite_twice * test_rewrite_with_ctx * add back some of the old tests * lints
This commit is contained in:
parent
9f00eacde5
commit
52d8afde2b
157
test/test_viz.py
157
test/test_viz.py
|
@ -1,78 +1,76 @@
|
|||
from typing import List
|
||||
from typing import Dict, List, Optional
|
||||
import unittest
|
||||
import itertools
|
||||
from tinygrad import Tensor, dtypes
|
||||
from tinygrad.helpers import Context, getenv
|
||||
from tinygrad.engine.realize import lower_schedule
|
||||
from tinygrad.viz.serve import GraphRewriteMetadata, get_metadata, _uop_to_json
|
||||
from tinygrad.ops import TRACK_MATCH_STATS, TrackedPatternMatcher, UPat, UOps, UOp, graph_rewrite, contexts, track_rewrites
|
||||
from tinygrad.dtype import PtrDType, dtypes
|
||||
from tinygrad.ops import TRACK_MATCH_STATS, BinaryOps, TrackedPatternMatcher as PatternMatcher, UOp, UOps, UPat, \
|
||||
graph_rewrite, contexts, track_rewrites
|
||||
from tinygrad.viz.serve import _replace_uop, get_details, get_metadata
|
||||
|
||||
def group_rewrites(kernels:List[GraphRewriteMetadata]): return {k:list(v) for k,v in itertools.groupby(kernels, lambda x:x.loc)}
|
||||
@track_rewrites
|
||||
def rewrite(sink:UOp, pm:PatternMatcher, ctx=None): return graph_rewrite(sink, pm, ctx)
|
||||
|
||||
def helper_test_viz(sink:UOp, pm:PatternMatcher, ctx=None) -> List[UOp]:
|
||||
rewrite(sink, pm, ctx)
|
||||
assert len(contexts) == 1
|
||||
assert len(contexts[0][1]) == 1
|
||||
ctx = contexts[0][1][0]
|
||||
uops = [ctx.sink]
|
||||
replaces: Dict[UOp, UOp] = {}
|
||||
for u0,u1,_ in ctx.rewrites:
|
||||
replaces[u0] = u1
|
||||
new_sink = _replace_uop(uops[-1], {**replaces})
|
||||
uops.append(new_sink)
|
||||
return uops[1:]
|
||||
|
||||
class TestViz(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
def setUp(self):
|
||||
contexts.clear()
|
||||
self.prev_val = TRACK_MATCH_STATS.value
|
||||
self.tms = TRACK_MATCH_STATS.value
|
||||
TRACK_MATCH_STATS.value = 2
|
||||
def tearDown(self) -> None:
|
||||
from tinygrad.ops import TRACK_MATCH_STATS, contexts
|
||||
if not getenv("VIZ"): contexts.clear()
|
||||
TRACK_MATCH_STATS.value = self.prev_val
|
||||
def tearDown(self): TRACK_MATCH_STATS.value = self.tms
|
||||
|
||||
def assert_valid_ctx(self):
|
||||
from tinygrad.ops import contexts
|
||||
assert len(contexts) != 0
|
||||
return get_metadata(contexts)
|
||||
def test_viz_simple(self):
|
||||
pm = PatternMatcher([
|
||||
(UPat.var("x")*1, lambda x:x),
|
||||
])
|
||||
a = UOp(UOps.LOAD, dtypes.int, (UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0), UOp.const(dtypes.int, 0)))
|
||||
uops = helper_test_viz(a*1, pm)
|
||||
self.assertEqual(len(uops), 1)
|
||||
self.assertEqual(uops[0], a)
|
||||
|
||||
def assert_valid_graph(self, t):
|
||||
s = t.schedule()
|
||||
list(lower_schedule(s))
|
||||
self.assert_valid_ctx()
|
||||
def test_rewrite_twice(self):
|
||||
pm = PatternMatcher([
|
||||
(UPat.var("x")+UPat.var("x"), lambda x:x*2),
|
||||
(UPat.var("x", dtypes.int)*2, lambda x:x.alu(BinaryOps.SHL, UOp.const(dtypes.int, 1))),
|
||||
])
|
||||
a = UOp(UOps.LOAD, dtypes.int, (UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0), UOp.const(dtypes.int, 0)))
|
||||
uops = helper_test_viz(a+a, pm)
|
||||
self.assertEqual(len(uops), 2)
|
||||
self.assertEqual(uops[0], a*2)
|
||||
self.assertEqual(uops[1], graph_rewrite(a+a, pm))
|
||||
|
||||
def test_ctx_diff(self):
|
||||
a = Tensor.ones(4, 1).contiguous().realize()
|
||||
out = a + a.reshape(1, 4)
|
||||
self.assert_valid_graph(out)
|
||||
|
||||
def test_ctx_groups(self):
|
||||
schedule1 = Tensor.zeros(4, 1).contiguous().exp().schedule()
|
||||
schedule2 = Tensor.zeros(4, 1).contiguous().exp().schedule()
|
||||
list(lower_schedule(schedule1))
|
||||
list(lower_schedule(schedule2))
|
||||
ret = self.assert_valid_ctx()
|
||||
assert len(ret) == 3
|
||||
assert all(len([x for _,_,x in y if "schedule" in x.loc[0]]) == 0 for y in ret[1:])
|
||||
assert all(len([x for _,_,x in y if "uopgraph" in x.loc[0]]) != 0 for y in ret[1:])
|
||||
|
||||
def test_gemm_diff(self):
|
||||
x = Tensor.empty(64, 64).realize()
|
||||
y = Tensor.empty(64, 64).realize()
|
||||
out = x.matmul(y)
|
||||
self.assert_valid_graph(out)
|
||||
|
||||
def test_track_no_ctx(self):
|
||||
@track_rewrites
|
||||
def simplify_and_verify(u:UOp):
|
||||
simplify = TrackedPatternMatcher([(UPat.var("x")*1, lambda x:x)])
|
||||
verify = TrackedPatternMatcher([(UPat(UOps.CONST), lambda:True)])
|
||||
verify.rewrite(graph_rewrite(u, simplify))
|
||||
u = UOp(UOps.LOAD, dtypes.int, (UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1), UOp.const(dtypes.int, 0)))*1
|
||||
simplify_and_verify(u)
|
||||
ret = self.assert_valid_ctx()
|
||||
self.assertEqual(len(ret), 1)
|
||||
key, ctx, metadata = ret[0][0]
|
||||
self.assertIs(key, u)
|
||||
self.assertIs(ctx.sink, u)
|
||||
self.assertEqual(len(metadata.upats), 1)
|
||||
@unittest.expectedFailure
|
||||
def test_rewrite_with_ctx(self):
|
||||
a = UOp(UOps.LOAD, dtypes.int, (UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0), UOp.const(dtypes.int, 0)))
|
||||
b = UOp(UOps.LOAD, dtypes.int, (UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 1), UOp.const(dtypes.int, 0)))
|
||||
def store_load(visited:Dict[UOp, None], x:UOp) -> Optional[UOp]:
|
||||
if x in visited: return None
|
||||
visited[x] = None
|
||||
return UOp.store(*x.src, x)
|
||||
pm = PatternMatcher([
|
||||
(UPat(UOps.LOAD, name="x"), store_load),
|
||||
])
|
||||
uops = helper_test_viz(a+b, pm, {})
|
||||
self.assertEqual(len(uops), 2)
|
||||
self.assertEqual(uops[-1], graph_rewrite(a+b, pm, {}))
|
||||
|
||||
def test_track_rewrites(self):
|
||||
simple = TrackedPatternMatcher([(UPat.var("x")*1, lambda x:x)])
|
||||
simple = PatternMatcher([(UPat.var("x")*1, lambda x:x)])
|
||||
@track_rewrites
|
||||
def do_rewrite(key:str, x:UOp): return graph_rewrite(x, simple)
|
||||
ld = UOp(UOps.LOAD, dtypes.int, (UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1), UOp.const(dtypes.int, 0)))
|
||||
do_rewrite("uop_0", ld*1)
|
||||
do_rewrite("uop_1", ld*2)
|
||||
ret = self.assert_valid_ctx()
|
||||
ret = get_metadata(contexts)
|
||||
self.assertEqual(len(ret), 1)
|
||||
key, _, m = ret[0][0]
|
||||
self.assertEqual(key, "uop_0")
|
||||
|
@ -81,43 +79,26 @@ class TestViz(unittest.TestCase):
|
|||
self.assertEqual(key, "uop_1")
|
||||
self.assertEqual(len(m.upats), 0)
|
||||
|
||||
def test_track_with_exception(self):
|
||||
simple = TrackedPatternMatcher([(UPat.var("x")*1, lambda x:x)])
|
||||
def test_track_rewrites_with_exception(self):
|
||||
simple = PatternMatcher([(UPat.var("x")*1, lambda x:x)])
|
||||
@track_rewrites
|
||||
def do_rewrite(key:str, x:UOp):
|
||||
x = graph_rewrite(x, simple) # NOTE: viz tracks this
|
||||
raise Exception("test")
|
||||
ld = UOp(UOps.LOAD, dtypes.int, (UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1), UOp.const(dtypes.int, 0)))
|
||||
with self.assertRaises(Exception): do_rewrite("uop_0", ld*1)
|
||||
ret = self.assert_valid_ctx()
|
||||
ret = get_metadata(contexts)
|
||||
self.assertEqual(len(ret), 1)
|
||||
|
||||
def test_dedup_ast(self):
|
||||
a = Tensor.empty(4, 4).contiguous().realize()+2
|
||||
b = Tensor.empty(4, 4).contiguous().realize()+2
|
||||
Tensor.schedule(a, b)
|
||||
kernels = self.assert_valid_ctx()
|
||||
self.assertEqual(len(kernels), 1)
|
||||
rewrites = [x[2] for x in kernels[0]]
|
||||
assert all(len(v) == 1 for k,v in group_rewrites(rewrites).items() if "schedule.py" in k)
|
||||
|
||||
@unittest.skip("broken")
|
||||
def test_no_dedup_different_opts(self):
|
||||
a = Tensor.empty(4, 4)+Tensor.empty(4, 4)
|
||||
s = a.schedule()
|
||||
with Context(NOOPT=1): list(lower_schedule(s.copy()))
|
||||
with Context(NOOPT=0): list(lower_schedule(s.copy()))
|
||||
kernels = self.assert_valid_ctx()[1:]
|
||||
self.assertEqual(len(kernels), 2)
|
||||
rewrites = [x[2] for x in kernels[0]]
|
||||
assert all(len(v) == 1 for _,v in group_rewrites(rewrites).items())
|
||||
|
||||
def test_fold_const_nodes(self):
|
||||
a = Tensor.empty(4, 4)+2
|
||||
sink = a.schedule()[-1].ast
|
||||
ret = _uop_to_json(sink)
|
||||
assert not any(v[0].startswith("CONST") for v in ret.values())
|
||||
assert len([x for x in ret.values() if "CONST" in x[0]]) == 1
|
||||
def test_fold_const(self):
|
||||
pm = PatternMatcher([
|
||||
(UPat.var("x")*1, lambda x:x),
|
||||
])
|
||||
a = UOp(UOps.LOAD, dtypes.int, (UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0), UOp.const(dtypes.int, 0)))
|
||||
rewrite(a, pm)
|
||||
graph = get_details(*get_metadata(contexts)[0][0]).graphs[-1]
|
||||
assert not any(v[0].startswith("CONST") for v in graph.values())
|
||||
assert len([x for x in graph.values() if "CONST" in x[0]]) == 1
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
Loading…
Reference in New Issue