mirror of https://github.com/commaai/tinygrad.git
constant fold pattern match (#3696)
* constant fold pattern match * match * better match * fix bug in pattern * more folding
This commit is contained in:
parent
dd1a1c12df
commit
6755a9254f
|
@ -1,8 +1,8 @@
|
|||
from __future__ import annotations
|
||||
import functools, math, operator
|
||||
import functools, math, operator, itertools
|
||||
from typing import List, Set, Optional, Tuple, Any, Dict, DefaultDict, Callable, cast
|
||||
from collections import defaultdict
|
||||
from tinygrad.helpers import DEBUG, flatten, all_same
|
||||
from tinygrad.helpers import DEBUG, flatten
|
||||
from tinygrad.dtype import dtypes, DType
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
|
||||
from tinygrad.shape.symbolic import sint, Variable, Node, NumNode, MulNode, DivNode, SumNode
|
||||
|
@ -20,10 +20,13 @@ class UOps(Enum):
|
|||
class UOp:
|
||||
uop: UOps
|
||||
dtype: Optional[DType]
|
||||
vin: Tuple[UOp, ...]
|
||||
arg: Any
|
||||
vin: Tuple[UOp, ...] = tuple()
|
||||
arg: Any = None
|
||||
def __repr__(self):
|
||||
return f"{str(self.uop):20s}: {str(self.dtype) if self.dtype is not None else '':25s} {str([x.uop for x in self.vin]):32s} {self.arg}"
|
||||
@staticmethod
|
||||
def const(dtype, val):
|
||||
return UOp(UOps.CONST, dtype, arg=float(val) if dtypes.is_float(dtype) else (int(val) if dtypes.is_int(dtype) else bool(val)))
|
||||
|
||||
def hook_overflow(dv, fxn):
|
||||
def wfxn(*args):
|
||||
|
@ -58,6 +61,59 @@ def uop_alu_resolve(u:UOp) -> sint:
|
|||
|
||||
def phi_resolve_acc(u:UOp) -> UOp: return u if u.uop is UOps.DEFINE_ACC else phi_resolve_acc(u.vin[0])
|
||||
|
||||
def _match(uop:UOp, pattern:Dict[str, Any], store:Dict[str, UOp]) -> bool:
|
||||
for k,v in pattern.items():
|
||||
if k == "__name__":
|
||||
if v in store and store[v] != uop: return False
|
||||
store[v] = uop
|
||||
elif k == "vin":
|
||||
# only one if it's a tuple
|
||||
# try all permutations if it's a list
|
||||
# repeat if it's a dict
|
||||
for vp in itertools.permutations(v) if isinstance(v, list) else ([v] if isinstance(v, tuple) else [(v,)*len(uop.vin)]):
|
||||
if len(uop.vin) != len(vp): return False
|
||||
new_store = store.copy()
|
||||
if all(_match(uu, vv, new_store) for uu, vv in zip(uop.vin, vp)):
|
||||
for k,v in new_store.items(): store[k] = v
|
||||
return True
|
||||
return False
|
||||
else:
|
||||
if uop.__getattribute__(k) != v: return False
|
||||
return True
|
||||
|
||||
def rewrite(uop:UOp, patterns:List[Tuple[Dict[str, Any], Any]]) -> Optional[UOp]:
|
||||
for p,fxn in patterns:
|
||||
store: Dict[str, UOp] = {}
|
||||
if _match(uop, p, store):
|
||||
return fxn(**store)
|
||||
return None
|
||||
|
||||
constant_fold_patterns = [
|
||||
# const rules
|
||||
({"__name__": "root", "uop": UOps.GEP, "vin": ({"__name__": "c", "uop": UOps.CONST},)}, lambda root, c: UOp.const(root.dtype, c.arg)),
|
||||
({"__name__": "root", "uop": UOps.CAST, "vin": {"__name__": "c", "uop": UOps.CONST}}, lambda root,c: UOp.const(root.dtype, c.arg)),
|
||||
# a phi without loops (len(vin)==2) is a noop
|
||||
({"uop": UOps.PHI, "vin": ({}, {"__name__": "x"})}, lambda x: x),
|
||||
# x+-y -> x-y
|
||||
({"uop": UOps.ALU, "arg": BinaryOps.ADD, "vin": ({"__name__": "x"}, {"__name__": "my", "uop": UOps.ALU, "arg": UnaryOps.NEG})},
|
||||
lambda x, my: UOp(UOps.ALU, x.dtype, (x, my.vin[0]), BinaryOps.SUB)),
|
||||
# a conditional with the same results either way is a noop, also fold const conditionals
|
||||
({"uop": UOps.ALU, "arg": TernaryOps.WHERE, "vin": ({}, {"__name__": "val"}, {"__name__": "val"})}, lambda val: val),
|
||||
({"uop": UOps.ALU, "arg": TernaryOps.WHERE, "vin": ({"__name__": "gate", "uop": UOps.CONST}, {"__name__": "c0"}, {"__name__": "c1"})},
|
||||
lambda gate, c0, c1: c0 if gate.arg else c1),
|
||||
# ** constant folding **
|
||||
({"__name__": "root", "uop": UOps.ALU, "vin": {"uop": UOps.CONST}},
|
||||
lambda root: UOp(UOps.CONST, root.dtype, arg=exec_alu(root.arg, root.dtype, [x.arg for x in root.vin]))),
|
||||
# ** self folding **
|
||||
({"uop": UOps.ALU, "arg": BinaryOps.ADD, "vin": [{"__name__": "x"}, {"uop": UOps.CONST, "arg": 0}]}, lambda x: x), # x+0 -> x or 0+x -> x
|
||||
({"uop": UOps.ALU, "arg": BinaryOps.MUL, "vin": [{"__name__": "x"}, {"uop": UOps.CONST, "arg": 1}]}, lambda x: x), # x*1 -> x or 1*x -> x
|
||||
({"uop": UOps.ALU, "arg": BinaryOps.SUB, "vin": ({"__name__": "x"}, {"uop": UOps.CONST, "arg": 0})}, lambda x: x), # x-0 -> x
|
||||
({"uop": UOps.ALU, "arg": BinaryOps.DIV, "vin": ({"__name__": "x"}, {"uop": UOps.CONST, "arg": 1})}, lambda x: x), # x/1 -> x
|
||||
# ** zero folding **
|
||||
({"uop": UOps.ALU, "arg": BinaryOps.MUL, "vin": [{}, {"__name__": "c", "uop": UOps.CONST, "arg": 0}]}, lambda c: c), # x*0 -> 0 or 0*x -> 0
|
||||
({"uop": UOps.ALU, "arg": BinaryOps.SUB, "vin": ({"__name__": "x"}, {"__name__": "x"})}, lambda x: UOp.const(x.dtype, 0)), # x-x -> 0
|
||||
]
|
||||
|
||||
class UOpGraph:
|
||||
def __init__(self, start_uops:Optional[List[UOp]]=None):
|
||||
# list of uops
|
||||
|
@ -81,33 +137,14 @@ class UOpGraph:
|
|||
|
||||
def add(self, uop:UOps, dtype:Optional[DType]=None, vin:Tuple[UOp, ...]=tuple(), arg:Any=None, cachable=True, insert_before=None,
|
||||
simplify=True) -> UOp:
|
||||
if simplify:
|
||||
if uop is UOps.PHI and len(vin) == 2: return vin[1] # a phi without loops is a noop
|
||||
if uop is UOps.GEP and vin[0].uop is UOps.CONST: return self.add(UOps.CONST, dtype, arg=vin[0].arg, insert_before=insert_before)
|
||||
if uop is UOps.CAST and all(x.uop is UOps.CONST for x in vin) and all_same([x.arg for x in vin]):
|
||||
return self.add(UOps.CONST, dtype, arg=vin[0].arg, insert_before=insert_before)
|
||||
if uop is UOps.ALU:
|
||||
# rewrites. NOTE: the rewritten NEG op is still around...
|
||||
if arg is BinaryOps.ADD and vin[1].uop is UOps.ALU and vin[1].arg is UnaryOps.NEG:
|
||||
return self.add(UOps.ALU, dtype, (vin[0], vin[1].vin[0]), BinaryOps.SUB, cachable, insert_before)
|
||||
# constant folding
|
||||
if arg is TernaryOps.WHERE and vin[1] == vin[2]: return vin[1] # a conditional with the same results either way is a noop
|
||||
if arg is TernaryOps.WHERE and vin[0].uop is UOps.CONST: return vin[1] if vin[0].arg else vin[2]
|
||||
if all(x.uop is UOps.CONST for x in vin):
|
||||
return self.add(UOps.CONST, dtype, arg=exec_alu(arg, dtype, [x.arg for x in vin]), insert_before=insert_before)
|
||||
# zero folding
|
||||
for x in [0,1]:
|
||||
if arg is BinaryOps.ADD and vin[x].uop is UOps.CONST and vin[x].arg == 0.0: return vin[1-x]
|
||||
if arg is BinaryOps.MUL and vin[x].uop is UOps.CONST and vin[x].arg == 1.0: return vin[1-x]
|
||||
if arg is BinaryOps.MUL and vin[x].uop is UOps.CONST and vin[x].arg == 0.0: return vin[x]
|
||||
if arg is BinaryOps.SUB and vin[1].uop is UOps.CONST and vin[1].arg == 0.0: return vin[0]
|
||||
if arg is BinaryOps.DIV and vin[1].uop is UOps.CONST and vin[1].arg == 1.0: return vin[0]
|
||||
|
||||
key = (uop, dtype, vin, arg)
|
||||
ret = UOp(uop, dtype, vin, arg)
|
||||
if simplify and (rewritten:=rewrite(ret, constant_fold_patterns)) is not None:
|
||||
if rewritten in self.uops: return rewritten # ignore cachable
|
||||
ret = rewritten
|
||||
key = (ret.uop, ret.dtype, ret.vin, ret.arg)
|
||||
if insert_before is None: insert_before = len(self.uops)
|
||||
# check if the cached expr is valid with the given insert place.
|
||||
if cachable and (expr:=self.saved_exprs.get(key, None)) is not None and self.uops.index(expr) <= insert_before: return expr
|
||||
ret = UOp(uop, dtype, vin, arg)
|
||||
self.uops.insert(insert_before, ret)
|
||||
if cachable: self.saved_exprs[key] = ret
|
||||
return ret
|
||||
|
|
Loading…
Reference in New Issue