mirror of https://github.com/commaai/tinygrad.git
remove defaultdict from PatternMatcher [run_process_replay] (#6754)
* remove defaultdict from PatternMatcher [run_process_replay] * nicer way to write that * same line count * tpm too
This commit is contained in:
parent
7e73c7b3cc
commit
717b394391
|
@ -1,8 +1,7 @@
|
|||
from __future__ import annotations
|
||||
from typing import Any, List, Optional, Set, Union, Tuple, Dict, Callable, cast, TYPE_CHECKING, TypeVar, DefaultDict
|
||||
from typing import Any, List, Optional, Set, Union, Tuple, Dict, Callable, cast, TYPE_CHECKING, TypeVar
|
||||
import sys, time, functools, itertools, math, operator, hashlib
|
||||
from enum import auto, IntEnum, Enum
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from tinygrad.dtype import ConstType, ImageDType, PtrDType, dtypes, DType, truncate
|
||||
from tinygrad.helpers import _CURRENT_KERNEL, ContextVar, pretty_print, prod, getenv, all_same
|
||||
|
@ -466,18 +465,19 @@ class UPatAny(UPat):
|
|||
class PatternMatcher:
|
||||
def __init__(self, patterns:List[Tuple[UPat, Callable]]):
|
||||
self.patterns = patterns
|
||||
self.pdict: DefaultDict[Tuple[UOps, Any], List[Tuple[UPat, Callable, Set]]] = defaultdict(list)
|
||||
# NOTE: use of DefaultDict here is very dangerous! all keys will live for the lifetime of the PatternMatcher!
|
||||
self.pdict: Dict[Tuple[UOps, Any], List[Tuple[UPat, Callable, Set]]] = {}
|
||||
# uop is required, arg is optional
|
||||
for p,fxn in self.patterns:
|
||||
assert p.op is not None
|
||||
for uop in p.op: self.pdict[(uop, p.arg)].append((p, fxn, p.early_reject))
|
||||
for uop in p.op: self.pdict.setdefault((uop, p.arg), []).append((p, fxn, p.early_reject))
|
||||
|
||||
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
|
||||
def __add__(self, more:PatternMatcher): return PatternMatcher(self.patterns+more.patterns)
|
||||
|
||||
def rewrite(self, uop:UOp, ctx=None) -> Optional[UOp]:
|
||||
ler = set([v for u in uop.src for v in ((u.op, u.arg), (u.op, None))])
|
||||
for p,fxn,early_reject in self.pdict[(uop.op, uop.arg)] + ([] if uop.arg is None else self.pdict[(uop.op, None)]):
|
||||
for p,fxn,early_reject in self.pdict.get((uop.op, uop.arg), []) + ([] if uop.arg is None else self.pdict.get((uop.op, None), [])):
|
||||
if not early_reject.issubset(ler): continue
|
||||
if (matches := p.match(uop, {})) and (ret:=(fxn(ctx, **matches[0]) if ctx is not None else fxn(**matches[0]))) is not None: return ret
|
||||
return None
|
||||
|
@ -502,7 +502,7 @@ class TrackedPatternMatcher(PatternMatcher):
|
|||
def rewrite(self, uop:UOp, ctx=None) -> Optional[UOp]:
|
||||
ret = None
|
||||
ler = set([v for u in uop.src for v in ((u.op, u.arg), (u.op, None))])
|
||||
for p,fxn,early_reject in self.pdict[(uop.op, uop.arg)] + ([] if uop.arg is None else self.pdict[(uop.op, None)]):
|
||||
for p,fxn,early_reject in self.pdict.get((uop.op, uop.arg), []) + ([] if uop.arg is None else self.pdict.get((uop.op, None), [])):
|
||||
st = time.perf_counter()
|
||||
if not early_reject.issubset(ler):
|
||||
match_stats[p][2] += time.perf_counter()-st
|
||||
|
|
Loading…
Reference in New Issue