mirror of https://github.com/commaai/tinygrad.git
parent
1923932339
commit
2d4d594994
|
@ -163,6 +163,8 @@ def fold_unrolled_divs(divs:UOp):
|
|||
|
||||
# ***** image load valid simplification *****
|
||||
|
||||
def is_irreducible(u:UOp): return u.op in (UOps.DEFINE_VAR, UOps.SPECIAL, UOps.RANGE)
|
||||
|
||||
def canonicalize_simplex(X:UOp) -> Optional[UOp]:
|
||||
# (X := a0*x0 + a1*x1 + ...) > 0 is equivalent to x0 + x1 + ... > 0 if xi >= 0 and ai > 0 for ints.
|
||||
# returns x0 + x1 + ... in such case, or None if not
|
||||
|
@ -172,13 +174,13 @@ def canonicalize_simplex(X:UOp) -> Optional[UOp]:
|
|||
if u.op is UOps.ALU and u.arg is BinaryOps.MUL and u.src[1].op is UOps.CONST and u.src[1].arg > 0:
|
||||
changed = True
|
||||
u = u.src[0]
|
||||
if not (u.op in (UOps.DEFINE_VAR, UOps.SPECIAL, UOps.RANGE) and u.vmin >= 0): return None
|
||||
if not (is_irreducible(u) and u.vmin >= 0): return None
|
||||
ret.append(u)
|
||||
return functools.reduce(operator.add, ret) if changed else None
|
||||
|
||||
def is_increasing(f:UOp):
|
||||
# is f a monotonically increasing function regards its input
|
||||
if f.op in [UOps.CONST, UOps.DEFINE_VAR, UOps.SPECIAL, UOps.RANGE]: return True
|
||||
if f.op is UOps.CONST or is_irreducible(f): return True
|
||||
if f.op is UOps.ALU and f.arg is BinaryOps.ADD: return is_increasing(f.src[0]) and is_increasing(f.src[1])
|
||||
if f.op is UOps.ALU and f.arg in (BinaryOps.MUL, BinaryOps.IDIV) and f.src[1].op is UOps.CONST and f.src[1].arg >= 0: return is_increasing(f.src[0])
|
||||
return False # False if not sure
|
||||
|
|
Loading…
Reference in New Issue