[run_process_replay] faster and simpler match function (#4876)

This commit is contained in:
George Hotz 2024-06-08 14:08:30 +02:00 committed by GitHub
parent aadab3e3da
commit 9c30889ce9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 11 additions and 13 deletions

View File

@ -1,5 +1,5 @@
from __future__ import annotations
from typing import Iterator, Optional, Tuple, Any, Dict, List, DefaultDict, Set, Callable, Union, cast
from typing import Iterator, Optional, Tuple, Any, Dict, List, DefaultDict, Set, Callable, Union, cast, TypeVar
import functools, itertools, heapq
from collections import defaultdict
from enum import Enum, auto
@ -77,21 +77,19 @@ class UPat:
dtype: Optional[Union[DType, Set[DType]]] = None
allow_len: Set[int] = field(default_factory=set)
T = TypeVar("T")
def __unmatch(m1:Union[T, Set[T]], m2:T) -> bool:
if isinstance(m1, set):
if m2 not in m1: return True
elif m2 != m1: return True
return False
def _match(uop:UOp, pat:UPat, store:Dict[str, UOp]) -> bool:
if pat.name in store and store[pat.name] != uop: return False
if pat.name is not None: store[pat.name] = uop
if pat.arg is not None:
if isinstance(pat.arg, set):
if uop.arg not in pat.arg: return False
elif uop.arg != pat.arg: return False
if pat.dtype is not None:
if isinstance(pat.dtype, set):
if uop.dtype not in pat.dtype: return False
elif uop.dtype != pat.dtype: return False
if pat.uop is not None:
if isinstance(pat.uop, set):
if uop.uop not in pat.uop: return False
elif uop.uop != pat.uop: return False
if (pat.arg is not None and __unmatch(pat.arg, uop.arg)) or \
(pat.dtype is not None and uop.dtype is not None and __unmatch(pat.dtype, uop.dtype)) or \
(pat.uop is not None and __unmatch(pat.uop, uop.uop)): return False
if pat.vin is None: return True
# only one if it's a tuple
# try all permutations if it's a list