add LBScheduleItem type [run_process_replay] (#5944)

* add LBScheduleItem type [run_process_replay]

* minor cleanups

* fix

* fix fuzz tests

* add group cache type
This commit is contained in:
George Hotz 2024-08-06 14:49:40 -07:00 committed by GitHub
parent 1dab75ae37
commit 73d4d51845
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 47 additions and 32 deletions

View File

@ -30,16 +30,16 @@ def fuzz_schedule(outs:List[LazyBuffer]):
seed = Tensor._seed
ts, (_, prescheduled) = toposorts[0]
for key in ts:
for out in (ps:=prescheduled[key])[0]:
for out in (ps:=prescheduled[key]).outputs:
# freeze assign state before exec
if out.op is MetaOps.ASSIGN:
prerealized[out] = out.buffer.as_buffer()
assign_targets[out.srcs[1]] = out
for x in ps[2]:
for x in ps.inputs:
if x not in ground_truth and x.device != "NPY": prerealized[x] = x.buffer.as_buffer()
si = ScheduleItem(ps[1], tuple(x.buffer for x in ps[0]+ps[2] if x.size != 0))
si = ScheduleItem(ps.ast, tuple(x.buffer for x in ps.outputs+ps.inputs if x.size != 0))
_exec_si(si, seed)
for out in ps[0]:
for out in ps.outputs:
ground_truth[out] = out.buffer.as_buffer()
del out.srcs # only schedule the LazyBuffer in this fuzz run
@ -48,19 +48,19 @@ def fuzz_schedule(outs:List[LazyBuffer]):
if DEBUG >= 1: print(colored(f"testing permutation {i} {ctx}", "yellow"))
rawbufs: Dict[LazyBuffer, Buffer] = {}
for key in ts:
for out in (ps:=prescheduled[key])[0]:
for out in (ps:=prescheduled[key]).outputs:
rawbufs[out] = Buffer(out.buffer.device, out.buffer.size, out.buffer.dtype)
if out.op is MetaOps.ASSIGN: rawbufs[out].ensure_allocated().copyin(prerealized[out])
for x in ps[2]:
for x in ps.inputs:
if x not in rawbufs:
# override the assign_target after ASSIGN
if x in assign_targets and assign_targets[x] in rawbufs: rawbufs[x] = rawbufs[assign_targets[x]]
elif x.device == "NPY": rawbufs[x] = x.buffer
# copy the pre realized input
else: rawbufs[x] = Buffer(x.buffer.device, x.buffer.size, x.buffer.dtype, initial_value=prerealized[x])
si = ScheduleItem(ps[1], tuple(rawbufs[x] for x in ps[0]+ps[2] if x.size != 0))
si = ScheduleItem(ps.ast, tuple(rawbufs[x] for x in ps.outputs+ps.inputs if x.size != 0))
_exec_si(si, seed)
for out in ps[0]:
for out in ps.outputs:
outbuf = np.frombuffer(rawbufs[out].as_buffer(), _to_np_dtype(out.dtype))
try: np.testing.assert_allclose(outbuf, np.frombuffer(ground_truth[out], _to_np_dtype(out.dtype)), atol=1e-2, rtol=1e-2)
except Exception as e:

View File

@ -1,6 +1,6 @@
import sys, pickle, atexit
from collections import defaultdict, deque
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Tuple, List, Dict, Optional, Set, DefaultDict, cast, get_args
from tinygrad.ops import MetaOps, BufferOps, LazyOp, Op, ReduceOps, ConstBuffer, MemBuffer, UNSAFE_PAD_OPS, UnaryOps, reduce_st
from tinygrad.engine.graph import log_lazybuffer, realized_lazybuffer
@ -35,11 +35,20 @@ class ScheduleItem:
"""Read only buffers in the schedule."""
return self.bufs[len(self.ast.src):] if self.ast.op is MetaOps.KERNEL else self.bufs[1:]
@dataclass(frozen=True)
class LBScheduleItem:
ast: LazyOp
outputs: List[LazyBuffer]
inputs: List[LazyBuffer]
var_vals: Dict[Variable, int] = field(default_factory=dict)
metadata: List[Metadata] = field(default_factory=list)
# *** DAG transformation: List[LazyBuffer] -> ScheduleItem ***
def _recursive_lazyop(buf:LazyBuffer, inputs:Dict[LazyBuffer, int], outputs:Tuple[LazyBuffer, ...], var_vals:Dict[Variable, int], st:ShapeTracker,
realizes:Dict[LazyBuffer, None], assign_targets:Dict[LazyBuffer, LazyBuffer],
reduce_info:Dict[Tuple[LazyBuffer, ShapeTracker], Tuple[ShapeTracker, Tuple[int, ...]]], cache) -> LazyOp:
reduce_info:Dict[Tuple[LazyBuffer, ShapeTracker], Tuple[ShapeTracker, Tuple[int, ...]]],
cache:Dict[Tuple[LazyBuffer, ShapeTracker], LazyOp]) -> LazyOp:
"""recursively create a lazyop"""
if buf is not buf.base: st, buf = buf.st+st, buf.base
if (buf, st) in cache: return cache[(buf, st)]
@ -89,8 +98,10 @@ def _permute_reduce(input_st:ShapeTracker, axis:Tuple[int, ...]) -> Tuple[ShapeT
tmp = input_st.permute(permute_axis)
return tmp, tmp.shape[-len(axis):]
def _recurse_reduceops(buf:LazyBuffer, st:ShapeTracker, realizes:Dict[LazyBuffer, None], outs:List[LazyBuffer],\
reduce_info:Dict[Tuple[LazyBuffer, ShapeTracker], Tuple[ShapeTracker, Tuple[int, ...]]], cache) -> Optional[Tuple[LazyBuffer, ShapeTracker]]:
def _recurse_reduceops(buf:LazyBuffer, st:ShapeTracker, realizes:Dict[LazyBuffer, None], outs:List[LazyBuffer],
reduce_info:Dict[Tuple[LazyBuffer, ShapeTracker], Tuple[ShapeTracker, Tuple[int, ...]]],
cache:Dict[Tuple[LazyBuffer, ShapeTracker], Optional[Tuple[LazyBuffer, ShapeTracker]]]) -> \
Optional[Tuple[LazyBuffer, ShapeTracker]]:
if (buf, st) in cache: return cache[(buf, st)]
if buf.base.realized is not None or (buf.base in realizes and buf.base not in outs): return None
if buf is not buf.base: st, buf = buf.st+st, buf.base
@ -125,15 +136,17 @@ def _recurse_reduceops(buf:LazyBuffer, st:ShapeTracker, realizes:Dict[LazyBuffer
return (buf, st)
return cache.setdefault((buf, st), top_reduce)
def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]) -> Tuple[LazyOp, List[LazyBuffer], Dict[Variable, int], List[Metadata]]:
def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]) -> LBScheduleItem:
"""describe the computation for a LazyBuffer with LazyOp + inputs + var_vals"""
if (out:=outs[0]).op is MetaOps.COPY and getenv("USE_COPY_KERNEL") and out.device.split(":")[0] == out.srcs[0].device.split(":")[0]:
rd = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.uint8, st:=ShapeTracker.from_shape((out.arg,))))
return LazyOp(MetaOps.KERNEL, (LazyOp(BufferOps.STORE, (rd,), MemBuffer(0, dtypes.uint8, st)), )), [x.base for x in out.srcs], {}, []
if out.op in {MetaOps.CUSTOM, MetaOps.COPY, MetaOps.EMPTY, MetaOps.VIEW}: return LazyOp(out.op, (), out.arg), [x.base for x in out.srcs], {}, []
wr = LazyOp(BufferOps.STORE, (rd,), MemBuffer(0, dtypes.uint8, st))
return LBScheduleItem(LazyOp(MetaOps.KERNEL, (wr,)), outs, [x.base for x in out.srcs])
if out.op in {MetaOps.CUSTOM, MetaOps.COPY, MetaOps.EMPTY, MetaOps.VIEW}:
return LBScheduleItem(LazyOp(out.op, (), out.arg), outs, [x.base for x in out.srcs])
# push through all movementops between reduceops
reduce_info: Dict[Tuple[LazyBuffer, ShapeTracker], Tuple[ShapeTracker, Tuple[int, ...]]] = {}
seen_ops: Dict[Tuple[LazyBuffer, ShapeTracker], None] = {}
seen_ops: Dict[Tuple[LazyBuffer, ShapeTracker], Optional[Tuple[LazyBuffer, ShapeTracker]]] = {}
for out in outs: _recurse_reduceops(out, out.st, realizes, outs, reduce_info, seen_ops)
# pad all reduceops to the max of each dimension
shape_dims = [sorted(dedup(dims)) for dims in zip(*[input_st.shape for input_st,_ in reduce_info.values()])]
@ -158,13 +171,14 @@ def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]) ->
output_st, vv = output_st.simplify().unbind()
if vv: var_vals.update(vv)
ast.append(LazyOp(BufferOps.STORE, (lop,), MemBuffer(i, out.dtype, output_st)))
return LazyOp(MetaOps.KERNEL, tuple(ast)), list(inputs), var_vals, dedup([x[0].metadata for x in cache if x[0].metadata and x[0] not in inputs])
return LBScheduleItem(LazyOp(MetaOps.KERNEL, tuple(ast)), outs, list(inputs), var_vals,
dedup([x[0].metadata for x in cache if x[0].metadata and x[0] not in inputs]))
# *** DAG creation: decide which LazyBuffers should realize ***
def _recurse_lb(buf:LazyBuffer, realizes:Dict[LazyBuffer, None], allbufs:Dict[LazyBuffer, None], simple_pads:Dict[LazyBuffer, None],\
children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], assign_targets:Dict[LazyBuffer, LazyBuffer],\
double_reduces:Dict[LazyBuffer, None], scheduled=False) -> None:
def _recurse_lb(buf:LazyBuffer, realizes:Dict[LazyBuffer, None], allbufs:Dict[LazyBuffer, None], simple_pads:Dict[LazyBuffer, None],
children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], assign_targets:Dict[LazyBuffer, LazyBuffer],
double_reduces:Dict[LazyBuffer, None], scheduled=False) -> None:
"""recursively search the entire graph for all LazyBuffers, insert realizes after expands"""
if buf in allbufs or buf.base.realized is not None: return
if GRAPH: log_lazybuffer(buf, scheduled)
@ -205,7 +219,8 @@ def _is_padding_okay(buf:LazyBuffer, realizes:Dict[LazyBuffer, None]) -> bool:
return all(_is_padding_okay(x.base, realizes) for x in buf.srcs)
def _recursive_group(tr:LazyBuffer, st:ShapeTracker, r:LazyBuffer, children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]],
realizes:Dict[LazyBuffer, None], reduce_for_op:Dict[LazyBuffer, LazyBuffer], group:Dict[LazyBuffer, None], cache:Set):
realizes:Dict[LazyBuffer, None], reduce_for_op:Dict[LazyBuffer, LazyBuffer], group:Dict[LazyBuffer, None],
cache:Set[Tuple[LazyBuffer, ShapeTracker]]) -> None:
"""recursively search the LazyBuffer for groupable children, realize the LazyBuffer if a child can't group"""
if (tr, st) in cache: return
cache.add((tr, st))
@ -239,7 +254,7 @@ SCHEDULES: List = []
def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> \
Tuple[DefaultDict[LazyBuffer, List[LazyBuffer]], # this is the graph
DefaultDict[LazyBuffer, int], # this is the in-degree of the graph
Dict[LazyBuffer, Tuple[List[LazyBuffer], LazyOp, List[LazyBuffer], Dict[Variable, int], List[Metadata]]]]: # this is ???
Dict[LazyBuffer, LBScheduleItem]]: # this is the schedule item, but still in LazyBuffer
"""create a graph for realizing the outputs"""
# start by just realizing the buffers passed in
realizes: Dict[LazyBuffer, None] = {x.base:None for x in outs if x.base.realized is None}
@ -330,20 +345,20 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> \
buf.buffer.options = None
# preschedule all buffers in realizes
prescheduled = {group[0]:(group, *_lower_lazybuffer(group, realizes)) for group in output_groups.values()}
schedule_targets = {out:ps for ps in prescheduled.values() for out in ps[0]}
prescheduled = {group[0]:_lower_lazybuffer(group, realizes) for group in output_groups.values()}
schedule_targets = {out:ps for ps in prescheduled.values() for out in ps.outputs}
graph: DefaultDict[LazyBuffer, List[LazyBuffer]] = defaultdict(list)
in_degree: DefaultDict[LazyBuffer, int] = defaultdict(int)
for key, lsi in prescheduled.items():
if key not in in_degree: in_degree[key] = 0
# realize outputs after all parents are realized
scheduled_parents = set(schedule_targets[x][0][0] for x in lsi[2] if x in schedule_targets)
scheduled_parents = set(schedule_targets[x].outputs[0] for x in lsi.inputs if x in schedule_targets)
for x in scheduled_parents:
graph[x].append(key)
in_degree[key] += 1
# realize outputs before a parent is assigned to
parents_assigns = set(schedule_targets[assign_targets[x]][0][0] for x in lsi[2] if x in assign_targets)
parents_assigns = set(schedule_targets[assign_targets[x]].outputs[0] for x in lsi.inputs if x in assign_targets)
for assign in parents_assigns:
graph[key].append(assign)
in_degree[assign] += 1
@ -373,15 +388,15 @@ def create_schedule_with_vars(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffe
kernel_number = GlobalCounters.kernel_count
while queue:
ps = queue.popleft()
for buf in ps[0]: seen.add(buf)
for buf in ps.outputs: seen.add(buf)
if GRAPH:
kernel_number += 1
for out in ps[0]: realized_lazybuffer(out, kernel_number)
var_vals = merge_dicts([var_vals, ps[3]])
for out in ps[0]: del out.srcs # can only schedule once
schedule.append(si:=ScheduleItem(ps[1], tuple(x.buffer for x in ps[0]+ps[2] if x.size != 0), ps[4]))
for out in ps.outputs: realized_lazybuffer(out, kernel_number)
var_vals = merge_dicts([var_vals, ps.var_vals])
for out in ps.outputs: del out.srcs # can only schedule once
schedule.append(si:=ScheduleItem(ps.ast, tuple(x.buffer for x in ps.outputs+ps.inputs if x.size != 0), ps.metadata))
if logops and si.ast.op is MetaOps.KERNEL and not any(i.device.startswith("DISK:") for i in si.inputs): logops.write(str(si.ast)+"\n")
for x in graph[ps[0][0]]:
for x in graph[ps.outputs[0]]:
in_degree[x] -= 1
if in_degree[x] == 0: queue.append(prescheduled[x])