mirror of https://github.com/commaai/tinygrad.git
Simpler mops verify (#2325)
* rewrite the to_movement_ops check using symbolic * tweak
This commit is contained in:
parent
ef67d7ff5d
commit
822d6e6f18
|
@ -8,7 +8,8 @@ inf, nan = float('inf'), float('nan')
|
|||
|
||||
# kernel unpacker
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
def ast_str_to_lin(ast_str): return Linearizer(eval(ast_str))
|
||||
def ast_str_to_ast(ast_str:str) -> LazyOp: return eval(ast_str)
|
||||
def ast_str_to_lin(ast_str:str): return Linearizer(ast_str_to_ast(ast_str))
|
||||
|
||||
# load worlds, a dataset of about 12k kernels
|
||||
import gzip
|
||||
|
|
|
@ -1,13 +1,10 @@
|
|||
import random
|
||||
from tqdm import tqdm
|
||||
from extra.optimization.helpers import load_worlds
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
from tinygrad.ops import LazyOp, MovementOps, TernaryOps, BinaryOps, UnaryOps, ReduceOps, BufferOps, MemBuffer, ConstBuffer
|
||||
from tinygrad.helpers import dtypes, prod
|
||||
import itertools
|
||||
from extra.optimization.helpers import load_worlds, ast_str_to_ast
|
||||
from tinygrad.ops import MovementOps, BufferOps, LazyOp
|
||||
from tinygrad.helpers import prod
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
from tinygrad.shape.symbolic import Node, Variable
|
||||
inf, nan = float('inf'), float('nan')
|
||||
from tinygrad.shape.symbolic import sym_infer
|
||||
|
||||
def get_real_view(shape, strides, offset, mask):
|
||||
real_shape = tuple(y-x for x,y in mask) if mask else shape
|
||||
|
@ -21,40 +18,30 @@ def get_buffer_size(shape, strides, offset, mask):
|
|||
real_real_shape, strides, real_offset = get_real_view(shape, strides, offset, mask)
|
||||
return real_offset + sum((s-1)*st for s, st in zip(real_real_shape,strides)) + 1
|
||||
|
||||
def flatten_view(view: View):
|
||||
real_real_shape, strides, real_offset = get_real_view(view.shape, view.strides, view.offset, view.mask)
|
||||
def sort_by_strides(shape, strides): return sorted(zip(shape, strides), key=lambda k: (k[1],-k[0]), reverse=True), sorted(range(len(strides)), key=lambda k: (strides[k],-real_real_shape[k]), reverse=True)
|
||||
ordered_shape_strides, _ = sort_by_strides(real_real_shape, strides)
|
||||
ordered_shape_strides = [list(s) for s in ordered_shape_strides]
|
||||
if strides:
|
||||
i = 0
|
||||
while i < len(ordered_shape_strides):
|
||||
if i<len(ordered_shape_strides)-1 and ordered_shape_strides[i][1] == ordered_shape_strides[i+1][0]*ordered_shape_strides[i+1][1]:
|
||||
ordered_shape_strides[i+1][0] = ordered_shape_strides[i][0]*ordered_shape_strides[i+1][0]
|
||||
else: i += 1
|
||||
flat_shape = [shape_stride[0] for shape_stride in ordered_shape_strides]
|
||||
flat_strides = [shape_stride[1] for shape_stride in ordered_shape_strides]
|
||||
return (flat_shape, flat_strides, real_offset)
|
||||
return (real_real_shape, view.strides, real_offset)
|
||||
def st_equivalent(st1: ShapeTracker, st2: ShapeTracker):
|
||||
if (idxs1:=st1.expr_idxs()) == (idxs2:=st2.expr_idxs()): return True
|
||||
idx1, valid1 = idxs1
|
||||
idx2, valid2 = idxs2
|
||||
# always invalid
|
||||
if valid1 == 0 and valid2 == 0: return True
|
||||
|
||||
def views_equivalent(v1: View, v2: View) -> bool:
|
||||
return v1 == v2 or flatten_view(v1) == flatten_view(v2)
|
||||
var1 = set(idx1.vars() + valid1.vars())
|
||||
var2 = set(idx2.vars() + valid2.vars())
|
||||
# Maybe there are cases that vars are different yet the sts are the same?
|
||||
if var1 != var2: return False
|
||||
|
||||
# brute force over the vars range
|
||||
vs = list(var1)
|
||||
for i, ranges in enumerate(itertools.product(*[range(v.min, v.max+1) for v in vs])):
|
||||
if i > 1000:
|
||||
print("WARNING: did not search all possible combinations")
|
||||
# not happening for now
|
||||
break
|
||||
var_vals = {k:v for k,v in zip(vs, ranges)}
|
||||
r1 = sym_infer(idx1, var_vals) if sym_infer(valid1, var_vals) else 0
|
||||
r2 = sym_infer(idx2, var_vals) if sym_infer(valid2, var_vals) else 0
|
||||
if r1 != r2: return False
|
||||
|
||||
def st_equivalent(st: ShapeTracker, st_rebuilt: ShapeTracker):
|
||||
views = list(st.views)
|
||||
rebuilt_views = list(st_rebuilt.views)
|
||||
i = 0
|
||||
while i < len(views):
|
||||
view, rebuilt_view = views[i], rebuilt_views[i]
|
||||
if view == rebuilt_view:
|
||||
i += 1
|
||||
continue
|
||||
elif view.shape == rebuilt_view.shape:
|
||||
i += 1
|
||||
# hack to skip expands for overlapped strides
|
||||
else:
|
||||
rebuilt_views.pop(i)
|
||||
return True
|
||||
|
||||
def test_rebuild(st: ShapeTracker):
|
||||
|
@ -81,25 +68,19 @@ def test_rebuild(st: ShapeTracker):
|
|||
else:
|
||||
raise Exception("invalid mop")
|
||||
rebuilt_st = rebuilt_st.simplify()
|
||||
if len(st.views) != len(rebuilt_st.views):
|
||||
if not set(st.views).issubset(set(rebuilt_st.views)):
|
||||
assert st_equivalent(st, rebuilt_st)
|
||||
else:
|
||||
for v1,v2 in zip(st.views, rebuilt_st.views):
|
||||
assert views_equivalent(v1, v2), f"{v1} not equivalent to {v2}"
|
||||
assert st_equivalent(st, rebuilt_st)
|
||||
last_v1 = st.views[-1]
|
||||
last_v2 = rebuilt_st.views[-1]
|
||||
assert last_v1.shape == last_v2.shape, f"{last_v1.shape} != {last_v2.shape}"
|
||||
|
||||
def test_interpret_ast(ast:LazyOp):
|
||||
if ast.op in BufferOps:
|
||||
test_rebuild(ast.arg.st)
|
||||
else:
|
||||
for src in ast.src: test_interpret_ast(src)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ast_strs = load_worlds(False, False, True)
|
||||
random.shuffle(ast_strs)
|
||||
ast_strs = ast_strs[:2000]
|
||||
def interpret_ast(ast):
|
||||
if ast.op in BufferOps:
|
||||
test_rebuild(ast.arg.st)
|
||||
else:
|
||||
for src in ast.src: interpret_ast(src)
|
||||
ast_strs = load_worlds(False, False, True)[:4000]
|
||||
for ast_str in tqdm(ast_strs):
|
||||
ast = eval(ast_str)
|
||||
interpret_ast(ast)
|
||||
test_interpret_ast(ast_str_to_ast(ast_str))
|
||||
|
|
Loading…
Reference in New Issue