mirror of https://github.com/commaai/tinygrad.git
use UOp.replace and UOp.define_var in validhack (#6730)
easier to see the diff in replacement [run_process_replay]
This commit is contained in:
parent
ff25bfb1b0
commit
66af8bb54c
|
@ -186,7 +186,7 @@ def is_increasing(f:UOp):
|
|||
|
||||
def replace_uop(uop:UOp, old:UOp, new:UOp):
|
||||
# replace all `old` in `uop` to `new`
|
||||
return new if uop.key == old.key else UOp(uop.op, uop.dtype, tuple(replace_uop(s, old, new) for s in uop.src), uop.arg)
|
||||
return new if uop.key == old.key else uop.replace(src=tuple(replace_uop(s, old, new) for s in uop.src))
|
||||
|
||||
def parse_valid(valid:UOp) -> Tuple[UOp, bool, int]:
|
||||
# if it's X <= c, returns X, True, c
|
||||
|
@ -217,7 +217,7 @@ def idx_given_valid(valid:UOp, idx:UOp) -> Optional[UOp]:
|
|||
candidates = []
|
||||
if uop.op is UOps.ALU and uop.arg is BinaryOps.ADD and all(is_irreducible(u) and u.vmin == 0 for u in _get_chain(uop, BinaryOps.ADD)):
|
||||
# if the constraint is a simplex: X0 + X1 + ... > 0, we can check if all Xi > 0 simplify into the same output
|
||||
candidates.append([(Xi, UOp(UOps.DEFINE_VAR, Xi.dtype, (), ("fake", 1, Xi.vmax))) for Xi in _get_chain(uop, BinaryOps.ADD)])
|
||||
candidates.append([(Xi, UOp.define_var("fake", Xi.dtype, 1, Xi.vmax)) for Xi in _get_chain(uop, BinaryOps.ADD)])
|
||||
# try checking the whole clause
|
||||
candidates.append([(uop, UOp.define_var("fake", uop.dtype, uop.vmin if v[0] is None else v[0], uop.vmax if v[1] is None else v[1]))])
|
||||
|
||||
|
@ -236,7 +236,7 @@ def idx_given_valid(valid:UOp, idx:UOp) -> Optional[UOp]:
|
|||
def simplify_valid_image_load(load:UOp, buf:UOp):
|
||||
if not isinstance(buf_dtype:=buf.dtype, ImageDType) or len(load.src) < 4: return None
|
||||
buf, start_idx, invalid_val, valid = load.src
|
||||
if (idx:=idx_given_valid(valid, start_idx)) is None: return UOp(UOps.LOAD, load.dtype, (buf, start_idx, invalid_val, valid.const_like(False)))
|
||||
if (idx:=idx_given_valid(valid, start_idx)) is None: return load.replace(src=(buf, start_idx, invalid_val, valid.const_like(False)))
|
||||
|
||||
# can drop valid if idx is out of bound when valid is False
|
||||
drop_stmt = []
|
||||
|
@ -264,7 +264,7 @@ def simplify_valid_image_load(load:UOp, buf:UOp):
|
|||
|
||||
if not drop_stmt and idx.key == start_idx.key: return None
|
||||
new_valid = functools.reduce(operator.and_, ss) if (ss:=[s for s in _get_chain(valid, BinaryOps.AND) if s not in drop_stmt]) else None
|
||||
return UOp(UOps.LOAD, load.dtype, (buf, idx, invalid_val, new_valid)) if new_valid else UOp(UOps.LOAD, load.dtype, (buf, idx))
|
||||
return load.replace(src=((buf, idx, invalid_val, new_valid) if new_valid else (buf, idx)))
|
||||
|
||||
# ***** transcendental *****
|
||||
|
||||
|
|
Loading…
Reference in New Issue