mirror of https://github.com/commaai/tinygrad.git
cleanups around idx_given_valid [pr] (#7074)
This commit is contained in:
parent
545e79969f
commit
e136cea027
|
@ -78,14 +78,14 @@ float4_folding = PatternMatcher([
|
|||
|
||||
# ***** image load valid simplification *****
|
||||
|
||||
def is_increasing(f:UOp):
|
||||
def is_increasing(f:UOp) -> bool:
|
||||
# is f a monotonically increasing function regards its input
|
||||
if f.op is UOps.CONST or is_irreducible(f): return True
|
||||
if is_irreducible(f): return True
|
||||
if f.op is UOps.ALU and f.arg is BinaryOps.ADD: return is_increasing(f.src[0]) and is_increasing(f.src[1])
|
||||
if f.op is UOps.ALU and f.arg in (BinaryOps.MUL, BinaryOps.IDIV) and f.src[1].op is UOps.CONST and f.src[1].arg >= 0: return is_increasing(f.src[0])
|
||||
return False # False if not sure
|
||||
|
||||
def replace_uop(uop:UOp, old:UOp, new:UOp):
|
||||
def replace_uop(uop:UOp, old:UOp, new:UOp) -> UOp:
|
||||
# replace all `old` in `uop` to `new`
|
||||
return new if uop is old else uop.replace(src=tuple(replace_uop(s, old, new) for s in uop.src))
|
||||
|
||||
|
@ -124,7 +124,7 @@ def idx_given_valid(valid:UOp, idx:UOp) -> Optional[UOp]:
|
|||
|
||||
for candidate in candidates:
|
||||
newidxs = [replace_uop(graph_rewrite(replace_uop(idx, X, newX), sym), newX, X) for X,newX in candidate]
|
||||
if idx.op is UOps.VECTORIZE:
|
||||
if idx.op is UOps.VECTORIZE and len(idx.src) == 2:
|
||||
# if every branch in candidate gives the same simplified output, we can rewrite the idx
|
||||
if all_same([idxs.src[0] for idxs in newidxs]): idx = idx.replace(src=(newidxs[0].src[0], idx.src[1]))
|
||||
if all_same([idxs.src[1] for idxs in newidxs]): idx = idx.replace(src=(idx.src[0], newidxs[0].src[1]))
|
||||
|
@ -132,7 +132,7 @@ def idx_given_valid(valid:UOp, idx:UOp) -> Optional[UOp]:
|
|||
|
||||
return idx
|
||||
|
||||
def simplify_buffer_load(load:UOp):
|
||||
def simplify_buffer_load(load:UOp) -> Optional[UOp]:
|
||||
if not isinstance(load.src[0].dtype, PtrDType) or len(load.src) != 4: return None
|
||||
buf, start_idx, invalid_val, valid = load.src
|
||||
try:
|
||||
|
@ -140,7 +140,7 @@ def simplify_buffer_load(load:UOp):
|
|||
except ValueError: return None
|
||||
return None if idx is start_idx else load.replace(src=((buf, idx, invalid_val, valid)))
|
||||
|
||||
def simplify_image_load(load:UOp):
|
||||
def simplify_image_load(load:UOp) -> Optional[UOp]:
|
||||
if not isinstance(buf_dtype:=load.src[0].dtype, ImageDType) or len(load.src) != 4: return None
|
||||
buf, start_idx, invalid_val, valid = load.src
|
||||
if (idx:=idx_given_valid(valid, start_idx)) is None: return load.replace(src=(buf, start_idx, invalid_val, valid.const_like(False)))
|
||||
|
@ -151,6 +151,7 @@ def simplify_image_load(load:UOp):
|
|||
X, is_upper_bound, c = parse_valid(stmt)
|
||||
|
||||
# for X0 + X1 + ... >= 1, check if it's out of bound when Xi = 0 for all i
|
||||
# TODO: does not need to be add chain?
|
||||
if not is_upper_bound and c == 1 and X.op is UOps.ALU and X.arg is BinaryOps.ADD and \
|
||||
all(is_irreducible(u) and u.vmin == 0 for u in _get_chain(X, BinaryOps.ADD)):
|
||||
testidx = functools.reduce(lambda nowidx,u: replace_uop(nowidx, u, u.const_like(0)), _get_chain(X, BinaryOps.ADD), idx)
|
||||
|
@ -169,7 +170,7 @@ def simplify_image_load(load:UOp):
|
|||
drop_stmt.append(stmt)
|
||||
break
|
||||
|
||||
if not drop_stmt and idx == start_idx: return None
|
||||
if not drop_stmt and idx is start_idx: return None
|
||||
new_valid = functools.reduce(operator.and_, ss) if (ss:=[s for s in _get_chain(valid, BinaryOps.AND) if s not in drop_stmt]) else None
|
||||
return load.replace(src=((buf, idx, invalid_val, new_valid) if new_valid is not None else (buf, idx)))
|
||||
|
||||
|
|
|
@ -848,7 +848,7 @@ def fold_unrolled_divs(divs:UOp):
|
|||
if ans is not None and 0 <= ans.vmin and ans.vmax + i < denominator: seen_const.append(i)
|
||||
return ans if ans is not None and sorted(seen_const)==list(range(denominator)) else None
|
||||
|
||||
def is_irreducible(u:UOp): return u.op in (UOps.DEFINE_VAR, UOps.SPECIAL, UOps.RANGE)
|
||||
def is_irreducible(u:UOp): return u.op in (UOps.CONST, UOps.DEFINE_VAR, UOps.SPECIAL, UOps.RANGE)
|
||||
|
||||
def canonicalize_simplex(X:UOp) -> Optional[UOp]:
|
||||
# (X := a0*x0 + a1*x1 + ...) > 0 is equivalent to x0 + x1 + ... > 0 if xi >= 0 and ai > 0 for ints.
|
||||
|
|
Loading…
Reference in New Issue