mirror of https://github.com/commaai/tinygrad.git
_get_chain -> split_uop [pr] (#7075)
This commit is contained in:
parent
e136cea027
commit
8601115976
|
@ -3,8 +3,8 @@ from typing import Optional, Tuple, Dict, List, cast, TYPE_CHECKING, Any, Defaul
|
|||
import functools, itertools, operator
|
||||
from collections import defaultdict
|
||||
from tinygrad.dtype import dtypes, PtrDType, ImageDType, ConstType
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, UOp, UOps, identity_element
|
||||
from tinygrad.ops import UPat, PatternMatcher, graph_rewrite, TernaryOps, symbolic_flat, is_irreducible, _get_chain
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, UOp, UOps, UPat, PatternMatcher
|
||||
from tinygrad.ops import graph_rewrite, symbolic_flat, is_irreducible, split_uop, identity_element
|
||||
from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, AMX, prod, partition, all_same
|
||||
from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, TRANSCENDENTAL_SUPPORTED_DTYPES
|
||||
|
||||
|
@ -105,7 +105,7 @@ def idx_given_valid(valid:UOp, idx:UOp) -> Optional[UOp]:
|
|||
|
||||
# first, parse valid into {expr: (lower_bound, upper_bound)}
|
||||
bounds:DefaultDict[UOp, List[Optional[ConstType]]] = defaultdict(lambda: [None, None])
|
||||
for stmt in _get_chain(valid, BinaryOps.AND):
|
||||
for stmt in split_uop(valid, BinaryOps.AND):
|
||||
expr, is_upper, c = parse_valid(stmt)
|
||||
bounds[expr][int(is_upper)] = c
|
||||
|
||||
|
@ -116,9 +116,9 @@ def idx_given_valid(valid:UOp, idx:UOp) -> Optional[UOp]:
|
|||
|
||||
# every candidate is a set of contrained UOp based on valid, and if every item in a set simplifies the idx into a same output, we rewrite idx
|
||||
candidates = []
|
||||
if uop.op is UOps.ALU and uop.arg is BinaryOps.ADD and all(is_irreducible(u) and u.vmin == 0 for u in _get_chain(uop, BinaryOps.ADD)):
|
||||
if uop.op is UOps.ALU and uop.arg is BinaryOps.ADD and all(is_irreducible(u) and u.vmin == 0 for u in split_uop(uop, BinaryOps.ADD)):
|
||||
# if the constraint is a simplex: X0 + X1 + ... > 0, we can check if all Xi > 0 simplify into the same output
|
||||
candidates.append([(Xi, UOp.variable("fake", 1, Xi.vmax, Xi.dtype)) for Xi in _get_chain(uop, BinaryOps.ADD)])
|
||||
candidates.append([(Xi, UOp.variable("fake", 1, Xi.vmax, Xi.dtype)) for Xi in split_uop(uop, BinaryOps.ADD)])
|
||||
# try checking the whole clause
|
||||
candidates.append([(uop, UOp.variable("fake", uop.vmin if v[0] is None else v[0], uop.vmax if v[1] is None else v[1], uop.dtype))])
|
||||
|
||||
|
@ -147,14 +147,12 @@ def simplify_image_load(load:UOp) -> Optional[UOp]:
|
|||
|
||||
# can drop valid if idx is out of bound when valid is False
|
||||
drop_stmt = []
|
||||
for stmt in _get_chain(valid, BinaryOps.AND):
|
||||
for stmt in split_uop(valid, BinaryOps.AND):
|
||||
X, is_upper_bound, c = parse_valid(stmt)
|
||||
|
||||
# for X0 + X1 + ... >= 1, check if it's out of bound when Xi = 0 for all i
|
||||
# TODO: does not need to be add chain?
|
||||
if not is_upper_bound and c == 1 and X.op is UOps.ALU and X.arg is BinaryOps.ADD and \
|
||||
all(is_irreducible(u) and u.vmin == 0 for u in _get_chain(X, BinaryOps.ADD)):
|
||||
testidx = functools.reduce(lambda nowidx,u: replace_uop(nowidx, u, u.const_like(0)), _get_chain(X, BinaryOps.ADD), idx)
|
||||
if not is_upper_bound and c == 1 and all(is_irreducible(u) and u.vmin == 0 for u in split_uop(X, BinaryOps.ADD)):
|
||||
testidx = functools.reduce(lambda nowidx,u: replace_uop(nowidx, u, u.const_like(0)), split_uop(X, BinaryOps.ADD), idx)
|
||||
testidx = graph_rewrite(testidx, sym)
|
||||
if testidx.src[0].vmax < 0 or testidx.src[1].vmax < 0:
|
||||
drop_stmt.append(stmt)
|
||||
|
@ -171,7 +169,7 @@ def simplify_image_load(load:UOp) -> Optional[UOp]:
|
|||
break
|
||||
|
||||
if not drop_stmt and idx is start_idx: return None
|
||||
new_valid = functools.reduce(operator.and_, ss) if (ss:=[s for s in _get_chain(valid, BinaryOps.AND) if s not in drop_stmt]) else None
|
||||
new_valid = functools.reduce(operator.and_, ss) if (ss:=[s for s in split_uop(valid, BinaryOps.AND) if s not in drop_stmt]) else None
|
||||
return load.replace(src=((buf, idx, invalid_val, new_valid) if new_valid is not None else (buf, idx)))
|
||||
|
||||
# ***** optional patterns *****
|
||||
|
|
|
@ -759,9 +759,9 @@ def type_verify(uops:List[UOp]):
|
|||
|
||||
# *** most of symbolic lives here now ***
|
||||
|
||||
def _get_chain(x:UOp, sep:BinaryOps):
|
||||
def split_uop(x:UOp, sep:BinaryOps):
|
||||
if x.op is UOps.ALU and x.arg is sep:
|
||||
for s in x.src: yield from _get_chain(s, sep)
|
||||
for s in x.src: yield from split_uop(s, sep)
|
||||
else: yield x
|
||||
|
||||
def mod_folding(x:UOp, c:int) -> Optional[UOp]:
|
||||
|
@ -771,7 +771,7 @@ def mod_folding(x:UOp, c:int) -> Optional[UOp]:
|
|||
if 0 < c and 0 <= x.vmin and (quotient:=x.vmin//c) == x.vmax//c: return x-quotient*c
|
||||
|
||||
remainder, something_changed = [], False
|
||||
for u in _get_chain(x, BinaryOps.ADD):
|
||||
for u in split_uop(x, BinaryOps.ADD):
|
||||
if (factor:=u.const_factor())%c != factor:
|
||||
divides = u.divides(factor)*(factor%c)
|
||||
assert divides is not None
|
||||
|
@ -791,7 +791,7 @@ def div_folding(x:UOp, c:int) -> Optional[UOp]:
|
|||
if 0 <= x.vmin and x.vmax < c: return x.const_like(0)
|
||||
|
||||
quotient, remainder, rem_const, something_changed, gcd, divisor = [], [], 0, False, c, 1
|
||||
for u in _get_chain(x, BinaryOps.ADD):
|
||||
for u in split_uop(x, BinaryOps.ADD):
|
||||
if u.op is UOps.CONST:
|
||||
# add all const together first
|
||||
if rem_const != 0: something_changed = True
|
||||
|
@ -830,7 +830,7 @@ def lt_folding(x:UOp, c:int) -> Optional[UOp]:
|
|||
def fold_unrolled_divs(divs:UOp):
|
||||
# div pattern in unrolled arange
|
||||
# example: (x//4+(x+1)//4+(x+2)//4+(x+3)//4 -> x
|
||||
add_chain, denominator, seen_const, ans = list(_get_chain(divs, BinaryOps.ADD)), None, [], None
|
||||
add_chain, denominator, seen_const, ans = list(split_uop(divs, BinaryOps.ADD)), None, [], None
|
||||
for u in add_chain:
|
||||
if not (u.op is UOps.ALU and u.arg is BinaryOps.IDIV and u.src[1].op is UOps.CONST): return None
|
||||
if denominator is None: denominator = u.src[1].arg
|
||||
|
@ -854,7 +854,7 @@ def canonicalize_simplex(X:UOp) -> Optional[UOp]:
|
|||
# (X := a0*x0 + a1*x1 + ...) > 0 is equivalent to x0 + x1 + ... > 0 if xi >= 0 and ai > 0 for ints.
|
||||
# returns x0 + x1 + ... in such case, or None if not
|
||||
changed, ret = False, []
|
||||
for u in _get_chain(X, BinaryOps.ADD):
|
||||
for u in split_uop(X, BinaryOps.ADD):
|
||||
# assumed the const is the last src of MUL
|
||||
if u.op is UOps.ALU and u.arg is BinaryOps.MUL and u.src[1].op is UOps.CONST and u.src[1].arg > 0:
|
||||
changed = True
|
||||
|
|
|
@ -5,7 +5,7 @@ from typing import Tuple, List, Optional, Dict, Set
|
|||
from tinygrad.helpers import merge_dicts, getenv
|
||||
from tinygrad.shape.view import View, strides_for_shape
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.ops import UOp, UOps, BinaryOps, graph_rewrite, _get_chain, symbolic_flat, Variable, sint
|
||||
from tinygrad.ops import UOp, UOps, BinaryOps, graph_rewrite, split_uop, symbolic_flat, Variable, sint
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ShapeTracker:
|
||||
|
@ -75,7 +75,7 @@ class ShapeTracker:
|
|||
ret: List[Optional[sint]] = [None] * len(self.shape)
|
||||
idx, valid = self.to_indexed_uops()
|
||||
idx = graph_rewrite(idx, symbolic_flat)
|
||||
for c in _get_chain(idx, BinaryOps.ADD):
|
||||
for c in split_uop(idx, BinaryOps.ADD):
|
||||
if c.op is UOps.RANGE: ret[c.arg] = 1
|
||||
if c.op is UOps.ALU and c.arg is BinaryOps.MUL and c.src[0].op is UOps.RANGE and c.src[1].op is UOps.CONST: ret[c.src[0].arg] = c.src[1].arg
|
||||
if c.op is UOps.ALU and c.arg is BinaryOps.MUL and c.src[1].op is UOps.RANGE and c.src[0].op is UOps.CONST: ret[c.src[1].arg] = c.src[0].arg
|
||||
|
|
Loading…
Reference in New Issue