validhack is_irreducible helper (#6664)

[run_process_replay]
This commit is contained in:
chenyu 2024-09-22 23:42:47 -04:00 committed by GitHub
parent 1923932339
commit 2d4d594994
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 4 additions and 2 deletions

View File

@ -163,6 +163,8 @@ def fold_unrolled_divs(divs:UOp):
# ***** image load valid simplification ***** # ***** image load valid simplification *****
def is_irreducible(u:UOp): return u.op in (UOps.DEFINE_VAR, UOps.SPECIAL, UOps.RANGE)
def canonicalize_simplex(X:UOp) -> Optional[UOp]: def canonicalize_simplex(X:UOp) -> Optional[UOp]:
# (X := a0*x0 + a1*x1 + ...) > 0 is equivalent to x0 + x1 + ... > 0 if xi >= 0 and ai > 0 for ints. # (X := a0*x0 + a1*x1 + ...) > 0 is equivalent to x0 + x1 + ... > 0 if xi >= 0 and ai > 0 for ints.
# returns x0 + x1 + ... in such case, or None if not # returns x0 + x1 + ... in such case, or None if not
@ -172,13 +174,13 @@ def canonicalize_simplex(X:UOp) -> Optional[UOp]:
if u.op is UOps.ALU and u.arg is BinaryOps.MUL and u.src[1].op is UOps.CONST and u.src[1].arg > 0: if u.op is UOps.ALU and u.arg is BinaryOps.MUL and u.src[1].op is UOps.CONST and u.src[1].arg > 0:
changed = True changed = True
u = u.src[0] u = u.src[0]
if not (u.op in (UOps.DEFINE_VAR, UOps.SPECIAL, UOps.RANGE) and u.vmin >= 0): return None if not (is_irreducible(u) and u.vmin >= 0): return None
ret.append(u) ret.append(u)
return functools.reduce(operator.add, ret) if changed else None return functools.reduce(operator.add, ret) if changed else None
def is_increasing(f:UOp): def is_increasing(f:UOp):
# is f a monotonically increasing function regards its input # is f a monotonically increasing function regards its input
if f.op in [UOps.CONST, UOps.DEFINE_VAR, UOps.SPECIAL, UOps.RANGE]: return True if f.op is UOps.CONST or is_irreducible(f): return True
if f.op is UOps.ALU and f.arg is BinaryOps.ADD: return is_increasing(f.src[0]) and is_increasing(f.src[1]) if f.op is UOps.ALU and f.arg is BinaryOps.ADD: return is_increasing(f.src[0]) and is_increasing(f.src[1])
if f.op is UOps.ALU and f.arg in (BinaryOps.MUL, BinaryOps.IDIV) and f.src[1].op is UOps.CONST and f.src[1].arg >= 0: return is_increasing(f.src[0]) if f.op is UOps.ALU and f.arg in (BinaryOps.MUL, BinaryOps.IDIV) and f.src[1].op is UOps.CONST and f.src[1].arg >= 0: return is_increasing(f.src[0])
return False # False if not sure return False # False if not sure