mirror of https://github.com/commaai/tinygrad.git
move some pm rules to uopgraph.py [run_process_replay] (#6831)
* move some pm rules to uopgraph.py [run_process_replay] * move more * move lt and clean * end maybe * put back
This commit is contained in:
parent
0cb82f308c
commit
e907b25792
|
@ -440,7 +440,7 @@ class TestIndexingOrdering(unittest.TestCase):
|
|||
|
||||
class TestUPatHelpers(unittest.TestCase):
|
||||
def test_location(self):
|
||||
self.assertEqual(sym.patterns[0][0].location[0].split("/")[-1], "uopgraph.py")
|
||||
self.assertEqual(sym.patterns[-1][0].location[0].split("/")[-1], "uopgraph.py")
|
||||
self.assertEqual(reduceop_fusor.patterns[0][0].location[0].split("/")[-1], "schedule.py")
|
||||
self.assertEqual(spec.patterns[0][0].location[0].split("/")[-1], "ops.py")
|
||||
with self.assertRaises(AssertionError): # TODO: location UPat files created in test/*?
|
||||
|
|
|
@ -64,12 +64,28 @@ class TestUOpResolve(unittest.TestCase):
|
|||
u = UOp.define_var("b", dtypes.bool, False, True) & False
|
||||
self.assertFalse(u)
|
||||
|
||||
def test_max(self):
|
||||
x = UOp.define_var("x", dtypes.pyint, 1, 10)
|
||||
y = UOp.define_var("y", dtypes.pyint, 5, 10)
|
||||
u = x.max(y)
|
||||
self.assertTrue(u < 20)
|
||||
self.assertFalse(u < 3)
|
||||
|
||||
def test_x_lt_x(self):
|
||||
x = UOp.define_var("i", dtypes.pyint, 1, 10)
|
||||
self.assertFalse(x < x)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_x_lt_xp1(self):
|
||||
x = UOp.define_var("i", dtypes.pyint, 1, 10)
|
||||
self.assertTrue(x < (x+1))
|
||||
|
||||
def test_and_true(self):
|
||||
with self.assertRaises(ValueError):
|
||||
u = UOp.define_var("b", dtypes.bool, False, True) & True
|
||||
self.assertFalse(u)
|
||||
|
||||
@unittest.skip("too fancy to be supported right now")
|
||||
@unittest.expectedFailure
|
||||
def test_var_cmp_range(self):
|
||||
v = UOp.define_var("i", dtypes.pyint, 1, 10)
|
||||
u = v > 4 or v < 6
|
||||
|
|
|
@ -3,8 +3,8 @@ from typing import Optional, Tuple, Dict, List, Set, cast, TYPE_CHECKING, Any, D
|
|||
import functools, itertools, heapq, math, operator
|
||||
from collections import defaultdict
|
||||
from tinygrad.dtype import dtypes, PtrDType, ImageDType, ConstType, DType
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, exec_alu, UOp, UOps, END_FOR_UOP, type_verify, print_uops, identity_element
|
||||
from tinygrad.ops import UPat, PatternMatcher, graph_rewrite, TernaryOps
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, UOp, UOps, END_FOR_UOP, type_verify, print_uops, identity_element
|
||||
from tinygrad.ops import UPat, PatternMatcher, graph_rewrite, TernaryOps, simple_pm
|
||||
from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, AMX, prod, CI, partition, all_same
|
||||
from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, TRANSCENDENTAL_SUPPORTED_DTYPES
|
||||
if TYPE_CHECKING: from tinygrad.renderer import Renderer
|
||||
|
@ -370,11 +370,7 @@ def no_vectorized_wmma(wmma:UOp):
|
|||
return UOp(UOps.VECTORIZE, wmma.dtype, tuple(wmma_ex))
|
||||
|
||||
# this is symbolic 2.0
|
||||
sym = PatternMatcher([
|
||||
# bool MUL is AND, ADD/MAX is OR. prevents other rules to rewrite bool ADD/MUL incorrectly
|
||||
(UPat.var('x', dtype=dtypes.bool) * UPat.var('y'), lambda x,y: x&y),
|
||||
(UPat.var('x', dtype=dtypes.bool) + UPat.var('y'), lambda x,y: x|y),
|
||||
(UPat.var('x', dtype=dtypes.bool).max(UPat.var('y')), lambda x,y: x|y),
|
||||
sym = simple_pm+PatternMatcher([
|
||||
# self ASSIGN is just self
|
||||
(UPat(UOps.ASSIGN, src=(UPat.var('x'), UPat.var('x'))), lambda x: x),
|
||||
# ASSIGN to global is just self
|
||||
|
@ -432,57 +428,18 @@ sym = PatternMatcher([
|
|||
(UPat(UOps.REDUCE, src=(UPat.var("idx").eq(UPat(UOps.RANGE, name="rng")).where(
|
||||
UPat(UOps.LOAD, src=(UPat.var("buf"), UPat.any(UPat.var("add")+UPat.var("mul")*UPat(UOps.RANGE, name="rng"), UPat(UOps.RANGE, name="rng"))),
|
||||
name="ld"), UPat.const(None, 0.0)),), arg=BinaryOps.ADD, name="reduce", allow_any_len=True), index_collapse),
|
||||
# max folding
|
||||
(UPat.max(UPat.var("x"), UPat.var("y")), lambda x,y: x if x.vmin >= y.vmax else y if x.vmax <= y.vmin else None),
|
||||
# GEP/CAST const rules
|
||||
(UPat(UOps.CAST, name="root", src=UPat.cvar("c")), lambda root, c: root.const_like(c.arg)),
|
||||
# a conditional with the same results either way is a noop, also fold const conditionals
|
||||
(UPat.var().where(UPat.var("val"), UPat.var("val")), lambda val: val),
|
||||
(UPat.cvar("gate", vec=False).where(UPat.var("c0"), UPat.var("c1")), lambda gate, c0, c1: c0 if gate.arg else c1),
|
||||
# ** constant folding **
|
||||
(UPat(UOps.ALU, name="root", src=UPat((UOps.VCONST, UOps.CONST))),
|
||||
lambda root: root.const_like(exec_alu(root.arg, root.dtype, [x.arg for x in root.src]))),
|
||||
# ** self folding **
|
||||
# cast NOOP (NOTE: it's str to deal with PtrDType)
|
||||
(UPat(UOps.CAST, name="root"), lambda root: root.src[0] if str(root.dtype) == str(root.src[0].dtype) else None),
|
||||
(UPat(UOps.REDUCE, src=(UPat.var("x"),)), lambda x: x), # a REDUCE without ranges is a NOOP
|
||||
(UPat.var("x") + 0, lambda x: x), # x+0 -> x
|
||||
(UPat.var("x") * 1, lambda x: x), # x*1 -> x
|
||||
(UPat.var("x") // UPat.var("x"), lambda x: x.const_like(1)), # x//x -> 1
|
||||
(UPat.var("x") // 1, lambda x: x), # x//1 -> x
|
||||
(UPat.var("x") // -1, lambda x: -x), # x//-1 -> -x
|
||||
(UPat.var("x") / UPat.var("x"), lambda x: x.const_like(1)), # x/x -> 1
|
||||
((UPat.var("x") * UPat.var("x2")) / UPat.var("x2"), lambda x,x2: x), # (x*x2)/x2 -> x
|
||||
(UPat.var("x", dtype=dtypes.bool) & UPat.cvar("c", vec=False), lambda x,c: x if c.arg else c),
|
||||
(UPat.var("x", dtype=dtypes.bool) | UPat.cvar("c", vec=False), lambda x,c: c if c.arg else x),
|
||||
# ** zero folding **
|
||||
# x*0 -> 0 or 0*x -> 0
|
||||
# if x is nan or inf it should render the nan value.
|
||||
# NOTE: this can be wrong for loaded NaN
|
||||
(UPat.var("x") * 0, lambda x: x.const_like(float("nan") if isinstance(x.arg, float) and (math.isnan(x.arg) or math.isinf(x.arg)) else 0)),
|
||||
# ALU min==max -> CONST (slow!)
|
||||
(UPat(UOps.ALU, name="x"), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None),
|
||||
# ** load/store folding **
|
||||
(UPat.store(UPat.var("buf"), UPat.var("idx"), UPat.load(UPat.var("buf"), UPat.var("idx"))), lambda buf,idx:UOp(UOps.NOOP)),
|
||||
# ** two stage add/mul folding **
|
||||
((UPat.var("x") + UPat.cvar("c1")) + UPat.cvar("c2"), lambda x,c1,c2: x+(c1+c2)),
|
||||
((UPat.var("x") * UPat.cvar("c1")) * UPat.cvar("c2"), lambda x,c1,c2: x*(c1*c2)),
|
||||
((UPat.var("x") & UPat.cvar("c1")) & UPat.cvar("c2"), lambda x,c1,c2: x&(c1&c2)),
|
||||
((UPat.var("x") | UPat.cvar("c1")) | UPat.cvar("c2"), lambda x,c1,c2: x|(c1|c2)),
|
||||
# *** rules from symbolic ***
|
||||
# ** lt **
|
||||
# c0*x<c1 for positive int c0,c1
|
||||
((UPat.cvar("c0", vec=False)*UPat.var("x", dtypes.ints)).lt(UPat.cvar("c1", vec=False)),
|
||||
lambda x,c0,c1: x.lt(math.ceil(c1.arg/c0.arg)) if c0.arg > 0 and c1.arg > 0 else None),
|
||||
# c0*x<c1 for negative int c0 and non-positive c1
|
||||
((UPat.cvar("c0", vec=False)*UPat.var("x", dtypes.ints)).lt(UPat.cvar("c1", vec=False)),
|
||||
lambda x,c0,c1: (-x).lt(-(math.floor(-c1.arg/-c0.arg))) if c0.arg < 0 and c0.arg != -1 and c1.arg <= 0 else None),
|
||||
# x//c0<c1 for positive int c0
|
||||
((UPat.var("x", dtypes.ints)//UPat.cvar("c0", vec=False)).lt(UPat.cvar("c1", vec=False)),
|
||||
lambda x,c0,c1: x.lt(c1.arg*c0.arg) if c0.arg > 0 else None),
|
||||
# mul add lt
|
||||
(((UPat.cvar("c0", vec=False)*UPat.var("x"))+UPat.var("x2")).lt(UPat.cvar("c1", vec=False)),
|
||||
lambda x,x2,c0,c1: x.lt(c1//c0) if c1.arg % c0.arg == 0 and c0.arg > x2.vmax and x2.vmin >= 0 else None),
|
||||
# generic lt folding
|
||||
(UPat.var("x", dtypes.sints).lt(UPat.cvar("c", vec=False)), lambda x,c: lt_folding(x, c.arg) if 0 < c.arg else None),
|
||||
# canonicalize a simplex with positive coefficients > 0
|
||||
|
@ -494,22 +451,16 @@ sym = PatternMatcher([
|
|||
# ** mod **
|
||||
# mod folding
|
||||
(UPat.var("x") % UPat.cvar("c", vec=False), lambda x,c: newx if 0 < c.arg and (newx:=mod_folding(x,c.arg)) is not None else None),
|
||||
# ** combine terms **
|
||||
(UPat.var("x")%UPat.cvar("c")+(UPat.var("x")//UPat.cvar("c"))*UPat.cvar("c"), lambda x,c: x), # (x%c)+(x//c)*c = x
|
||||
# ** combine terms (opinionated) **
|
||||
(UPat.var("x") * UPat.cvar("c0") + UPat.var("x") * UPat.cvar("c1"), lambda x,c0,c1: x*(c0+c1)), # (x*c0)+(x*c1) -> x*(c0+c1)
|
||||
(UPat.var("x") + UPat.var("x") * UPat.cvar("c"), lambda x,c: x*(c+1)), # (x+x*c)-> x*(c+1)
|
||||
(UPat.var("x") + UPat.var("x"), lambda x: x*2), # (x+x)-> x*2
|
||||
((UPat.var("x") // UPat.cvar("c0")) // UPat.cvar("c1"), lambda x,c0,c1: x//(c0*c1)), # (x//c0)//c1 -> x//(c0*c1)
|
||||
((UPat.var("x") / UPat.var("x2")) / UPat.var("x3"), lambda x,x2,x3: x/(x2*x3)), # (x/x2)/x3 -> x/(x2*x3)
|
||||
(-1 * (UPat.var("x") + UPat.var("y")), lambda x,y: (-x)+(-y)), # -(x+y) -> -x + -y
|
||||
((UPat.cvar("c0") + UPat.var("x")).lt(UPat.cvar("c1")), lambda x,c0,c1: UOp.lt(x, c1-c0)), # c0 + x < c1 -> x < c1 - c0
|
||||
# (x+y)*c -> x*c+y*c. only for int, float has inf*0=nan issue
|
||||
((UPat.var("x", dtypes.ints) + UPat.var("y")) * UPat.cvar("c"), lambda x,y,c: x*c+y*c),
|
||||
# x!=0 -> (bool)x
|
||||
(UPat.var("x").ne(0), lambda x: x.cast(dtypes.bool.vec(x.dtype.count))),
|
||||
# bitwise noops
|
||||
((UPat.var("x") & UPat.var("x")), lambda x: x),
|
||||
((UPat.var("x") | UPat.var("x")), lambda x: x),
|
||||
# TODO: can do the invert of this (flip alt/load) when we fix double ops
|
||||
(UPat.store(UPat.var("buf"), UPat.var("idx"), UPat.var("gate").where(UPat.var("alt"), UPat.load(UPat.var("buf"), UPat.var("idx")))),
|
||||
lambda buf, idx, gate, alt: UOp.store(buf, idx, alt, gate)),
|
||||
|
|
|
@ -173,8 +173,9 @@ class UOp(MathTrait):
|
|||
# *** uop evaluation ***
|
||||
def _eval(self, dtype, expected_type) -> ConstType:
|
||||
assert self.dtype in dtype, f"eval with wrong dtype {self}"
|
||||
vmin, vmax = self._min_max
|
||||
if vmin != vmax: raise ValueError(f"eval failed to be a single number, range is {vmin} to {vmax}")
|
||||
simple_self = graph_rewrite(self, simple_pm)
|
||||
vmin, vmax = simple_self._min_max
|
||||
if vmin != vmax: raise ValueError(f"eval failed to be a single number, range is {vmin} to {vmax} in {simple_self}")
|
||||
assert type(vmin) is expected_type, f"vmin is wrong dtype {vmin} != {expected_type}"
|
||||
return vmin
|
||||
def __bool__(self): return self._eval((dtypes.bool,), bool)
|
||||
|
@ -685,3 +686,57 @@ def type_verify(uops:List[UOp]):
|
|||
for u in uops:
|
||||
chk = cast(bool, spec.rewrite(u))
|
||||
assert chk is True, f"UOp verification failed on {u.op} {u.dtype} {len(u.src)} {[x.op for x in u.src]} {u.arg}"
|
||||
|
||||
simple_pm = PatternMatcher([
|
||||
# bool MUL is AND, ADD/MAX is OR. prevents other rules to rewrite bool ADD/MUL incorrectly
|
||||
(UPat.var('x', dtype=dtypes.bool) * UPat.var('y'), lambda x,y: x&y),
|
||||
(UPat.var('x', dtype=dtypes.bool) + UPat.var('y'), lambda x,y: x|y),
|
||||
(UPat.var('x', dtype=dtypes.bool).max(UPat.var('y')), lambda x,y: x|y),
|
||||
# ** self folding **
|
||||
(UPat.var("x") + 0, lambda x: x), # x+0 -> x
|
||||
(UPat.var("x") * 1, lambda x: x), # x*1 -> x
|
||||
(UPat.var("x") // UPat.var("x"), lambda x: x.const_like(1)), # x//x -> 1
|
||||
(UPat.var("x") // 1, lambda x: x), # x//1 -> x
|
||||
(UPat.var("x") // -1, lambda x: -x), # x//-1 -> -x
|
||||
(UPat.var("x") / UPat.var("x"), lambda x: x.const_like(1)), # x/x -> 1
|
||||
(UPat.var("x") < UPat.var("x"), lambda x: UOp.const(dtypes.bool.vec(x.dtype.count), False)), # x < x -> False
|
||||
((UPat.var("x") * UPat.var("x2")) / UPat.var("x2"), lambda x,x2: x), # (x*x2)/x2 -> x
|
||||
(UPat.var("x", dtype=dtypes.bool) & UPat.cvar("c", vec=False), lambda x,c: x if c.arg else c),
|
||||
(UPat.var("x", dtype=dtypes.bool) | UPat.cvar("c", vec=False), lambda x,c: c if c.arg else x),
|
||||
((UPat.var("x") & UPat.var("x")), lambda x: x),
|
||||
((UPat.var("x") | UPat.var("x")), lambda x: x),
|
||||
# ** zero folding **
|
||||
# x*0 -> 0 or 0*x -> 0
|
||||
# if x is nan or inf it should render the nan value.
|
||||
# NOTE: this can be wrong for loaded NaN
|
||||
(UPat.var("x") * 0, lambda x: x.const_like(float("nan") if isinstance(x.arg, float) and (math.isnan(x.arg) or math.isinf(x.arg)) else 0)),
|
||||
# ** constant folding **
|
||||
(UPat(UOps.ALU, name="root", src=UPat((UOps.VCONST, UOps.CONST))),
|
||||
lambda root: root.const_like(exec_alu(root.arg, root.dtype, [x.arg for x in root.src]))),
|
||||
# ALU min==max -> CONST (slow!)
|
||||
(UPat(UOps.ALU, name="x"), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None),
|
||||
# max folding
|
||||
(UPat.max(UPat.var("x"), UPat.var("y")), lambda x,y: x if x.vmin >= y.vmax else y if x.vmax <= y.vmin else None),
|
||||
# ** two stage ALU folding **
|
||||
((UPat.var("x") + UPat.cvar("c1")) + UPat.cvar("c2"), lambda x,c1,c2: x+(c1+c2)),
|
||||
((UPat.var("x") * UPat.cvar("c1")) * UPat.cvar("c2"), lambda x,c1,c2: x*(c1*c2)),
|
||||
((UPat.var("x") & UPat.cvar("c1")) & UPat.cvar("c2"), lambda x,c1,c2: x&(c1&c2)),
|
||||
((UPat.var("x") | UPat.cvar("c1")) | UPat.cvar("c2"), lambda x,c1,c2: x|(c1|c2)),
|
||||
((UPat.cvar("c0") + UPat.var("x")) < UPat.cvar("c1"), lambda x,c0,c1: x<(c1-c0)), # c0 + x < c1 -> x < c1 - c0
|
||||
((UPat.var("x") // UPat.cvar("c1")) // UPat.cvar("c2"), lambda x,c1,c2: x//(c1*c2)), # (x//c1)//c2 -> x//(c1*c2)
|
||||
# mod folding
|
||||
(UPat.var("x")%UPat.cvar("c")+(UPat.var("x")//UPat.cvar("c"))*UPat.cvar("c"), lambda x,c: x), # (x%c)+(x//c)*c = x
|
||||
# ** lt **
|
||||
# c0*x<c1 for positive int c0,c1
|
||||
((UPat.cvar("c0", vec=False)*UPat.var("x", dtype=dtypes.ints)).lt(UPat.cvar("c1", vec=False)),
|
||||
lambda x,c0,c1: x.lt(math.ceil(c1.arg/c0.arg)) if c0.arg > 0 and c1.arg > 0 else None),
|
||||
# c0*x<c1 for negative int c0 and non-positive c1
|
||||
((UPat.cvar("c0", vec=False)*UPat.var("x", dtype=dtypes.ints)).lt(UPat.cvar("c1", vec=False)),
|
||||
lambda x,c0,c1: (-x).lt(-(math.floor(-c1.arg/-c0.arg))) if c0.arg < 0 and c0.arg != -1 and c1.arg <= 0 else None),
|
||||
# x//c0<c1 for positive int c0
|
||||
((UPat.var("x", dtype=dtypes.ints)//UPat.cvar("c0", vec=False)).lt(UPat.cvar("c1", vec=False)),
|
||||
lambda x,c0,c1: x.lt(c1.arg*c0.arg) if c0.arg > 0 else None),
|
||||
# mul add lt
|
||||
(((UPat.cvar("c0", vec=False)*UPat.var("x"))+UPat.var("x2")).lt(UPat.cvar("c1", vec=False)),
|
||||
lambda x,x2,c0,c1: x.lt(c1//c0) if c1.arg % c0.arg == 0 and c0.arg > x2.vmax and x2.vmin >= 0 else None),
|
||||
])
|
||||
|
|
Loading…
Reference in New Issue