mirror of https://github.com/commaai/tinygrad.git
[run_process_replay] faster and simpler match function (#4876)
This commit is contained in:
parent
aadab3e3da
commit
9c30889ce9
|
@ -1,5 +1,5 @@
|
|||
from __future__ import annotations
|
||||
from typing import Iterator, Optional, Tuple, Any, Dict, List, DefaultDict, Set, Callable, Union, cast
|
||||
from typing import Iterator, Optional, Tuple, Any, Dict, List, DefaultDict, Set, Callable, Union, cast, TypeVar
|
||||
import functools, itertools, heapq
|
||||
from collections import defaultdict
|
||||
from enum import Enum, auto
|
||||
|
@ -77,21 +77,19 @@ class UPat:
|
|||
dtype: Optional[Union[DType, Set[DType]]] = None
|
||||
allow_len: Set[int] = field(default_factory=set)
|
||||
|
||||
T = TypeVar("T")
|
||||
def __unmatch(m1:Union[T, Set[T]], m2:T) -> bool:
|
||||
if isinstance(m1, set):
|
||||
if m2 not in m1: return True
|
||||
elif m2 != m1: return True
|
||||
return False
|
||||
|
||||
def _match(uop:UOp, pat:UPat, store:Dict[str, UOp]) -> bool:
|
||||
if pat.name in store and store[pat.name] != uop: return False
|
||||
if pat.name is not None: store[pat.name] = uop
|
||||
if pat.arg is not None:
|
||||
if isinstance(pat.arg, set):
|
||||
if uop.arg not in pat.arg: return False
|
||||
elif uop.arg != pat.arg: return False
|
||||
if pat.dtype is not None:
|
||||
if isinstance(pat.dtype, set):
|
||||
if uop.dtype not in pat.dtype: return False
|
||||
elif uop.dtype != pat.dtype: return False
|
||||
if pat.uop is not None:
|
||||
if isinstance(pat.uop, set):
|
||||
if uop.uop not in pat.uop: return False
|
||||
elif uop.uop != pat.uop: return False
|
||||
if (pat.arg is not None and __unmatch(pat.arg, uop.arg)) or \
|
||||
(pat.dtype is not None and uop.dtype is not None and __unmatch(pat.dtype, uop.dtype)) or \
|
||||
(pat.uop is not None and __unmatch(pat.uop, uop.uop)): return False
|
||||
if pat.vin is None: return True
|
||||
# only one if it's a tuple
|
||||
# try all permutations if it's a list
|
||||
|
|
Loading…
Reference in New Issue