remove defaultdict from PatternMatcher [run_process_replay] (#6754)

* remove defaultdict from PatternMatcher [run_process_replay]

* nicer way to write that

* same line count

* tpm too
This commit is contained in:
George Hotz 2024-09-26 11:25:01 +08:00 committed by GitHub
parent 7e73c7b3cc
commit 717b394391
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 6 additions and 6 deletions

View File

@ -1,8 +1,7 @@
from __future__ import annotations
from typing import Any, List, Optional, Set, Union, Tuple, Dict, Callable, cast, TYPE_CHECKING, TypeVar, DefaultDict
from typing import Any, List, Optional, Set, Union, Tuple, Dict, Callable, cast, TYPE_CHECKING, TypeVar
import sys, time, functools, itertools, math, operator, hashlib
from enum import auto, IntEnum, Enum
from collections import defaultdict
from dataclasses import dataclass, field
from tinygrad.dtype import ConstType, ImageDType, PtrDType, dtypes, DType, truncate
from tinygrad.helpers import _CURRENT_KERNEL, ContextVar, pretty_print, prod, getenv, all_same
@ -466,18 +465,19 @@ class UPatAny(UPat):
class PatternMatcher:
def __init__(self, patterns:List[Tuple[UPat, Callable]]):
self.patterns = patterns
self.pdict: DefaultDict[Tuple[UOps, Any], List[Tuple[UPat, Callable, Set]]] = defaultdict(list)
# NOTE: use of DefaultDict here is very dangerous! all keys will live for the lifetime of the PatternMatcher!
self.pdict: Dict[Tuple[UOps, Any], List[Tuple[UPat, Callable, Set]]] = {}
# uop is required, arg is optional
for p,fxn in self.patterns:
assert p.op is not None
for uop in p.op: self.pdict[(uop, p.arg)].append((p, fxn, p.early_reject))
for uop in p.op: self.pdict.setdefault((uop, p.arg), []).append((p, fxn, p.early_reject))
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
def __add__(self, more:PatternMatcher): return PatternMatcher(self.patterns+more.patterns)
def rewrite(self, uop:UOp, ctx=None) -> Optional[UOp]:
ler = set([v for u in uop.src for v in ((u.op, u.arg), (u.op, None))])
for p,fxn,early_reject in self.pdict[(uop.op, uop.arg)] + ([] if uop.arg is None else self.pdict[(uop.op, None)]):
for p,fxn,early_reject in self.pdict.get((uop.op, uop.arg), []) + ([] if uop.arg is None else self.pdict.get((uop.op, None), [])):
if not early_reject.issubset(ler): continue
if (matches := p.match(uop, {})) and (ret:=(fxn(ctx, **matches[0]) if ctx is not None else fxn(**matches[0]))) is not None: return ret
return None
@ -502,7 +502,7 @@ class TrackedPatternMatcher(PatternMatcher):
def rewrite(self, uop:UOp, ctx=None) -> Optional[UOp]:
ret = None
ler = set([v for u in uop.src for v in ((u.op, u.arg), (u.op, None))])
for p,fxn,early_reject in self.pdict[(uop.op, uop.arg)] + ([] if uop.arg is None else self.pdict[(uop.op, None)]):
for p,fxn,early_reject in self.pdict.get((uop.op, uop.arg), []) + ([] if uop.arg is None else self.pdict.get((uop.op, None), [])):
st = time.perf_counter()
if not early_reject.issubset(ler):
match_stats[p][2] += time.perf_counter()-st