From 66af8bb54c75e555a476d7d68f1d9c33ec295d69 Mon Sep 17 00:00:00 2001 From: chenyu Date: Wed, 25 Sep 2024 02:51:34 -0400 Subject: [PATCH] use UOp.replace and UOp.define_var in validhack (#6730) easier to see the diff in replacement [run_process_replay] --- tinygrad/codegen/uopgraph.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index b6d93f6b..4ea50003 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -186,7 +186,7 @@ def is_increasing(f:UOp): def replace_uop(uop:UOp, old:UOp, new:UOp): # replace all `old` in `uop` to `new` - return new if uop.key == old.key else UOp(uop.op, uop.dtype, tuple(replace_uop(s, old, new) for s in uop.src), uop.arg) + return new if uop.key == old.key else uop.replace(src=tuple(replace_uop(s, old, new) for s in uop.src)) def parse_valid(valid:UOp) -> Tuple[UOp, bool, int]: # if it's X <= c, returns X, True, c @@ -217,7 +217,7 @@ def idx_given_valid(valid:UOp, idx:UOp) -> Optional[UOp]: candidates = [] if uop.op is UOps.ALU and uop.arg is BinaryOps.ADD and all(is_irreducible(u) and u.vmin == 0 for u in _get_chain(uop, BinaryOps.ADD)): # if the constraint is a simplex: X0 + X1 + ... > 0, we can check if all Xi > 0 simplify into the same output - candidates.append([(Xi, UOp(UOps.DEFINE_VAR, Xi.dtype, (), ("fake", 1, Xi.vmax))) for Xi in _get_chain(uop, BinaryOps.ADD)]) + candidates.append([(Xi, UOp.define_var("fake", Xi.dtype, 1, Xi.vmax)) for Xi in _get_chain(uop, BinaryOps.ADD)]) # try checking the whole clause candidates.append([(uop, UOp.define_var("fake", uop.dtype, uop.vmin if v[0] is None else v[0], uop.vmax if v[1] is None else v[1]))]) @@ -236,7 +236,7 @@ def idx_given_valid(valid:UOp, idx:UOp) -> Optional[UOp]: def simplify_valid_image_load(load:UOp, buf:UOp): if not isinstance(buf_dtype:=buf.dtype, ImageDType) or len(load.src) < 4: return None buf, start_idx, invalid_val, valid = load.src - if (idx:=idx_given_valid(valid, start_idx)) is None: return UOp(UOps.LOAD, load.dtype, (buf, start_idx, invalid_val, valid.const_like(False))) + if (idx:=idx_given_valid(valid, start_idx)) is None: return load.replace(src=(buf, start_idx, invalid_val, valid.const_like(False))) # can drop valid if idx is out of bound when valid is False drop_stmt = [] @@ -264,7 +264,7 @@ def simplify_valid_image_load(load:UOp, buf:UOp): if not drop_stmt and idx.key == start_idx.key: return None new_valid = functools.reduce(operator.and_, ss) if (ss:=[s for s in _get_chain(valid, BinaryOps.AND) if s not in drop_stmt]) else None - return UOp(UOps.LOAD, load.dtype, (buf, idx, invalid_val, new_valid)) if new_valid else UOp(UOps.LOAD, load.dtype, (buf, idx)) + return load.replace(src=((buf, idx, invalid_val, new_valid) if new_valid else (buf, idx))) # ***** transcendental *****