constant fold pattern match (#3696)

* constant fold pattern match

* match

* better match

* fix bug in pattern

* more folding
This commit is contained in:
George Hotz 2024-03-12 08:48:07 -07:00 committed by GitHub
parent dd1a1c12df
commit 6755a9254f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 65 additions and 28 deletions

View File

@ -1,8 +1,8 @@
from __future__ import annotations
import functools, math, operator
import functools, math, operator, itertools
from typing import List, Set, Optional, Tuple, Any, Dict, DefaultDict, Callable, cast
from collections import defaultdict
from tinygrad.helpers import DEBUG, flatten, all_same
from tinygrad.helpers import DEBUG, flatten
from tinygrad.dtype import dtypes, DType
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
from tinygrad.shape.symbolic import sint, Variable, Node, NumNode, MulNode, DivNode, SumNode
@ -20,10 +20,13 @@ class UOps(Enum):
class UOp:
uop: UOps
dtype: Optional[DType]
vin: Tuple[UOp, ...]
arg: Any
vin: Tuple[UOp, ...] = tuple()
arg: Any = None
def __repr__(self):
return f"{str(self.uop):20s}: {str(self.dtype) if self.dtype is not None else '':25s} {str([x.uop for x in self.vin]):32s} {self.arg}"
@staticmethod
def const(dtype, val):
return UOp(UOps.CONST, dtype, arg=float(val) if dtypes.is_float(dtype) else (int(val) if dtypes.is_int(dtype) else bool(val)))
def hook_overflow(dv, fxn):
def wfxn(*args):
@ -58,6 +61,59 @@ def uop_alu_resolve(u:UOp) -> sint:
def phi_resolve_acc(u:UOp) -> UOp: return u if u.uop is UOps.DEFINE_ACC else phi_resolve_acc(u.vin[0])
def _match(uop:UOp, pattern:Dict[str, Any], store:Dict[str, UOp]) -> bool:
for k,v in pattern.items():
if k == "__name__":
if v in store and store[v] != uop: return False
store[v] = uop
elif k == "vin":
# only one if it's a tuple
# try all permutations if it's a list
# repeat if it's a dict
for vp in itertools.permutations(v) if isinstance(v, list) else ([v] if isinstance(v, tuple) else [(v,)*len(uop.vin)]):
if len(uop.vin) != len(vp): return False
new_store = store.copy()
if all(_match(uu, vv, new_store) for uu, vv in zip(uop.vin, vp)):
for k,v in new_store.items(): store[k] = v
return True
return False
else:
if uop.__getattribute__(k) != v: return False
return True
def rewrite(uop:UOp, patterns:List[Tuple[Dict[str, Any], Any]]) -> Optional[UOp]:
for p,fxn in patterns:
store: Dict[str, UOp] = {}
if _match(uop, p, store):
return fxn(**store)
return None
constant_fold_patterns = [
# const rules
({"__name__": "root", "uop": UOps.GEP, "vin": ({"__name__": "c", "uop": UOps.CONST},)}, lambda root, c: UOp.const(root.dtype, c.arg)),
({"__name__": "root", "uop": UOps.CAST, "vin": {"__name__": "c", "uop": UOps.CONST}}, lambda root,c: UOp.const(root.dtype, c.arg)),
# a phi without loops (len(vin)==2) is a noop
({"uop": UOps.PHI, "vin": ({}, {"__name__": "x"})}, lambda x: x),
# x+-y -> x-y
({"uop": UOps.ALU, "arg": BinaryOps.ADD, "vin": ({"__name__": "x"}, {"__name__": "my", "uop": UOps.ALU, "arg": UnaryOps.NEG})},
lambda x, my: UOp(UOps.ALU, x.dtype, (x, my.vin[0]), BinaryOps.SUB)),
# a conditional with the same results either way is a noop, also fold const conditionals
({"uop": UOps.ALU, "arg": TernaryOps.WHERE, "vin": ({}, {"__name__": "val"}, {"__name__": "val"})}, lambda val: val),
({"uop": UOps.ALU, "arg": TernaryOps.WHERE, "vin": ({"__name__": "gate", "uop": UOps.CONST}, {"__name__": "c0"}, {"__name__": "c1"})},
lambda gate, c0, c1: c0 if gate.arg else c1),
# ** constant folding **
({"__name__": "root", "uop": UOps.ALU, "vin": {"uop": UOps.CONST}},
lambda root: UOp(UOps.CONST, root.dtype, arg=exec_alu(root.arg, root.dtype, [x.arg for x in root.vin]))),
# ** self folding **
({"uop": UOps.ALU, "arg": BinaryOps.ADD, "vin": [{"__name__": "x"}, {"uop": UOps.CONST, "arg": 0}]}, lambda x: x), # x+0 -> x or 0+x -> x
({"uop": UOps.ALU, "arg": BinaryOps.MUL, "vin": [{"__name__": "x"}, {"uop": UOps.CONST, "arg": 1}]}, lambda x: x), # x*1 -> x or 1*x -> x
({"uop": UOps.ALU, "arg": BinaryOps.SUB, "vin": ({"__name__": "x"}, {"uop": UOps.CONST, "arg": 0})}, lambda x: x), # x-0 -> x
({"uop": UOps.ALU, "arg": BinaryOps.DIV, "vin": ({"__name__": "x"}, {"uop": UOps.CONST, "arg": 1})}, lambda x: x), # x/1 -> x
# ** zero folding **
({"uop": UOps.ALU, "arg": BinaryOps.MUL, "vin": [{}, {"__name__": "c", "uop": UOps.CONST, "arg": 0}]}, lambda c: c), # x*0 -> 0 or 0*x -> 0
({"uop": UOps.ALU, "arg": BinaryOps.SUB, "vin": ({"__name__": "x"}, {"__name__": "x"})}, lambda x: UOp.const(x.dtype, 0)), # x-x -> 0
]
class UOpGraph:
def __init__(self, start_uops:Optional[List[UOp]]=None):
# list of uops
@ -81,33 +137,14 @@ class UOpGraph:
def add(self, uop:UOps, dtype:Optional[DType]=None, vin:Tuple[UOp, ...]=tuple(), arg:Any=None, cachable=True, insert_before=None,
simplify=True) -> UOp:
if simplify:
if uop is UOps.PHI and len(vin) == 2: return vin[1] # a phi without loops is a noop
if uop is UOps.GEP and vin[0].uop is UOps.CONST: return self.add(UOps.CONST, dtype, arg=vin[0].arg, insert_before=insert_before)
if uop is UOps.CAST and all(x.uop is UOps.CONST for x in vin) and all_same([x.arg for x in vin]):
return self.add(UOps.CONST, dtype, arg=vin[0].arg, insert_before=insert_before)
if uop is UOps.ALU:
# rewrites. NOTE: the rewritten NEG op is still around...
if arg is BinaryOps.ADD and vin[1].uop is UOps.ALU and vin[1].arg is UnaryOps.NEG:
return self.add(UOps.ALU, dtype, (vin[0], vin[1].vin[0]), BinaryOps.SUB, cachable, insert_before)
# constant folding
if arg is TernaryOps.WHERE and vin[1] == vin[2]: return vin[1] # a conditional with the same results either way is a noop
if arg is TernaryOps.WHERE and vin[0].uop is UOps.CONST: return vin[1] if vin[0].arg else vin[2]
if all(x.uop is UOps.CONST for x in vin):
return self.add(UOps.CONST, dtype, arg=exec_alu(arg, dtype, [x.arg for x in vin]), insert_before=insert_before)
# zero folding
for x in [0,1]:
if arg is BinaryOps.ADD and vin[x].uop is UOps.CONST and vin[x].arg == 0.0: return vin[1-x]
if arg is BinaryOps.MUL and vin[x].uop is UOps.CONST and vin[x].arg == 1.0: return vin[1-x]
if arg is BinaryOps.MUL and vin[x].uop is UOps.CONST and vin[x].arg == 0.0: return vin[x]
if arg is BinaryOps.SUB and vin[1].uop is UOps.CONST and vin[1].arg == 0.0: return vin[0]
if arg is BinaryOps.DIV and vin[1].uop is UOps.CONST and vin[1].arg == 1.0: return vin[0]
key = (uop, dtype, vin, arg)
ret = UOp(uop, dtype, vin, arg)
if simplify and (rewritten:=rewrite(ret, constant_fold_patterns)) is not None:
if rewritten in self.uops: return rewritten # ignore cachable
ret = rewritten
key = (ret.uop, ret.dtype, ret.vin, ret.arg)
if insert_before is None: insert_before = len(self.uops)
# check if the cached expr is valid with the given insert place.
if cachable and (expr:=self.saved_exprs.get(key, None)) is not None and self.uops.index(expr) <= insert_before: return expr
ret = UOp(uop, dtype, vin, arg)
self.uops.insert(insert_before, ret)
if cachable: self.saved_exprs[key] = ret
return ret