idx_given_valid -> uop_given_valid [pr] (#7110)

will reuse this to simplify valid independent of idx
This commit is contained in:
chenyu 2024-10-16 18:16:36 -04:00 committed by GitHub
parent 842fe444df
commit 51cd0e7c0d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 17 additions and 17 deletions

View File

@ -100,8 +100,8 @@ def parse_valid(valid:UOp) -> Tuple[UOp, bool, int]:
if valid.op is UOps.ALU and valid.arg is BinaryOps.CMPLT and valid.src[1].op is UOps.CONST: return valid.src[0], True, valid.src[1].arg-1
raise ValueError(f"not able to parse {valid=}")
def idx_given_valid(valid:UOp, idx:UOp) -> Optional[UOp]:
# return None if valid is always False, otherwise the simplified idx (might be the same as input)
def uop_given_valid(valid:UOp, uop:UOp) -> Optional[UOp]:
# return None if valid is always False, otherwise the simplified uop (might be the same as input)
# first, parse valid into {expr: (lower_bound, upper_bound)}
bounds:DefaultDict[UOp, List[Optional[ConstType]]] = defaultdict(lambda: [None, None])
@ -109,41 +109,41 @@ def idx_given_valid(valid:UOp, idx:UOp) -> Optional[UOp]:
expr, is_upper, c = parse_valid(stmt)
bounds[expr][int(is_upper)] = c
# simplify idx given that valid is True
for uop,v in bounds.items():
# simplify uop given that valid is True
for expr,v in bounds.items():
# some expr has lower bound > upper bound -> valid is an empty set and we return None
if v[0] is not None and v[1] is not None and v[0] > v[1]: return None
# every candidate is a set of contrained UOp based on valid, and if every item in a set simplifies the idx into a same output, we rewrite idx
# every candidate is a set of contrained UOp based on valid, and if every item in a set simplifies the uop into a same output, we rewrite uop
candidates = []
if uop.op is UOps.ALU and uop.arg is BinaryOps.ADD and all(is_irreducible(u) and u.vmin == 0 for u in split_uop(uop, BinaryOps.ADD)):
if expr.op is UOps.ALU and expr.arg is BinaryOps.ADD and all(is_irreducible(u) and u.vmin == 0 for u in split_uop(expr, BinaryOps.ADD)):
# if the constraint is a simplex: X0 + X1 + ... > 0, we can check if all Xi > 0 simplify into the same output
candidates.append([(Xi, UOp.variable("fake", 1, Xi.vmax, Xi.dtype)) for Xi in split_uop(uop, BinaryOps.ADD)])
candidates.append([(Xi, UOp.variable("fake", 1, Xi.vmax, Xi.dtype)) for Xi in split_uop(expr, BinaryOps.ADD)])
# try checking the whole clause
candidates.append([(uop, UOp.variable("fake", uop.vmin if v[0] is None else v[0], uop.vmax if v[1] is None else v[1], uop.dtype))])
candidates.append([(expr, UOp.variable("fake", expr.vmin if v[0] is None else v[0], expr.vmax if v[1] is None else v[1], expr.dtype))])
for candidate in candidates:
newidxs = [replace_uop(graph_rewrite(replace_uop(idx, X, newX), sym), newX, X) for X,newX in candidate]
if idx.op is UOps.VECTORIZE and len(idx.src) == 2:
# if every branch in candidate gives the same simplified output, we can rewrite the idx
if all_same([idxs.src[0] for idxs in newidxs]): idx = idx.replace(src=(newidxs[0].src[0], idx.src[1]))
if all_same([idxs.src[1] for idxs in newidxs]): idx = idx.replace(src=(idx.src[0], newidxs[0].src[1]))
elif all_same(newidxs): idx = newidxs[0]
# if every branch in candidate gives the same simplified uop, we can rewrite the uop
newuops = [replace_uop(graph_rewrite(replace_uop(uop, X, newX), sym), newX, X) for X,newX in candidate]
if uop.op is UOps.VECTORIZE and len(uop.src) == 2:
if all_same([uops.src[0] for uops in newuops]): uop = uop.replace(src=(newuops[0].src[0], uop.src[1]))
if all_same([uops.src[1] for uops in newuops]): uop = uop.replace(src=(uop.src[0], newuops[0].src[1]))
elif all_same(newuops): uop = newuops[0]
return idx
return uop
def simplify_buffer_load(load:UOp) -> Optional[UOp]:
if not isinstance(load.src[0].dtype, PtrDType) or len(load.src) != 4: return None
buf, start_idx, invalid_val, valid = load.src
try:
if (idx:=idx_given_valid(valid, start_idx)) is None: return load.replace(src=(buf, start_idx, invalid_val, valid.const_like(False)))
if (idx:=uop_given_valid(valid, start_idx)) is None: return load.replace(src=(buf, start_idx, invalid_val, valid.const_like(False)))
except ValueError: return None
return None if idx is start_idx else load.replace(src=((buf, idx, invalid_val, valid)))
def simplify_image_load(load:UOp) -> Optional[UOp]:
if not isinstance(buf_dtype:=load.src[0].dtype, ImageDType) or len(load.src) != 4: return None
buf, start_idx, invalid_val, valid = load.src
if (idx:=idx_given_valid(valid, start_idx)) is None: return load.replace(src=(buf, start_idx, invalid_val, valid.const_like(False)))
if (idx:=uop_given_valid(valid, start_idx)) is None: return load.replace(src=(buf, start_idx, invalid_val, valid.const_like(False)))
# can drop valid if idx is out of bound when valid is False
drop_stmt = []