tinygrad/extra/to_movement_ops.py

149 lines
8.1 KiB
Python
Raw Permalink Normal View History

import itertools
from enum import Enum, auto
from collections import defaultdict
from typing import List, Tuple, DefaultDict
from extra.optimization.helpers import load_worlds, ast_str_to_ast
from tinygrad.helpers import prod, tqdm
from tinygrad.ops import UOp, UOps
from tinygrad.shape.shapetracker import ShapeTracker
2024-10-12 18:44:44 +08:00
from tinygrad.ops import sym_infer, Node
class MovementOps(Enum): RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); STRIDE = auto(); AS_STRIDED = auto() # noqa: E702
def apply_mop(st: ShapeTracker, mop_arg: Tuple[MovementOps, Tuple]) -> ShapeTracker:
mop, arg = mop_arg
if mop == MovementOps.RESHAPE:
# shapetracker doesn't allow flattening with -1 but required for MovementOps.RESHAPE
if arg == (-1,): return st.reshape((prod(st.views[-1].shape),))
return st.reshape(arg)
if mop == MovementOps.PERMUTE: return st.permute(arg)
if mop == MovementOps.EXPAND:
if len(arg) != len(st.shape): st = st.reshape((1,*st.shape))
return st.expand(arg)
if mop == MovementOps.PAD: return st.pad(arg)
if mop == MovementOps.SHRINK: return st.shrink(arg)
if mop == MovementOps.STRIDE: return st.stride(arg)
raise ValueError("invalid mop")
def make_scratch_st(st: ShapeTracker) -> ShapeTracker:
return ShapeTracker.from_shape((get_buffer_size(st.views[0].shape, st.views[0].strides, st.views[0].offset, st.views[0].mask),))
# ShapeTracker to an equivalent series of MovementOps (https://github.com/tinygrad/tinygrad/pull/2216)
def to_movement_ops(st: ShapeTracker) -> List[Tuple[MovementOps, Tuple]]:
to_apply:List[Tuple[MovementOps, Tuple]] = []
for i, v in enumerate(st.views):
real_shape = tuple(y-x for x,y in v.mask) if v.mask else v.shape
offset = v.offset + sum(st*(s-1) for s,st in zip(real_shape, v.strides) if st<0)
real_offset = offset + (sum(x*st for (x,_),st in zip(v.mask, v.strides)) if v.mask else 0)
real_real_shape = [s for s,st in zip(real_shape, v.strides) if st]
strides: List[Node|int] = [abs(st) if isinstance(st,int) else st for st in v.strides if st]
buffer_size = sum((s-1)*st for s,st in zip(real_real_shape,strides)) + 1
if i: buffer_size = prod(st.views[i-1].shape) - real_offset
def sort_by_strides(shape, strides): return sorted(zip(shape, strides), key=lambda k: (k[1],-k[0]), reverse=True), sorted(range(len(strides)), key=lambda k: (strides[k],-real_real_shape[k]), reverse=True)
ordered_shape_strides, order = sort_by_strides(real_real_shape, strides)
to_apply.extend([(MovementOps.RESHAPE, (-1,)), (MovementOps.SHRINK, ((real_offset, real_offset+buffer_size),))])
if strides:
if (ordered_shape_strides[0][0]*ordered_shape_strides[0][1])-buffer_size>0: to_apply.append((MovementOps.PAD, ((0, (ordered_shape_strides[0][0] * ordered_shape_strides[0][1]) - buffer_size),)))
for i, shape_stride in enumerate(ordered_shape_strides):
if i<len(ordered_shape_strides)-1 and shape_stride[1] < ordered_shape_strides[i+1][0]*ordered_shape_strides[i+1][1]:
remaining_buffer = ordered_shape_strides[i-1][1] if i>0 else buffer_size
to_apply.append((MovementOps.EXPAND, (shape_stride[0], *(s[0] for s in ordered_shape_strides[:i]), remaining_buffer)))
to_apply.append((MovementOps.PERMUTE, (*range(1,i+1), 0, i+1)))
to_apply.append((MovementOps.RESHAPE, (*(s[0] for s in ordered_shape_strides[:i]), shape_stride[0]*remaining_buffer)))
to_apply.append((MovementOps.PAD, (*((0,0) for _ in range(i)), (0, shape_stride[0]*shape_stride[1]))))
to_apply.append((MovementOps.RESHAPE, (*(s[0] for s in ordered_shape_strides[:i+1]), remaining_buffer+shape_stride[1])))
ordered_shape_strides[i] = (ordered_shape_strides[i][0], remaining_buffer+shape_stride[1])
else:
to_apply.append((MovementOps.SHRINK, (*((0, s[0]) for s in ordered_shape_strides[:i]), (0, shape_stride[0]*shape_stride[1]))))
to_apply.append((MovementOps.RESHAPE, (*[s[0] for s in ordered_shape_strides[:i+1]], shape_stride[1])))
to_apply.extend([(MovementOps.SHRINK, (*[(0, s[0]) for s in ordered_shape_strides], (0,1))), (MovementOps.RESHAPE, tuple(s[0] for s in ordered_shape_strides))])
if order != list(range(len(order))): to_apply.append((MovementOps.PERMUTE, tuple(order.index(i) for i in range(len(strides)))))
to_apply.append((MovementOps.RESHAPE, tuple(s if st else 1 for s,st in zip(real_shape, v.strides))))
if any(i<0 for i in v.strides): to_apply.append((MovementOps.STRIDE, tuple(-1 if st<0 else 1 for st in v.strides)))
# then, we apply pre expand pads
if v.mask is not None:
pre_expand_pads = tuple((x,s-y) if st != 0 else (0,0) for (x,y),s,st in zip(v.mask, v.shape, v.strides))
post_expand_pads = tuple((x,s-y) if st == 0 else (0,0) for (x,y),s,st in zip(v.mask, v.shape, v.strides))
if any(x != (0,0) for x in pre_expand_pads):
to_apply.append((MovementOps.PAD, pre_expand_pads))
real_shape = tuple(x+s[0]+s[1] for x,s in zip(real_shape, pre_expand_pads))
# then, we do any expands
if any(s != 1 and st == 0 for s,st in zip(real_shape, v.strides)): to_apply.append((MovementOps.EXPAND, real_shape))
# lastly, we apply post expand pads
if v.mask is not None and any(x != (0,0) for x in post_expand_pads): to_apply.append((MovementOps.PAD, post_expand_pads))
scratch_st = make_scratch_st(st)
ret = []
seen = {} # {shapetracker: list of mops to generate that shapetracker}
for mop_arg in to_apply:
scratch_st = apply_mop(scratch_st, mop_arg)
if scratch_st in seen:
ret = seen[scratch_st][:]
else:
ret.append(mop_arg)
seen[scratch_st] = ret[:]
return ret
def get_real_view(shape, strides, offset, mask):
real_shape = tuple(y-x for x,y in mask) if mask else shape
offset = offset + sum(st * (s-1) for s,st in zip(real_shape, strides) if st<0)
real_offset = offset + (sum(x*st for (x,_),st in zip(mask, strides)) if mask else 0)
real_real_shape = [s for s,st in zip(real_shape, strides) if st]
strides = [abs(st) if isinstance(st,int) else st for st in strides if st]
return real_real_shape, strides, real_offset
def get_buffer_size(shape, strides, offset, mask):
real_real_shape, strides, real_offset = get_real_view(shape, strides, offset, mask)
return real_offset + sum((s-1)*st for s, st in zip(real_real_shape,strides)) + 1
def st_equivalent(st1: ShapeTracker, st2: ShapeTracker):
if (idxs1:=st1.expr_idxs()) == (idxs2:=st2.expr_idxs()): return True
idx1, valid1 = idxs1
idx2, valid2 = idxs2
# always invalid
if valid1 == 0 and valid2 == 0: return True
var1 = idx1.vars() | valid1.vars()
var2 = idx2.vars() | valid2.vars()
# Maybe there are cases that vars are different yet the sts are the same?
if var1 != var2: return False
# brute force over the vars range
vs = list(var1)
for i, ranges in enumerate(itertools.product(*[range(v.min, v.max+1) for v in vs])):
if i > 1000:
print("WARNING: did not search all possible combinations")
break
var_vals = {k:v for k,v in zip(vs, ranges)}
r1 = sym_infer(idx1, var_vals) if sym_infer(valid1, var_vals) else 0
r2 = sym_infer(idx2, var_vals) if sym_infer(valid2, var_vals) else 0
if r1 != r2: return False
return True
c: DefaultDict[int,int] = defaultdict(int)
def test_rebuild(st: ShapeTracker):
rebuilt_st = make_scratch_st(st)
mops = to_movement_ops(st)
c[len(mops)] += 1
for mop_arg in mops: rebuilt_st = apply_mop(rebuilt_st, mop_arg)
rebuilt_st = rebuilt_st.simplify()
# why is the "all(x == 0 for x in rebuilt_st.views[-1].strides)" hack needed?
assert st_equivalent(st, rebuilt_st) or all(x == 0 for x in rebuilt_st.views[-1].strides), f"mismatch {st} {rebuilt_st}"
last_v1 = st.views[-1]
last_v2 = rebuilt_st.views[-1]
assert last_v1.shape == last_v2.shape, f"{last_v1.shape} != {last_v2.shape}"
def test_rebuild_bufferop_st(ast:UOp):
if ast.op is UOps.SHAPETRACKER:
test_rebuild(ast.arg)
for src in ast.src: test_rebuild_bufferop_st(src)
if __name__ == "__main__":
ast_strs = load_worlds(False, False, True)[:2000]
for ast_str in tqdm(ast_strs):
test_rebuild_bufferop_st(ast_str_to_ast(ast_str))
print(f"avg length of mop = {sum(k*v for k,v in c.items()) / sum(c.values()):.2f}")