mirror of https://github.com/commaai/tinygrad.git
idx_given_valid -> uop_given_valid [pr] (#7110)
will reuse this to simplify valid independent of idx
This commit is contained in:
parent
842fe444df
commit
51cd0e7c0d
|
@ -100,8 +100,8 @@ def parse_valid(valid:UOp) -> Tuple[UOp, bool, int]:
|
|||
if valid.op is UOps.ALU and valid.arg is BinaryOps.CMPLT and valid.src[1].op is UOps.CONST: return valid.src[0], True, valid.src[1].arg-1
|
||||
raise ValueError(f"not able to parse {valid=}")
|
||||
|
||||
def idx_given_valid(valid:UOp, idx:UOp) -> Optional[UOp]:
|
||||
# return None if valid is always False, otherwise the simplified idx (might be the same as input)
|
||||
def uop_given_valid(valid:UOp, uop:UOp) -> Optional[UOp]:
|
||||
# return None if valid is always False, otherwise the simplified uop (might be the same as input)
|
||||
|
||||
# first, parse valid into {expr: (lower_bound, upper_bound)}
|
||||
bounds:DefaultDict[UOp, List[Optional[ConstType]]] = defaultdict(lambda: [None, None])
|
||||
|
@ -109,41 +109,41 @@ def idx_given_valid(valid:UOp, idx:UOp) -> Optional[UOp]:
|
|||
expr, is_upper, c = parse_valid(stmt)
|
||||
bounds[expr][int(is_upper)] = c
|
||||
|
||||
# simplify idx given that valid is True
|
||||
for uop,v in bounds.items():
|
||||
# simplify uop given that valid is True
|
||||
for expr,v in bounds.items():
|
||||
# some expr has lower bound > upper bound -> valid is an empty set and we return None
|
||||
if v[0] is not None and v[1] is not None and v[0] > v[1]: return None
|
||||
|
||||
# every candidate is a set of contrained UOp based on valid, and if every item in a set simplifies the idx into a same output, we rewrite idx
|
||||
# every candidate is a set of contrained UOp based on valid, and if every item in a set simplifies the uop into a same output, we rewrite uop
|
||||
candidates = []
|
||||
if uop.op is UOps.ALU and uop.arg is BinaryOps.ADD and all(is_irreducible(u) and u.vmin == 0 for u in split_uop(uop, BinaryOps.ADD)):
|
||||
if expr.op is UOps.ALU and expr.arg is BinaryOps.ADD and all(is_irreducible(u) and u.vmin == 0 for u in split_uop(expr, BinaryOps.ADD)):
|
||||
# if the constraint is a simplex: X0 + X1 + ... > 0, we can check if all Xi > 0 simplify into the same output
|
||||
candidates.append([(Xi, UOp.variable("fake", 1, Xi.vmax, Xi.dtype)) for Xi in split_uop(uop, BinaryOps.ADD)])
|
||||
candidates.append([(Xi, UOp.variable("fake", 1, Xi.vmax, Xi.dtype)) for Xi in split_uop(expr, BinaryOps.ADD)])
|
||||
# try checking the whole clause
|
||||
candidates.append([(uop, UOp.variable("fake", uop.vmin if v[0] is None else v[0], uop.vmax if v[1] is None else v[1], uop.dtype))])
|
||||
candidates.append([(expr, UOp.variable("fake", expr.vmin if v[0] is None else v[0], expr.vmax if v[1] is None else v[1], expr.dtype))])
|
||||
|
||||
for candidate in candidates:
|
||||
newidxs = [replace_uop(graph_rewrite(replace_uop(idx, X, newX), sym), newX, X) for X,newX in candidate]
|
||||
if idx.op is UOps.VECTORIZE and len(idx.src) == 2:
|
||||
# if every branch in candidate gives the same simplified output, we can rewrite the idx
|
||||
if all_same([idxs.src[0] for idxs in newidxs]): idx = idx.replace(src=(newidxs[0].src[0], idx.src[1]))
|
||||
if all_same([idxs.src[1] for idxs in newidxs]): idx = idx.replace(src=(idx.src[0], newidxs[0].src[1]))
|
||||
elif all_same(newidxs): idx = newidxs[0]
|
||||
# if every branch in candidate gives the same simplified uop, we can rewrite the uop
|
||||
newuops = [replace_uop(graph_rewrite(replace_uop(uop, X, newX), sym), newX, X) for X,newX in candidate]
|
||||
if uop.op is UOps.VECTORIZE and len(uop.src) == 2:
|
||||
if all_same([uops.src[0] for uops in newuops]): uop = uop.replace(src=(newuops[0].src[0], uop.src[1]))
|
||||
if all_same([uops.src[1] for uops in newuops]): uop = uop.replace(src=(uop.src[0], newuops[0].src[1]))
|
||||
elif all_same(newuops): uop = newuops[0]
|
||||
|
||||
return idx
|
||||
return uop
|
||||
|
||||
def simplify_buffer_load(load:UOp) -> Optional[UOp]:
|
||||
if not isinstance(load.src[0].dtype, PtrDType) or len(load.src) != 4: return None
|
||||
buf, start_idx, invalid_val, valid = load.src
|
||||
try:
|
||||
if (idx:=idx_given_valid(valid, start_idx)) is None: return load.replace(src=(buf, start_idx, invalid_val, valid.const_like(False)))
|
||||
if (idx:=uop_given_valid(valid, start_idx)) is None: return load.replace(src=(buf, start_idx, invalid_val, valid.const_like(False)))
|
||||
except ValueError: return None
|
||||
return None if idx is start_idx else load.replace(src=((buf, idx, invalid_val, valid)))
|
||||
|
||||
def simplify_image_load(load:UOp) -> Optional[UOp]:
|
||||
if not isinstance(buf_dtype:=load.src[0].dtype, ImageDType) or len(load.src) != 4: return None
|
||||
buf, start_idx, invalid_val, valid = load.src
|
||||
if (idx:=idx_given_valid(valid, start_idx)) is None: return load.replace(src=(buf, start_idx, invalid_val, valid.const_like(False)))
|
||||
if (idx:=uop_given_valid(valid, start_idx)) is None: return load.replace(src=(buf, start_idx, invalid_val, valid.const_like(False)))
|
||||
|
||||
# can drop valid if idx is out of bound when valid is False
|
||||
drop_stmt = []
|
||||
|
|
Loading…
Reference in New Issue