mirror of https://github.com/commaai/tinygrad.git
cast bool for type verify [run_process_replay] (#6742)
This commit is contained in:
parent
cb22ef379a
commit
232edcfd4f
|
@ -561,7 +561,7 @@ def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None) -> UOp:
|
|||
|
||||
# this is the matcher for the final rendered UOps
|
||||
# matcher functions returns True or False (or None to not match)
|
||||
spec = PatternMatcher([(x, functools.partial(lambda fxn,**kw: UOp.const(dtypes.bool, r) if (r:=fxn(**kw)) is not None else None, y)) for (x,y) in [
|
||||
spec = PatternMatcher([
|
||||
(UPat(UOps.DEFINE_GLOBAL, name="x"), lambda x: isinstance(x.dtype, (PtrDType, ImageDType)) and not x.dtype.local),
|
||||
(UPat(UOps.DEFINE_LOCAL, name="x"), lambda x: isinstance(x.dtype, PtrDType) and x.dtype.local),
|
||||
(UPat(UOps.DEFINE_ACC, src=(UPat(UOps.CONST, name="c"),), name="x", allow_any_len=True),
|
||||
|
@ -631,9 +631,9 @@ spec = PatternMatcher([(x, functools.partial(lambda fxn,**kw: UOp.const(dtypes.b
|
|||
# PTX LOAD/STORE
|
||||
(UPat((UOps.LOAD, UOps.STORE), src=(UPat(dtype=dtypes.int64),), allow_any_len=True), lambda: True),
|
||||
(UPat(UOps.BARRIER, dtypes.void, src=UPat(UOps.STORE, src=(UPat(dtype=dtypes.int64),), allow_any_len=True)), lambda: True),
|
||||
]])
|
||||
])
|
||||
|
||||
def type_verify(uops:List[UOp]):
|
||||
for u in uops:
|
||||
chk = spec.rewrite(u)
|
||||
assert chk is not None and chk.arg is True, f"UOp verification failed on {u.op} {u.dtype} {len(u.src)} {[x.op for x in u.src]} {u.arg}"
|
||||
chk = cast(bool, spec.rewrite(u))
|
||||
assert chk is True, f"UOp verification failed on {u.op} {u.dtype} {len(u.src)} {[x.op for x in u.src]} {u.arg}"
|
||||
|
|
Loading…
Reference in New Issue