Fix repr upat (#5705)

* test

* fix

* x fix

* simpler

* rm extra space
This commit is contained in:
kormann 2024-07-25 18:05:48 +02:00 committed by GitHub
parent 1c992de257
commit 1e2eac755d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 6 additions and 4 deletions

View File

@ -190,6 +190,8 @@ class TestPatternMatcher(TestUOps):
dtypes._float2 = dtypes.float.vec(2)
upat = UPat(UOps.CONST, name="x", dtype=dtypes.float)
assert str(upat) == str(eval(str(upat)))
evpat:UPat = eval(repr(UPat(src = [UPat(name='a'), UPat(name='b')])))
assert len(evpat.src) == 2
for i in range(20): upat = UPat(UOps.ALU, name="x", src=[upat, upat], arg=BinaryOps.ADD)
assert len(str(upat)) < 10_000
assert str(eval(str(upat))) == str(upat)

View File

@ -142,9 +142,9 @@ class UPat:
name, u.dtype, allow_any_len=(isinstance(name, str) and 'allow_any_len' in name))
def __repr__(self):
def rep(x):
form = "UPat(%s, %s, name=%s, dtype=%s, allow_any_len=%s, src=(%s))"
return form % (('{%s}'%', '.join(map(str,x.op))) if isinstance(x.op, tuple) else x.op, x.arg,
repr(x.name), set(x.dtype) if x.dtype else None, x.allowed_len == 0, "%s")
form = "UPat(%s, %s, name=%s, dtype=%s, allow_any_len=%s, src=%s)"
return form % (None if x.op is None else ('(%s)'%', '.join(map(str, x.op))), x.arg, repr(x.name),
set(x.dtype) if x.dtype else None, x.allowed_len == 0, "[%s]" if x.src and len(x.src)>1 else "(%s)")
return pretty_print(self, rep, srcfn=lambda x:None if x.src is None else [next(x.src[0])] if isinstance(x.src[0], itertools.repeat) else x.src[0])
def _match(uop:UOp, pat:UPat, store:Dict[str, UOp]) -> List[Dict[str, UOp]]:

View File

@ -332,4 +332,4 @@ def pretty_print(x:Any, rep:Callable, srcfn=lambda x: x.src, cache=None, d=0)->s
if cache is None: dfs(x, cache:={})
if (cx:=cache.setdefault(x, [0,0,False]))[2]: return f"{' '*d} x{cx[0]}"
cx[2], srcs = True, ('None' if srcfn(x) is None else''.join(f'\n{pretty_print(s, rep, srcfn, cache, d+2)},' for s in srcfn(x)))
return f"{' '*d} {f'x{cx[0]}:=' * (cx[1]>1)}{rep(x)}" % srcs
return f"{' '*d}{f'x{cx[0]}:=' * (cx[1]>1)}{rep(x)}" % srcs