diff --git a/tinygrad/ops.py b/tinygrad/ops.py index ca6a9068..08075e44 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -1,8 +1,7 @@ from __future__ import annotations -from typing import Any, List, Optional, Set, Union, Tuple, Dict, Callable, cast, TYPE_CHECKING, TypeVar, DefaultDict +from typing import Any, List, Optional, Set, Union, Tuple, Dict, Callable, cast, TYPE_CHECKING, TypeVar import sys, time, functools, itertools, math, operator, hashlib from enum import auto, IntEnum, Enum -from collections import defaultdict from dataclasses import dataclass, field from tinygrad.dtype import ConstType, ImageDType, PtrDType, dtypes, DType, truncate from tinygrad.helpers import _CURRENT_KERNEL, ContextVar, pretty_print, prod, getenv, all_same @@ -466,18 +465,19 @@ class UPatAny(UPat): class PatternMatcher: def __init__(self, patterns:List[Tuple[UPat, Callable]]): self.patterns = patterns - self.pdict: DefaultDict[Tuple[UOps, Any], List[Tuple[UPat, Callable, Set]]] = defaultdict(list) + # NOTE: use of DefaultDict here is very dangerous! all keys will live for the lifetime of the PatternMatcher! + self.pdict: Dict[Tuple[UOps, Any], List[Tuple[UPat, Callable, Set]]] = {} # uop is required, arg is optional for p,fxn in self.patterns: assert p.op is not None - for uop in p.op: self.pdict[(uop, p.arg)].append((p, fxn, p.early_reject)) + for uop in p.op: self.pdict.setdefault((uop, p.arg), []).append((p, fxn, p.early_reject)) @functools.lru_cache(None) # pylint: disable=method-cache-max-size-none def __add__(self, more:PatternMatcher): return PatternMatcher(self.patterns+more.patterns) def rewrite(self, uop:UOp, ctx=None) -> Optional[UOp]: ler = set([v for u in uop.src for v in ((u.op, u.arg), (u.op, None))]) - for p,fxn,early_reject in self.pdict[(uop.op, uop.arg)] + ([] if uop.arg is None else self.pdict[(uop.op, None)]): + for p,fxn,early_reject in self.pdict.get((uop.op, uop.arg), []) + ([] if uop.arg is None else self.pdict.get((uop.op, None), [])): if not early_reject.issubset(ler): continue if (matches := p.match(uop, {})) and (ret:=(fxn(ctx, **matches[0]) if ctx is not None else fxn(**matches[0]))) is not None: return ret return None @@ -502,7 +502,7 @@ class TrackedPatternMatcher(PatternMatcher): def rewrite(self, uop:UOp, ctx=None) -> Optional[UOp]: ret = None ler = set([v for u in uop.src for v in ((u.op, u.arg), (u.op, None))]) - for p,fxn,early_reject in self.pdict[(uop.op, uop.arg)] + ([] if uop.arg is None else self.pdict[(uop.op, None)]): + for p,fxn,early_reject in self.pdict.get((uop.op, uop.arg), []) + ([] if uop.arg is None else self.pdict.get((uop.op, None), [])): st = time.perf_counter() if not early_reject.issubset(ler): match_stats[p][2] += time.perf_counter()-st