Merge pull request #4222 from Qazalin/fuzz-multi0

Tunable multi output fusion
This commit is contained in:
qazal 2024-04-19 08:07:45 +03:00 committed by GitHub
commit 43841a32b7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 3 additions and 2 deletions

View File

@ -4,7 +4,7 @@ from dataclasses import dataclass
from typing import Tuple, List, Dict, Optional, Set, DefaultDict
from tinygrad.ops import LoadOps, ScheduleItem, BufferOps, LazyOp, ReduceOps, ConstBuffer, MemBuffer, UNSAFE_PAD_OPS, UnaryOps
from tinygrad.features.graph import log_lazybuffer, realized_lazybuffer
from tinygrad.helpers import GRAPH, DEBUG, GlobalCounters, prod, dedup, all_int, merge_dicts, getenv
from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, GlobalCounters, prod, dedup, all_int, merge_dicts, getenv
from tinygrad.shape.symbolic import Variable
from tinygrad.dtype import ImageDType, dtypes
from tinygrad.lazy import LazyBuffer
@ -213,7 +213,7 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> Tuple[Defaul
output_groups: DefaultDict[Tuple, List[LazyBuffer]] = defaultdict(list)
for r in realizes:
if r.realized is not None or r.op is LoadOps.CONST or r in seen: continue
output_groups[(reduce_for_op[r], ) if r in reduce_for_op else (r, )].append(r)
output_groups[(reduce_for_op[r], ) if r in reduce_for_op and MULTIOUTPUT else (r, )].append(r)
# preschedule all buffers in realizes
prescheduled = {group[0]:_schedule_group(tuple(group), realizes, reduce_for_op) for group in output_groups.values()}

View File

@ -97,6 +97,7 @@ class ContextVar:
DEBUG, IMAGE, BEAM, NOOPT = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0), ContextVar("BEAM", 0), ContextVar("NOOPT", 0)
WINO, THREEFRY, CACHECOLLECTING = ContextVar("WINO", 0), ContextVar("THREEFRY", 0), ContextVar("CACHECOLLECTING", 1)
GRAPH, GRAPHPATH, RING = ContextVar("GRAPH", 0), getenv("GRAPHPATH", "/tmp/net"), ContextVar("RING", 1)
MULTIOUTPUT = ContextVar("MULTIOUTPUT", 1)
# **************** global state Counters ****************