Simpler mops verify (#2325)

* rewrite the to_movement_ops check using symbolic

* tweak
This commit is contained in:
chenyu 2023-11-15 21:47:18 -05:00 committed by GitHub
parent ef67d7ff5d
commit 822d6e6f18
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 38 additions and 56 deletions

View File

@ -8,7 +8,8 @@ inf, nan = float('inf'), float('nan')
# kernel unpacker
from tinygrad.codegen.linearizer import Linearizer
def ast_str_to_lin(ast_str): return Linearizer(eval(ast_str))
def ast_str_to_ast(ast_str:str) -> LazyOp: return eval(ast_str)
def ast_str_to_lin(ast_str:str): return Linearizer(ast_str_to_ast(ast_str))
# load worlds, a dataset of about 12k kernels
import gzip

View File

@ -1,13 +1,10 @@
import random
from tqdm import tqdm
from extra.optimization.helpers import load_worlds
from tinygrad.codegen.linearizer import Linearizer
from tinygrad.ops import LazyOp, MovementOps, TernaryOps, BinaryOps, UnaryOps, ReduceOps, BufferOps, MemBuffer, ConstBuffer
from tinygrad.helpers import dtypes, prod
import itertools
from extra.optimization.helpers import load_worlds, ast_str_to_ast
from tinygrad.ops import MovementOps, BufferOps, LazyOp
from tinygrad.helpers import prod
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
from tinygrad.shape.symbolic import Node, Variable
inf, nan = float('inf'), float('nan')
from tinygrad.shape.symbolic import sym_infer
def get_real_view(shape, strides, offset, mask):
real_shape = tuple(y-x for x,y in mask) if mask else shape
@ -21,40 +18,30 @@ def get_buffer_size(shape, strides, offset, mask):
real_real_shape, strides, real_offset = get_real_view(shape, strides, offset, mask)
return real_offset + sum((s-1)*st for s, st in zip(real_real_shape,strides)) + 1
def flatten_view(view: View):
real_real_shape, strides, real_offset = get_real_view(view.shape, view.strides, view.offset, view.mask)
def sort_by_strides(shape, strides): return sorted(zip(shape, strides), key=lambda k: (k[1],-k[0]), reverse=True), sorted(range(len(strides)), key=lambda k: (strides[k],-real_real_shape[k]), reverse=True)
ordered_shape_strides, _ = sort_by_strides(real_real_shape, strides)
ordered_shape_strides = [list(s) for s in ordered_shape_strides]
if strides:
i = 0
while i < len(ordered_shape_strides):
if i<len(ordered_shape_strides)-1 and ordered_shape_strides[i][1] == ordered_shape_strides[i+1][0]*ordered_shape_strides[i+1][1]:
ordered_shape_strides[i+1][0] = ordered_shape_strides[i][0]*ordered_shape_strides[i+1][0]
else: i += 1
flat_shape = [shape_stride[0] for shape_stride in ordered_shape_strides]
flat_strides = [shape_stride[1] for shape_stride in ordered_shape_strides]
return (flat_shape, flat_strides, real_offset)
return (real_real_shape, view.strides, real_offset)
def st_equivalent(st1: ShapeTracker, st2: ShapeTracker):
if (idxs1:=st1.expr_idxs()) == (idxs2:=st2.expr_idxs()): return True
idx1, valid1 = idxs1
idx2, valid2 = idxs2
# always invalid
if valid1 == 0 and valid2 == 0: return True
def views_equivalent(v1: View, v2: View) -> bool:
return v1 == v2 or flatten_view(v1) == flatten_view(v2)
var1 = set(idx1.vars() + valid1.vars())
var2 = set(idx2.vars() + valid2.vars())
# Maybe there are cases that vars are different yet the sts are the same?
if var1 != var2: return False
# brute force over the vars range
vs = list(var1)
for i, ranges in enumerate(itertools.product(*[range(v.min, v.max+1) for v in vs])):
if i > 1000:
print("WARNING: did not search all possible combinations")
# not happening for now
break
var_vals = {k:v for k,v in zip(vs, ranges)}
r1 = sym_infer(idx1, var_vals) if sym_infer(valid1, var_vals) else 0
r2 = sym_infer(idx2, var_vals) if sym_infer(valid2, var_vals) else 0
if r1 != r2: return False
def st_equivalent(st: ShapeTracker, st_rebuilt: ShapeTracker):
views = list(st.views)
rebuilt_views = list(st_rebuilt.views)
i = 0
while i < len(views):
view, rebuilt_view = views[i], rebuilt_views[i]
if view == rebuilt_view:
i += 1
continue
elif view.shape == rebuilt_view.shape:
i += 1
# hack to skip expands for overlapped strides
else:
rebuilt_views.pop(i)
return True
def test_rebuild(st: ShapeTracker):
@ -81,25 +68,19 @@ def test_rebuild(st: ShapeTracker):
else:
raise Exception("invalid mop")
rebuilt_st = rebuilt_st.simplify()
if len(st.views) != len(rebuilt_st.views):
if not set(st.views).issubset(set(rebuilt_st.views)):
assert st_equivalent(st, rebuilt_st)
else:
for v1,v2 in zip(st.views, rebuilt_st.views):
assert views_equivalent(v1, v2), f"{v1} not equivalent to {v2}"
assert st_equivalent(st, rebuilt_st)
last_v1 = st.views[-1]
last_v2 = rebuilt_st.views[-1]
assert last_v1.shape == last_v2.shape, f"{last_v1.shape} != {last_v2.shape}"
def test_interpret_ast(ast:LazyOp):
if ast.op in BufferOps:
test_rebuild(ast.arg.st)
else:
for src in ast.src: test_interpret_ast(src)
if __name__ == "__main__":
ast_strs = load_worlds(False, False, True)
random.shuffle(ast_strs)
ast_strs = ast_strs[:2000]
def interpret_ast(ast):
if ast.op in BufferOps:
test_rebuild(ast.arg.st)
else:
for src in ast.src: interpret_ast(src)
ast_strs = load_worlds(False, False, True)[:4000]
for ast_str in tqdm(ast_strs):
ast = eval(ast_str)
interpret_ast(ast)
test_interpret_ast(ast_str_to_ast(ast_str))