From 232edcfd4f8b388807c64fb1817a7668ce27cbad Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Wed, 25 Sep 2024 17:12:16 +0800 Subject: [PATCH] cast bool for type verify [run_process_replay] (#6742) --- tinygrad/ops.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 2e4a582f..8ac32e5a 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -561,7 +561,7 @@ def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None) -> UOp: # this is the matcher for the final rendered UOps # matcher functions returns True or False (or None to not match) -spec = PatternMatcher([(x, functools.partial(lambda fxn,**kw: UOp.const(dtypes.bool, r) if (r:=fxn(**kw)) is not None else None, y)) for (x,y) in [ +spec = PatternMatcher([ (UPat(UOps.DEFINE_GLOBAL, name="x"), lambda x: isinstance(x.dtype, (PtrDType, ImageDType)) and not x.dtype.local), (UPat(UOps.DEFINE_LOCAL, name="x"), lambda x: isinstance(x.dtype, PtrDType) and x.dtype.local), (UPat(UOps.DEFINE_ACC, src=(UPat(UOps.CONST, name="c"),), name="x", allow_any_len=True), @@ -631,9 +631,9 @@ spec = PatternMatcher([(x, functools.partial(lambda fxn,**kw: UOp.const(dtypes.b # PTX LOAD/STORE (UPat((UOps.LOAD, UOps.STORE), src=(UPat(dtype=dtypes.int64),), allow_any_len=True), lambda: True), (UPat(UOps.BARRIER, dtypes.void, src=UPat(UOps.STORE, src=(UPat(dtype=dtypes.int64),), allow_any_len=True)), lambda: True), -]]) +]) def type_verify(uops:List[UOp]): for u in uops: - chk = spec.rewrite(u) - assert chk is not None and chk.arg is True, f"UOp verification failed on {u.op} {u.dtype} {len(u.src)} {[x.op for x in u.src]} {u.arg}" + chk = cast(bool, spec.rewrite(u)) + assert chk is True, f"UOp verification failed on {u.op} {u.dtype} {len(u.src)} {[x.op for x in u.src]} {u.arg}"