use UOp.replace and UOp.define_var in validhack (#6730)

easier to see the diff in replacement
[run_process_replay]
This commit is contained in:
chenyu 2024-09-25 02:51:34 -04:00 committed by GitHub
parent ff25bfb1b0
commit 66af8bb54c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 4 additions and 4 deletions

View File

@ -186,7 +186,7 @@ def is_increasing(f:UOp):
def replace_uop(uop:UOp, old:UOp, new:UOp):
# replace all `old` in `uop` to `new`
return new if uop.key == old.key else UOp(uop.op, uop.dtype, tuple(replace_uop(s, old, new) for s in uop.src), uop.arg)
return new if uop.key == old.key else uop.replace(src=tuple(replace_uop(s, old, new) for s in uop.src))
def parse_valid(valid:UOp) -> Tuple[UOp, bool, int]:
# if it's X <= c, returns X, True, c
@ -217,7 +217,7 @@ def idx_given_valid(valid:UOp, idx:UOp) -> Optional[UOp]:
candidates = []
if uop.op is UOps.ALU and uop.arg is BinaryOps.ADD and all(is_irreducible(u) and u.vmin == 0 for u in _get_chain(uop, BinaryOps.ADD)):
# if the constraint is a simplex: X0 + X1 + ... > 0, we can check if all Xi > 0 simplify into the same output
candidates.append([(Xi, UOp(UOps.DEFINE_VAR, Xi.dtype, (), ("fake", 1, Xi.vmax))) for Xi in _get_chain(uop, BinaryOps.ADD)])
candidates.append([(Xi, UOp.define_var("fake", Xi.dtype, 1, Xi.vmax)) for Xi in _get_chain(uop, BinaryOps.ADD)])
# try checking the whole clause
candidates.append([(uop, UOp.define_var("fake", uop.dtype, uop.vmin if v[0] is None else v[0], uop.vmax if v[1] is None else v[1]))])
@ -236,7 +236,7 @@ def idx_given_valid(valid:UOp, idx:UOp) -> Optional[UOp]:
def simplify_valid_image_load(load:UOp, buf:UOp):
if not isinstance(buf_dtype:=buf.dtype, ImageDType) or len(load.src) < 4: return None
buf, start_idx, invalid_val, valid = load.src
if (idx:=idx_given_valid(valid, start_idx)) is None: return UOp(UOps.LOAD, load.dtype, (buf, start_idx, invalid_val, valid.const_like(False)))
if (idx:=idx_given_valid(valid, start_idx)) is None: return load.replace(src=(buf, start_idx, invalid_val, valid.const_like(False)))
# can drop valid if idx is out of bound when valid is False
drop_stmt = []
@ -264,7 +264,7 @@ def simplify_valid_image_load(load:UOp, buf:UOp):
if not drop_stmt and idx.key == start_idx.key: return None
new_valid = functools.reduce(operator.and_, ss) if (ss:=[s for s in _get_chain(valid, BinaryOps.AND) if s not in drop_stmt]) else None
return UOp(UOps.LOAD, load.dtype, (buf, idx, invalid_val, new_valid)) if new_valid else UOp(UOps.LOAD, load.dtype, (buf, idx))
return load.replace(src=((buf, idx, invalid_val, new_valid) if new_valid else (buf, idx)))
# ***** transcendental *****