_get_chain -> split_uop [pr] (#7075)

This commit is contained in:
chenyu 2024-10-15 17:31:25 -04:00 committed by GitHub
parent e136cea027
commit 8601115976
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 17 additions and 19 deletions

View File

@ -3,8 +3,8 @@ from typing import Optional, Tuple, Dict, List, cast, TYPE_CHECKING, Any, Defaul
import functools, itertools, operator
from collections import defaultdict
from tinygrad.dtype import dtypes, PtrDType, ImageDType, ConstType
from tinygrad.ops import UnaryOps, BinaryOps, UOp, UOps, identity_element
from tinygrad.ops import UPat, PatternMatcher, graph_rewrite, TernaryOps, symbolic_flat, is_irreducible, _get_chain
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, UOp, UOps, UPat, PatternMatcher
from tinygrad.ops import graph_rewrite, symbolic_flat, is_irreducible, split_uop, identity_element
from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, AMX, prod, partition, all_same
from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, TRANSCENDENTAL_SUPPORTED_DTYPES
@ -105,7 +105,7 @@ def idx_given_valid(valid:UOp, idx:UOp) -> Optional[UOp]:
# first, parse valid into {expr: (lower_bound, upper_bound)}
bounds:DefaultDict[UOp, List[Optional[ConstType]]] = defaultdict(lambda: [None, None])
for stmt in _get_chain(valid, BinaryOps.AND):
for stmt in split_uop(valid, BinaryOps.AND):
expr, is_upper, c = parse_valid(stmt)
bounds[expr][int(is_upper)] = c
@ -116,9 +116,9 @@ def idx_given_valid(valid:UOp, idx:UOp) -> Optional[UOp]:
# every candidate is a set of contrained UOp based on valid, and if every item in a set simplifies the idx into a same output, we rewrite idx
candidates = []
if uop.op is UOps.ALU and uop.arg is BinaryOps.ADD and all(is_irreducible(u) and u.vmin == 0 for u in _get_chain(uop, BinaryOps.ADD)):
if uop.op is UOps.ALU and uop.arg is BinaryOps.ADD and all(is_irreducible(u) and u.vmin == 0 for u in split_uop(uop, BinaryOps.ADD)):
# if the constraint is a simplex: X0 + X1 + ... > 0, we can check if all Xi > 0 simplify into the same output
candidates.append([(Xi, UOp.variable("fake", 1, Xi.vmax, Xi.dtype)) for Xi in _get_chain(uop, BinaryOps.ADD)])
candidates.append([(Xi, UOp.variable("fake", 1, Xi.vmax, Xi.dtype)) for Xi in split_uop(uop, BinaryOps.ADD)])
# try checking the whole clause
candidates.append([(uop, UOp.variable("fake", uop.vmin if v[0] is None else v[0], uop.vmax if v[1] is None else v[1], uop.dtype))])
@ -147,14 +147,12 @@ def simplify_image_load(load:UOp) -> Optional[UOp]:
# can drop valid if idx is out of bound when valid is False
drop_stmt = []
for stmt in _get_chain(valid, BinaryOps.AND):
for stmt in split_uop(valid, BinaryOps.AND):
X, is_upper_bound, c = parse_valid(stmt)
# for X0 + X1 + ... >= 1, check if it's out of bound when Xi = 0 for all i
# TODO: does not need to be add chain?
if not is_upper_bound and c == 1 and X.op is UOps.ALU and X.arg is BinaryOps.ADD and \
all(is_irreducible(u) and u.vmin == 0 for u in _get_chain(X, BinaryOps.ADD)):
testidx = functools.reduce(lambda nowidx,u: replace_uop(nowidx, u, u.const_like(0)), _get_chain(X, BinaryOps.ADD), idx)
if not is_upper_bound and c == 1 and all(is_irreducible(u) and u.vmin == 0 for u in split_uop(X, BinaryOps.ADD)):
testidx = functools.reduce(lambda nowidx,u: replace_uop(nowidx, u, u.const_like(0)), split_uop(X, BinaryOps.ADD), idx)
testidx = graph_rewrite(testidx, sym)
if testidx.src[0].vmax < 0 or testidx.src[1].vmax < 0:
drop_stmt.append(stmt)
@ -171,7 +169,7 @@ def simplify_image_load(load:UOp) -> Optional[UOp]:
break
if not drop_stmt and idx is start_idx: return None
new_valid = functools.reduce(operator.and_, ss) if (ss:=[s for s in _get_chain(valid, BinaryOps.AND) if s not in drop_stmt]) else None
new_valid = functools.reduce(operator.and_, ss) if (ss:=[s for s in split_uop(valid, BinaryOps.AND) if s not in drop_stmt]) else None
return load.replace(src=((buf, idx, invalid_val, new_valid) if new_valid is not None else (buf, idx)))
# ***** optional patterns *****

View File

@ -759,9 +759,9 @@ def type_verify(uops:List[UOp]):
# *** most of symbolic lives here now ***
def _get_chain(x:UOp, sep:BinaryOps):
def split_uop(x:UOp, sep:BinaryOps):
if x.op is UOps.ALU and x.arg is sep:
for s in x.src: yield from _get_chain(s, sep)
for s in x.src: yield from split_uop(s, sep)
else: yield x
def mod_folding(x:UOp, c:int) -> Optional[UOp]:
@ -771,7 +771,7 @@ def mod_folding(x:UOp, c:int) -> Optional[UOp]:
if 0 < c and 0 <= x.vmin and (quotient:=x.vmin//c) == x.vmax//c: return x-quotient*c
remainder, something_changed = [], False
for u in _get_chain(x, BinaryOps.ADD):
for u in split_uop(x, BinaryOps.ADD):
if (factor:=u.const_factor())%c != factor:
divides = u.divides(factor)*(factor%c)
assert divides is not None
@ -791,7 +791,7 @@ def div_folding(x:UOp, c:int) -> Optional[UOp]:
if 0 <= x.vmin and x.vmax < c: return x.const_like(0)
quotient, remainder, rem_const, something_changed, gcd, divisor = [], [], 0, False, c, 1
for u in _get_chain(x, BinaryOps.ADD):
for u in split_uop(x, BinaryOps.ADD):
if u.op is UOps.CONST:
# add all const together first
if rem_const != 0: something_changed = True
@ -830,7 +830,7 @@ def lt_folding(x:UOp, c:int) -> Optional[UOp]:
def fold_unrolled_divs(divs:UOp):
# div pattern in unrolled arange
# example: (x//4+(x+1)//4+(x+2)//4+(x+3)//4 -> x
add_chain, denominator, seen_const, ans = list(_get_chain(divs, BinaryOps.ADD)), None, [], None
add_chain, denominator, seen_const, ans = list(split_uop(divs, BinaryOps.ADD)), None, [], None
for u in add_chain:
if not (u.op is UOps.ALU and u.arg is BinaryOps.IDIV and u.src[1].op is UOps.CONST): return None
if denominator is None: denominator = u.src[1].arg
@ -854,7 +854,7 @@ def canonicalize_simplex(X:UOp) -> Optional[UOp]:
# (X := a0*x0 + a1*x1 + ...) > 0 is equivalent to x0 + x1 + ... > 0 if xi >= 0 and ai > 0 for ints.
# returns x0 + x1 + ... in such case, or None if not
changed, ret = False, []
for u in _get_chain(X, BinaryOps.ADD):
for u in split_uop(X, BinaryOps.ADD):
# assumed the const is the last src of MUL
if u.op is UOps.ALU and u.arg is BinaryOps.MUL and u.src[1].op is UOps.CONST and u.src[1].arg > 0:
changed = True

View File

@ -5,7 +5,7 @@ from typing import Tuple, List, Optional, Dict, Set
from tinygrad.helpers import merge_dicts, getenv
from tinygrad.shape.view import View, strides_for_shape
from tinygrad.dtype import dtypes
from tinygrad.ops import UOp, UOps, BinaryOps, graph_rewrite, _get_chain, symbolic_flat, Variable, sint
from tinygrad.ops import UOp, UOps, BinaryOps, graph_rewrite, split_uop, symbolic_flat, Variable, sint
@dataclass(frozen=True)
class ShapeTracker:
@ -75,7 +75,7 @@ class ShapeTracker:
ret: List[Optional[sint]] = [None] * len(self.shape)
idx, valid = self.to_indexed_uops()
idx = graph_rewrite(idx, symbolic_flat)
for c in _get_chain(idx, BinaryOps.ADD):
for c in split_uop(idx, BinaryOps.ADD):
if c.op is UOps.RANGE: ret[c.arg] = 1
if c.op is UOps.ALU and c.arg is BinaryOps.MUL and c.src[0].op is UOps.RANGE and c.src[1].op is UOps.CONST: ret[c.src[0].arg] = c.src[1].arg
if c.op is UOps.ALU and c.arg is BinaryOps.MUL and c.src[1].op is UOps.RANGE and c.src[0].op is UOps.CONST: ret[c.src[1].arg] = c.src[0].arg