cleanups around idx_given_valid [pr] (#7074)

This commit is contained in:
chenyu 2024-10-15 16:59:01 -04:00 committed by GitHub
parent 545e79969f
commit e136cea027
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 9 additions and 8 deletions

View File

@ -78,14 +78,14 @@ float4_folding = PatternMatcher([
# ***** image load valid simplification *****
def is_increasing(f:UOp):
def is_increasing(f:UOp) -> bool:
# is f a monotonically increasing function regards its input
if f.op is UOps.CONST or is_irreducible(f): return True
if is_irreducible(f): return True
if f.op is UOps.ALU and f.arg is BinaryOps.ADD: return is_increasing(f.src[0]) and is_increasing(f.src[1])
if f.op is UOps.ALU and f.arg in (BinaryOps.MUL, BinaryOps.IDIV) and f.src[1].op is UOps.CONST and f.src[1].arg >= 0: return is_increasing(f.src[0])
return False # False if not sure
def replace_uop(uop:UOp, old:UOp, new:UOp):
def replace_uop(uop:UOp, old:UOp, new:UOp) -> UOp:
# replace all `old` in `uop` to `new`
return new if uop is old else uop.replace(src=tuple(replace_uop(s, old, new) for s in uop.src))
@ -124,7 +124,7 @@ def idx_given_valid(valid:UOp, idx:UOp) -> Optional[UOp]:
for candidate in candidates:
newidxs = [replace_uop(graph_rewrite(replace_uop(idx, X, newX), sym), newX, X) for X,newX in candidate]
if idx.op is UOps.VECTORIZE:
if idx.op is UOps.VECTORIZE and len(idx.src) == 2:
# if every branch in candidate gives the same simplified output, we can rewrite the idx
if all_same([idxs.src[0] for idxs in newidxs]): idx = idx.replace(src=(newidxs[0].src[0], idx.src[1]))
if all_same([idxs.src[1] for idxs in newidxs]): idx = idx.replace(src=(idx.src[0], newidxs[0].src[1]))
@ -132,7 +132,7 @@ def idx_given_valid(valid:UOp, idx:UOp) -> Optional[UOp]:
return idx
def simplify_buffer_load(load:UOp):
def simplify_buffer_load(load:UOp) -> Optional[UOp]:
if not isinstance(load.src[0].dtype, PtrDType) or len(load.src) != 4: return None
buf, start_idx, invalid_val, valid = load.src
try:
@ -140,7 +140,7 @@ def simplify_buffer_load(load:UOp):
except ValueError: return None
return None if idx is start_idx else load.replace(src=((buf, idx, invalid_val, valid)))
def simplify_image_load(load:UOp):
def simplify_image_load(load:UOp) -> Optional[UOp]:
if not isinstance(buf_dtype:=load.src[0].dtype, ImageDType) or len(load.src) != 4: return None
buf, start_idx, invalid_val, valid = load.src
if (idx:=idx_given_valid(valid, start_idx)) is None: return load.replace(src=(buf, start_idx, invalid_val, valid.const_like(False)))
@ -151,6 +151,7 @@ def simplify_image_load(load:UOp):
X, is_upper_bound, c = parse_valid(stmt)
# for X0 + X1 + ... >= 1, check if it's out of bound when Xi = 0 for all i
# TODO: does not need to be add chain?
if not is_upper_bound and c == 1 and X.op is UOps.ALU and X.arg is BinaryOps.ADD and \
all(is_irreducible(u) and u.vmin == 0 for u in _get_chain(X, BinaryOps.ADD)):
testidx = functools.reduce(lambda nowidx,u: replace_uop(nowidx, u, u.const_like(0)), _get_chain(X, BinaryOps.ADD), idx)
@ -169,7 +170,7 @@ def simplify_image_load(load:UOp):
drop_stmt.append(stmt)
break
if not drop_stmt and idx == start_idx: return None
if not drop_stmt and idx is start_idx: return None
new_valid = functools.reduce(operator.and_, ss) if (ss:=[s for s in _get_chain(valid, BinaryOps.AND) if s not in drop_stmt]) else None
return load.replace(src=((buf, idx, invalid_val, new_valid) if new_valid is not None else (buf, idx)))

View File

@ -848,7 +848,7 @@ def fold_unrolled_divs(divs:UOp):
if ans is not None and 0 <= ans.vmin and ans.vmax + i < denominator: seen_const.append(i)
return ans if ans is not None and sorted(seen_const)==list(range(denominator)) else None
def is_irreducible(u:UOp): return u.op in (UOps.DEFINE_VAR, UOps.SPECIAL, UOps.RANGE)
def is_irreducible(u:UOp): return u.op in (UOps.CONST, UOps.DEFINE_VAR, UOps.SPECIAL, UOps.RANGE)
def canonicalize_simplex(X:UOp) -> Optional[UOp]:
# (X := a0*x0 + a1*x1 + ...) > 0 is equivalent to x0 + x1 + ... > 0 if xi >= 0 and ai > 0 for ints.