From 73d4d51845df0c82b4897aa0161cd27437a9a2fb Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Tue, 6 Aug 2024 14:49:40 -0700 Subject: [PATCH] add LBScheduleItem type [run_process_replay] (#5944) * add LBScheduleItem type [run_process_replay] * minor cleanups * fix * fix fuzz tests * add group cache type --- test/external/fuzz_schedule.py | 16 ++++----- tinygrad/engine/schedule.py | 63 +++++++++++++++++++++------------- 2 files changed, 47 insertions(+), 32 deletions(-) diff --git a/test/external/fuzz_schedule.py b/test/external/fuzz_schedule.py index e6dc6faf..a9d53b42 100644 --- a/test/external/fuzz_schedule.py +++ b/test/external/fuzz_schedule.py @@ -30,16 +30,16 @@ def fuzz_schedule(outs:List[LazyBuffer]): seed = Tensor._seed ts, (_, prescheduled) = toposorts[0] for key in ts: - for out in (ps:=prescheduled[key])[0]: + for out in (ps:=prescheduled[key]).outputs: # freeze assign state before exec if out.op is MetaOps.ASSIGN: prerealized[out] = out.buffer.as_buffer() assign_targets[out.srcs[1]] = out - for x in ps[2]: + for x in ps.inputs: if x not in ground_truth and x.device != "NPY": prerealized[x] = x.buffer.as_buffer() - si = ScheduleItem(ps[1], tuple(x.buffer for x in ps[0]+ps[2] if x.size != 0)) + si = ScheduleItem(ps.ast, tuple(x.buffer for x in ps.outputs+ps.inputs if x.size != 0)) _exec_si(si, seed) - for out in ps[0]: + for out in ps.outputs: ground_truth[out] = out.buffer.as_buffer() del out.srcs # only schedule the LazyBuffer in this fuzz run @@ -48,19 +48,19 @@ def fuzz_schedule(outs:List[LazyBuffer]): if DEBUG >= 1: print(colored(f"testing permutation {i} {ctx}", "yellow")) rawbufs: Dict[LazyBuffer, Buffer] = {} for key in ts: - for out in (ps:=prescheduled[key])[0]: + for out in (ps:=prescheduled[key]).outputs: rawbufs[out] = Buffer(out.buffer.device, out.buffer.size, out.buffer.dtype) if out.op is MetaOps.ASSIGN: rawbufs[out].ensure_allocated().copyin(prerealized[out]) - for x in ps[2]: + for x in ps.inputs: if x not in rawbufs: # override the assign_target after ASSIGN if x in assign_targets and assign_targets[x] in rawbufs: rawbufs[x] = rawbufs[assign_targets[x]] elif x.device == "NPY": rawbufs[x] = x.buffer # copy the pre realized input else: rawbufs[x] = Buffer(x.buffer.device, x.buffer.size, x.buffer.dtype, initial_value=prerealized[x]) - si = ScheduleItem(ps[1], tuple(rawbufs[x] for x in ps[0]+ps[2] if x.size != 0)) + si = ScheduleItem(ps.ast, tuple(rawbufs[x] for x in ps.outputs+ps.inputs if x.size != 0)) _exec_si(si, seed) - for out in ps[0]: + for out in ps.outputs: outbuf = np.frombuffer(rawbufs[out].as_buffer(), _to_np_dtype(out.dtype)) try: np.testing.assert_allclose(outbuf, np.frombuffer(ground_truth[out], _to_np_dtype(out.dtype)), atol=1e-2, rtol=1e-2) except Exception as e: diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 67e9a2ff..fe105c84 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -1,6 +1,6 @@ import sys, pickle, atexit from collections import defaultdict, deque -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Tuple, List, Dict, Optional, Set, DefaultDict, cast, get_args from tinygrad.ops import MetaOps, BufferOps, LazyOp, Op, ReduceOps, ConstBuffer, MemBuffer, UNSAFE_PAD_OPS, UnaryOps, reduce_st from tinygrad.engine.graph import log_lazybuffer, realized_lazybuffer @@ -35,11 +35,20 @@ class ScheduleItem: """Read only buffers in the schedule.""" return self.bufs[len(self.ast.src):] if self.ast.op is MetaOps.KERNEL else self.bufs[1:] +@dataclass(frozen=True) +class LBScheduleItem: + ast: LazyOp + outputs: List[LazyBuffer] + inputs: List[LazyBuffer] + var_vals: Dict[Variable, int] = field(default_factory=dict) + metadata: List[Metadata] = field(default_factory=list) + # *** DAG transformation: List[LazyBuffer] -> ScheduleItem *** def _recursive_lazyop(buf:LazyBuffer, inputs:Dict[LazyBuffer, int], outputs:Tuple[LazyBuffer, ...], var_vals:Dict[Variable, int], st:ShapeTracker, realizes:Dict[LazyBuffer, None], assign_targets:Dict[LazyBuffer, LazyBuffer], - reduce_info:Dict[Tuple[LazyBuffer, ShapeTracker], Tuple[ShapeTracker, Tuple[int, ...]]], cache) -> LazyOp: + reduce_info:Dict[Tuple[LazyBuffer, ShapeTracker], Tuple[ShapeTracker, Tuple[int, ...]]], + cache:Dict[Tuple[LazyBuffer, ShapeTracker], LazyOp]) -> LazyOp: """recursively create a lazyop""" if buf is not buf.base: st, buf = buf.st+st, buf.base if (buf, st) in cache: return cache[(buf, st)] @@ -89,8 +98,10 @@ def _permute_reduce(input_st:ShapeTracker, axis:Tuple[int, ...]) -> Tuple[ShapeT tmp = input_st.permute(permute_axis) return tmp, tmp.shape[-len(axis):] -def _recurse_reduceops(buf:LazyBuffer, st:ShapeTracker, realizes:Dict[LazyBuffer, None], outs:List[LazyBuffer],\ - reduce_info:Dict[Tuple[LazyBuffer, ShapeTracker], Tuple[ShapeTracker, Tuple[int, ...]]], cache) -> Optional[Tuple[LazyBuffer, ShapeTracker]]: +def _recurse_reduceops(buf:LazyBuffer, st:ShapeTracker, realizes:Dict[LazyBuffer, None], outs:List[LazyBuffer], + reduce_info:Dict[Tuple[LazyBuffer, ShapeTracker], Tuple[ShapeTracker, Tuple[int, ...]]], + cache:Dict[Tuple[LazyBuffer, ShapeTracker], Optional[Tuple[LazyBuffer, ShapeTracker]]]) -> \ + Optional[Tuple[LazyBuffer, ShapeTracker]]: if (buf, st) in cache: return cache[(buf, st)] if buf.base.realized is not None or (buf.base in realizes and buf.base not in outs): return None if buf is not buf.base: st, buf = buf.st+st, buf.base @@ -125,15 +136,17 @@ def _recurse_reduceops(buf:LazyBuffer, st:ShapeTracker, realizes:Dict[LazyBuffer return (buf, st) return cache.setdefault((buf, st), top_reduce) -def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]) -> Tuple[LazyOp, List[LazyBuffer], Dict[Variable, int], List[Metadata]]: +def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]) -> LBScheduleItem: """describe the computation for a LazyBuffer with LazyOp + inputs + var_vals""" if (out:=outs[0]).op is MetaOps.COPY and getenv("USE_COPY_KERNEL") and out.device.split(":")[0] == out.srcs[0].device.split(":")[0]: rd = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.uint8, st:=ShapeTracker.from_shape((out.arg,)))) - return LazyOp(MetaOps.KERNEL, (LazyOp(BufferOps.STORE, (rd,), MemBuffer(0, dtypes.uint8, st)), )), [x.base for x in out.srcs], {}, [] - if out.op in {MetaOps.CUSTOM, MetaOps.COPY, MetaOps.EMPTY, MetaOps.VIEW}: return LazyOp(out.op, (), out.arg), [x.base for x in out.srcs], {}, [] + wr = LazyOp(BufferOps.STORE, (rd,), MemBuffer(0, dtypes.uint8, st)) + return LBScheduleItem(LazyOp(MetaOps.KERNEL, (wr,)), outs, [x.base for x in out.srcs]) + if out.op in {MetaOps.CUSTOM, MetaOps.COPY, MetaOps.EMPTY, MetaOps.VIEW}: + return LBScheduleItem(LazyOp(out.op, (), out.arg), outs, [x.base for x in out.srcs]) # push through all movementops between reduceops reduce_info: Dict[Tuple[LazyBuffer, ShapeTracker], Tuple[ShapeTracker, Tuple[int, ...]]] = {} - seen_ops: Dict[Tuple[LazyBuffer, ShapeTracker], None] = {} + seen_ops: Dict[Tuple[LazyBuffer, ShapeTracker], Optional[Tuple[LazyBuffer, ShapeTracker]]] = {} for out in outs: _recurse_reduceops(out, out.st, realizes, outs, reduce_info, seen_ops) # pad all reduceops to the max of each dimension shape_dims = [sorted(dedup(dims)) for dims in zip(*[input_st.shape for input_st,_ in reduce_info.values()])] @@ -158,13 +171,14 @@ def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]) -> output_st, vv = output_st.simplify().unbind() if vv: var_vals.update(vv) ast.append(LazyOp(BufferOps.STORE, (lop,), MemBuffer(i, out.dtype, output_st))) - return LazyOp(MetaOps.KERNEL, tuple(ast)), list(inputs), var_vals, dedup([x[0].metadata for x in cache if x[0].metadata and x[0] not in inputs]) + return LBScheduleItem(LazyOp(MetaOps.KERNEL, tuple(ast)), outs, list(inputs), var_vals, + dedup([x[0].metadata for x in cache if x[0].metadata and x[0] not in inputs])) # *** DAG creation: decide which LazyBuffers should realize *** -def _recurse_lb(buf:LazyBuffer, realizes:Dict[LazyBuffer, None], allbufs:Dict[LazyBuffer, None], simple_pads:Dict[LazyBuffer, None],\ - children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], assign_targets:Dict[LazyBuffer, LazyBuffer],\ - double_reduces:Dict[LazyBuffer, None], scheduled=False) -> None: +def _recurse_lb(buf:LazyBuffer, realizes:Dict[LazyBuffer, None], allbufs:Dict[LazyBuffer, None], simple_pads:Dict[LazyBuffer, None], + children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], assign_targets:Dict[LazyBuffer, LazyBuffer], + double_reduces:Dict[LazyBuffer, None], scheduled=False) -> None: """recursively search the entire graph for all LazyBuffers, insert realizes after expands""" if buf in allbufs or buf.base.realized is not None: return if GRAPH: log_lazybuffer(buf, scheduled) @@ -205,7 +219,8 @@ def _is_padding_okay(buf:LazyBuffer, realizes:Dict[LazyBuffer, None]) -> bool: return all(_is_padding_okay(x.base, realizes) for x in buf.srcs) def _recursive_group(tr:LazyBuffer, st:ShapeTracker, r:LazyBuffer, children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], - realizes:Dict[LazyBuffer, None], reduce_for_op:Dict[LazyBuffer, LazyBuffer], group:Dict[LazyBuffer, None], cache:Set): + realizes:Dict[LazyBuffer, None], reduce_for_op:Dict[LazyBuffer, LazyBuffer], group:Dict[LazyBuffer, None], + cache:Set[Tuple[LazyBuffer, ShapeTracker]]) -> None: """recursively search the LazyBuffer for groupable children, realize the LazyBuffer if a child can't group""" if (tr, st) in cache: return cache.add((tr, st)) @@ -239,7 +254,7 @@ SCHEDULES: List = [] def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> \ Tuple[DefaultDict[LazyBuffer, List[LazyBuffer]], # this is the graph DefaultDict[LazyBuffer, int], # this is the in-degree of the graph - Dict[LazyBuffer, Tuple[List[LazyBuffer], LazyOp, List[LazyBuffer], Dict[Variable, int], List[Metadata]]]]: # this is ??? + Dict[LazyBuffer, LBScheduleItem]]: # this is the schedule item, but still in LazyBuffer """create a graph for realizing the outputs""" # start by just realizing the buffers passed in realizes: Dict[LazyBuffer, None] = {x.base:None for x in outs if x.base.realized is None} @@ -330,20 +345,20 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> \ buf.buffer.options = None # preschedule all buffers in realizes - prescheduled = {group[0]:(group, *_lower_lazybuffer(group, realizes)) for group in output_groups.values()} - schedule_targets = {out:ps for ps in prescheduled.values() for out in ps[0]} + prescheduled = {group[0]:_lower_lazybuffer(group, realizes) for group in output_groups.values()} + schedule_targets = {out:ps for ps in prescheduled.values() for out in ps.outputs} graph: DefaultDict[LazyBuffer, List[LazyBuffer]] = defaultdict(list) in_degree: DefaultDict[LazyBuffer, int] = defaultdict(int) for key, lsi in prescheduled.items(): if key not in in_degree: in_degree[key] = 0 # realize outputs after all parents are realized - scheduled_parents = set(schedule_targets[x][0][0] for x in lsi[2] if x in schedule_targets) + scheduled_parents = set(schedule_targets[x].outputs[0] for x in lsi.inputs if x in schedule_targets) for x in scheduled_parents: graph[x].append(key) in_degree[key] += 1 # realize outputs before a parent is assigned to - parents_assigns = set(schedule_targets[assign_targets[x]][0][0] for x in lsi[2] if x in assign_targets) + parents_assigns = set(schedule_targets[assign_targets[x]].outputs[0] for x in lsi.inputs if x in assign_targets) for assign in parents_assigns: graph[key].append(assign) in_degree[assign] += 1 @@ -373,15 +388,15 @@ def create_schedule_with_vars(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffe kernel_number = GlobalCounters.kernel_count while queue: ps = queue.popleft() - for buf in ps[0]: seen.add(buf) + for buf in ps.outputs: seen.add(buf) if GRAPH: kernel_number += 1 - for out in ps[0]: realized_lazybuffer(out, kernel_number) - var_vals = merge_dicts([var_vals, ps[3]]) - for out in ps[0]: del out.srcs # can only schedule once - schedule.append(si:=ScheduleItem(ps[1], tuple(x.buffer for x in ps[0]+ps[2] if x.size != 0), ps[4])) + for out in ps.outputs: realized_lazybuffer(out, kernel_number) + var_vals = merge_dicts([var_vals, ps.var_vals]) + for out in ps.outputs: del out.srcs # can only schedule once + schedule.append(si:=ScheduleItem(ps.ast, tuple(x.buffer for x in ps.outputs+ps.inputs if x.size != 0), ps.metadata)) if logops and si.ast.op is MetaOps.KERNEL and not any(i.device.startswith("DISK:") for i in si.inputs): logops.write(str(si.ast)+"\n") - for x in graph[ps[0][0]]: + for x in graph[ps.outputs[0]]: in_degree[x] -= 1 if in_degree[x] == 0: queue.append(prescheduled[x])