add LBScheduleItem type [run_process_replay] (#5944)

* add LBScheduleItem type [run_process_replay]

* minor cleanups

* fix

* fix fuzz tests

* add group cache type
This commit is contained in:
George Hotz 2024-08-06 14:49:40 -07:00 committed by GitHub
parent 1dab75ae37
commit 73d4d51845
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 47 additions and 32 deletions

View File

@ -30,16 +30,16 @@ def fuzz_schedule(outs:List[LazyBuffer]):
seed = Tensor._seed seed = Tensor._seed
ts, (_, prescheduled) = toposorts[0] ts, (_, prescheduled) = toposorts[0]
for key in ts: for key in ts:
for out in (ps:=prescheduled[key])[0]: for out in (ps:=prescheduled[key]).outputs:
# freeze assign state before exec # freeze assign state before exec
if out.op is MetaOps.ASSIGN: if out.op is MetaOps.ASSIGN:
prerealized[out] = out.buffer.as_buffer() prerealized[out] = out.buffer.as_buffer()
assign_targets[out.srcs[1]] = out assign_targets[out.srcs[1]] = out
for x in ps[2]: for x in ps.inputs:
if x not in ground_truth and x.device != "NPY": prerealized[x] = x.buffer.as_buffer() if x not in ground_truth and x.device != "NPY": prerealized[x] = x.buffer.as_buffer()
si = ScheduleItem(ps[1], tuple(x.buffer for x in ps[0]+ps[2] if x.size != 0)) si = ScheduleItem(ps.ast, tuple(x.buffer for x in ps.outputs+ps.inputs if x.size != 0))
_exec_si(si, seed) _exec_si(si, seed)
for out in ps[0]: for out in ps.outputs:
ground_truth[out] = out.buffer.as_buffer() ground_truth[out] = out.buffer.as_buffer()
del out.srcs # only schedule the LazyBuffer in this fuzz run del out.srcs # only schedule the LazyBuffer in this fuzz run
@ -48,19 +48,19 @@ def fuzz_schedule(outs:List[LazyBuffer]):
if DEBUG >= 1: print(colored(f"testing permutation {i} {ctx}", "yellow")) if DEBUG >= 1: print(colored(f"testing permutation {i} {ctx}", "yellow"))
rawbufs: Dict[LazyBuffer, Buffer] = {} rawbufs: Dict[LazyBuffer, Buffer] = {}
for key in ts: for key in ts:
for out in (ps:=prescheduled[key])[0]: for out in (ps:=prescheduled[key]).outputs:
rawbufs[out] = Buffer(out.buffer.device, out.buffer.size, out.buffer.dtype) rawbufs[out] = Buffer(out.buffer.device, out.buffer.size, out.buffer.dtype)
if out.op is MetaOps.ASSIGN: rawbufs[out].ensure_allocated().copyin(prerealized[out]) if out.op is MetaOps.ASSIGN: rawbufs[out].ensure_allocated().copyin(prerealized[out])
for x in ps[2]: for x in ps.inputs:
if x not in rawbufs: if x not in rawbufs:
# override the assign_target after ASSIGN # override the assign_target after ASSIGN
if x in assign_targets and assign_targets[x] in rawbufs: rawbufs[x] = rawbufs[assign_targets[x]] if x in assign_targets and assign_targets[x] in rawbufs: rawbufs[x] = rawbufs[assign_targets[x]]
elif x.device == "NPY": rawbufs[x] = x.buffer elif x.device == "NPY": rawbufs[x] = x.buffer
# copy the pre realized input # copy the pre realized input
else: rawbufs[x] = Buffer(x.buffer.device, x.buffer.size, x.buffer.dtype, initial_value=prerealized[x]) else: rawbufs[x] = Buffer(x.buffer.device, x.buffer.size, x.buffer.dtype, initial_value=prerealized[x])
si = ScheduleItem(ps[1], tuple(rawbufs[x] for x in ps[0]+ps[2] if x.size != 0)) si = ScheduleItem(ps.ast, tuple(rawbufs[x] for x in ps.outputs+ps.inputs if x.size != 0))
_exec_si(si, seed) _exec_si(si, seed)
for out in ps[0]: for out in ps.outputs:
outbuf = np.frombuffer(rawbufs[out].as_buffer(), _to_np_dtype(out.dtype)) outbuf = np.frombuffer(rawbufs[out].as_buffer(), _to_np_dtype(out.dtype))
try: np.testing.assert_allclose(outbuf, np.frombuffer(ground_truth[out], _to_np_dtype(out.dtype)), atol=1e-2, rtol=1e-2) try: np.testing.assert_allclose(outbuf, np.frombuffer(ground_truth[out], _to_np_dtype(out.dtype)), atol=1e-2, rtol=1e-2)
except Exception as e: except Exception as e:

View File

@ -1,6 +1,6 @@
import sys, pickle, atexit import sys, pickle, atexit
from collections import defaultdict, deque from collections import defaultdict, deque
from dataclasses import dataclass from dataclasses import dataclass, field
from typing import Tuple, List, Dict, Optional, Set, DefaultDict, cast, get_args from typing import Tuple, List, Dict, Optional, Set, DefaultDict, cast, get_args
from tinygrad.ops import MetaOps, BufferOps, LazyOp, Op, ReduceOps, ConstBuffer, MemBuffer, UNSAFE_PAD_OPS, UnaryOps, reduce_st from tinygrad.ops import MetaOps, BufferOps, LazyOp, Op, ReduceOps, ConstBuffer, MemBuffer, UNSAFE_PAD_OPS, UnaryOps, reduce_st
from tinygrad.engine.graph import log_lazybuffer, realized_lazybuffer from tinygrad.engine.graph import log_lazybuffer, realized_lazybuffer
@ -35,11 +35,20 @@ class ScheduleItem:
"""Read only buffers in the schedule.""" """Read only buffers in the schedule."""
return self.bufs[len(self.ast.src):] if self.ast.op is MetaOps.KERNEL else self.bufs[1:] return self.bufs[len(self.ast.src):] if self.ast.op is MetaOps.KERNEL else self.bufs[1:]
@dataclass(frozen=True)
class LBScheduleItem:
ast: LazyOp
outputs: List[LazyBuffer]
inputs: List[LazyBuffer]
var_vals: Dict[Variable, int] = field(default_factory=dict)
metadata: List[Metadata] = field(default_factory=list)
# *** DAG transformation: List[LazyBuffer] -> ScheduleItem *** # *** DAG transformation: List[LazyBuffer] -> ScheduleItem ***
def _recursive_lazyop(buf:LazyBuffer, inputs:Dict[LazyBuffer, int], outputs:Tuple[LazyBuffer, ...], var_vals:Dict[Variable, int], st:ShapeTracker, def _recursive_lazyop(buf:LazyBuffer, inputs:Dict[LazyBuffer, int], outputs:Tuple[LazyBuffer, ...], var_vals:Dict[Variable, int], st:ShapeTracker,
realizes:Dict[LazyBuffer, None], assign_targets:Dict[LazyBuffer, LazyBuffer], realizes:Dict[LazyBuffer, None], assign_targets:Dict[LazyBuffer, LazyBuffer],
reduce_info:Dict[Tuple[LazyBuffer, ShapeTracker], Tuple[ShapeTracker, Tuple[int, ...]]], cache) -> LazyOp: reduce_info:Dict[Tuple[LazyBuffer, ShapeTracker], Tuple[ShapeTracker, Tuple[int, ...]]],
cache:Dict[Tuple[LazyBuffer, ShapeTracker], LazyOp]) -> LazyOp:
"""recursively create a lazyop""" """recursively create a lazyop"""
if buf is not buf.base: st, buf = buf.st+st, buf.base if buf is not buf.base: st, buf = buf.st+st, buf.base
if (buf, st) in cache: return cache[(buf, st)] if (buf, st) in cache: return cache[(buf, st)]
@ -89,8 +98,10 @@ def _permute_reduce(input_st:ShapeTracker, axis:Tuple[int, ...]) -> Tuple[ShapeT
tmp = input_st.permute(permute_axis) tmp = input_st.permute(permute_axis)
return tmp, tmp.shape[-len(axis):] return tmp, tmp.shape[-len(axis):]
def _recurse_reduceops(buf:LazyBuffer, st:ShapeTracker, realizes:Dict[LazyBuffer, None], outs:List[LazyBuffer],\ def _recurse_reduceops(buf:LazyBuffer, st:ShapeTracker, realizes:Dict[LazyBuffer, None], outs:List[LazyBuffer],
reduce_info:Dict[Tuple[LazyBuffer, ShapeTracker], Tuple[ShapeTracker, Tuple[int, ...]]], cache) -> Optional[Tuple[LazyBuffer, ShapeTracker]]: reduce_info:Dict[Tuple[LazyBuffer, ShapeTracker], Tuple[ShapeTracker, Tuple[int, ...]]],
cache:Dict[Tuple[LazyBuffer, ShapeTracker], Optional[Tuple[LazyBuffer, ShapeTracker]]]) -> \
Optional[Tuple[LazyBuffer, ShapeTracker]]:
if (buf, st) in cache: return cache[(buf, st)] if (buf, st) in cache: return cache[(buf, st)]
if buf.base.realized is not None or (buf.base in realizes and buf.base not in outs): return None if buf.base.realized is not None or (buf.base in realizes and buf.base not in outs): return None
if buf is not buf.base: st, buf = buf.st+st, buf.base if buf is not buf.base: st, buf = buf.st+st, buf.base
@ -125,15 +136,17 @@ def _recurse_reduceops(buf:LazyBuffer, st:ShapeTracker, realizes:Dict[LazyBuffer
return (buf, st) return (buf, st)
return cache.setdefault((buf, st), top_reduce) return cache.setdefault((buf, st), top_reduce)
def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]) -> Tuple[LazyOp, List[LazyBuffer], Dict[Variable, int], List[Metadata]]: def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]) -> LBScheduleItem:
"""describe the computation for a LazyBuffer with LazyOp + inputs + var_vals""" """describe the computation for a LazyBuffer with LazyOp + inputs + var_vals"""
if (out:=outs[0]).op is MetaOps.COPY and getenv("USE_COPY_KERNEL") and out.device.split(":")[0] == out.srcs[0].device.split(":")[0]: if (out:=outs[0]).op is MetaOps.COPY and getenv("USE_COPY_KERNEL") and out.device.split(":")[0] == out.srcs[0].device.split(":")[0]:
rd = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.uint8, st:=ShapeTracker.from_shape((out.arg,)))) rd = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.uint8, st:=ShapeTracker.from_shape((out.arg,))))
return LazyOp(MetaOps.KERNEL, (LazyOp(BufferOps.STORE, (rd,), MemBuffer(0, dtypes.uint8, st)), )), [x.base for x in out.srcs], {}, [] wr = LazyOp(BufferOps.STORE, (rd,), MemBuffer(0, dtypes.uint8, st))
if out.op in {MetaOps.CUSTOM, MetaOps.COPY, MetaOps.EMPTY, MetaOps.VIEW}: return LazyOp(out.op, (), out.arg), [x.base for x in out.srcs], {}, [] return LBScheduleItem(LazyOp(MetaOps.KERNEL, (wr,)), outs, [x.base for x in out.srcs])
if out.op in {MetaOps.CUSTOM, MetaOps.COPY, MetaOps.EMPTY, MetaOps.VIEW}:
return LBScheduleItem(LazyOp(out.op, (), out.arg), outs, [x.base for x in out.srcs])
# push through all movementops between reduceops # push through all movementops between reduceops
reduce_info: Dict[Tuple[LazyBuffer, ShapeTracker], Tuple[ShapeTracker, Tuple[int, ...]]] = {} reduce_info: Dict[Tuple[LazyBuffer, ShapeTracker], Tuple[ShapeTracker, Tuple[int, ...]]] = {}
seen_ops: Dict[Tuple[LazyBuffer, ShapeTracker], None] = {} seen_ops: Dict[Tuple[LazyBuffer, ShapeTracker], Optional[Tuple[LazyBuffer, ShapeTracker]]] = {}
for out in outs: _recurse_reduceops(out, out.st, realizes, outs, reduce_info, seen_ops) for out in outs: _recurse_reduceops(out, out.st, realizes, outs, reduce_info, seen_ops)
# pad all reduceops to the max of each dimension # pad all reduceops to the max of each dimension
shape_dims = [sorted(dedup(dims)) for dims in zip(*[input_st.shape for input_st,_ in reduce_info.values()])] shape_dims = [sorted(dedup(dims)) for dims in zip(*[input_st.shape for input_st,_ in reduce_info.values()])]
@ -158,13 +171,14 @@ def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]) ->
output_st, vv = output_st.simplify().unbind() output_st, vv = output_st.simplify().unbind()
if vv: var_vals.update(vv) if vv: var_vals.update(vv)
ast.append(LazyOp(BufferOps.STORE, (lop,), MemBuffer(i, out.dtype, output_st))) ast.append(LazyOp(BufferOps.STORE, (lop,), MemBuffer(i, out.dtype, output_st)))
return LazyOp(MetaOps.KERNEL, tuple(ast)), list(inputs), var_vals, dedup([x[0].metadata for x in cache if x[0].metadata and x[0] not in inputs]) return LBScheduleItem(LazyOp(MetaOps.KERNEL, tuple(ast)), outs, list(inputs), var_vals,
dedup([x[0].metadata for x in cache if x[0].metadata and x[0] not in inputs]))
# *** DAG creation: decide which LazyBuffers should realize *** # *** DAG creation: decide which LazyBuffers should realize ***
def _recurse_lb(buf:LazyBuffer, realizes:Dict[LazyBuffer, None], allbufs:Dict[LazyBuffer, None], simple_pads:Dict[LazyBuffer, None],\ def _recurse_lb(buf:LazyBuffer, realizes:Dict[LazyBuffer, None], allbufs:Dict[LazyBuffer, None], simple_pads:Dict[LazyBuffer, None],
children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], assign_targets:Dict[LazyBuffer, LazyBuffer],\ children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], assign_targets:Dict[LazyBuffer, LazyBuffer],
double_reduces:Dict[LazyBuffer, None], scheduled=False) -> None: double_reduces:Dict[LazyBuffer, None], scheduled=False) -> None:
"""recursively search the entire graph for all LazyBuffers, insert realizes after expands""" """recursively search the entire graph for all LazyBuffers, insert realizes after expands"""
if buf in allbufs or buf.base.realized is not None: return if buf in allbufs or buf.base.realized is not None: return
if GRAPH: log_lazybuffer(buf, scheduled) if GRAPH: log_lazybuffer(buf, scheduled)
@ -205,7 +219,8 @@ def _is_padding_okay(buf:LazyBuffer, realizes:Dict[LazyBuffer, None]) -> bool:
return all(_is_padding_okay(x.base, realizes) for x in buf.srcs) return all(_is_padding_okay(x.base, realizes) for x in buf.srcs)
def _recursive_group(tr:LazyBuffer, st:ShapeTracker, r:LazyBuffer, children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], def _recursive_group(tr:LazyBuffer, st:ShapeTracker, r:LazyBuffer, children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]],
realizes:Dict[LazyBuffer, None], reduce_for_op:Dict[LazyBuffer, LazyBuffer], group:Dict[LazyBuffer, None], cache:Set): realizes:Dict[LazyBuffer, None], reduce_for_op:Dict[LazyBuffer, LazyBuffer], group:Dict[LazyBuffer, None],
cache:Set[Tuple[LazyBuffer, ShapeTracker]]) -> None:
"""recursively search the LazyBuffer for groupable children, realize the LazyBuffer if a child can't group""" """recursively search the LazyBuffer for groupable children, realize the LazyBuffer if a child can't group"""
if (tr, st) in cache: return if (tr, st) in cache: return
cache.add((tr, st)) cache.add((tr, st))
@ -239,7 +254,7 @@ SCHEDULES: List = []
def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> \ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> \
Tuple[DefaultDict[LazyBuffer, List[LazyBuffer]], # this is the graph Tuple[DefaultDict[LazyBuffer, List[LazyBuffer]], # this is the graph
DefaultDict[LazyBuffer, int], # this is the in-degree of the graph DefaultDict[LazyBuffer, int], # this is the in-degree of the graph
Dict[LazyBuffer, Tuple[List[LazyBuffer], LazyOp, List[LazyBuffer], Dict[Variable, int], List[Metadata]]]]: # this is ??? Dict[LazyBuffer, LBScheduleItem]]: # this is the schedule item, but still in LazyBuffer
"""create a graph for realizing the outputs""" """create a graph for realizing the outputs"""
# start by just realizing the buffers passed in # start by just realizing the buffers passed in
realizes: Dict[LazyBuffer, None] = {x.base:None for x in outs if x.base.realized is None} realizes: Dict[LazyBuffer, None] = {x.base:None for x in outs if x.base.realized is None}
@ -330,20 +345,20 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> \
buf.buffer.options = None buf.buffer.options = None
# preschedule all buffers in realizes # preschedule all buffers in realizes
prescheduled = {group[0]:(group, *_lower_lazybuffer(group, realizes)) for group in output_groups.values()} prescheduled = {group[0]:_lower_lazybuffer(group, realizes) for group in output_groups.values()}
schedule_targets = {out:ps for ps in prescheduled.values() for out in ps[0]} schedule_targets = {out:ps for ps in prescheduled.values() for out in ps.outputs}
graph: DefaultDict[LazyBuffer, List[LazyBuffer]] = defaultdict(list) graph: DefaultDict[LazyBuffer, List[LazyBuffer]] = defaultdict(list)
in_degree: DefaultDict[LazyBuffer, int] = defaultdict(int) in_degree: DefaultDict[LazyBuffer, int] = defaultdict(int)
for key, lsi in prescheduled.items(): for key, lsi in prescheduled.items():
if key not in in_degree: in_degree[key] = 0 if key not in in_degree: in_degree[key] = 0
# realize outputs after all parents are realized # realize outputs after all parents are realized
scheduled_parents = set(schedule_targets[x][0][0] for x in lsi[2] if x in schedule_targets) scheduled_parents = set(schedule_targets[x].outputs[0] for x in lsi.inputs if x in schedule_targets)
for x in scheduled_parents: for x in scheduled_parents:
graph[x].append(key) graph[x].append(key)
in_degree[key] += 1 in_degree[key] += 1
# realize outputs before a parent is assigned to # realize outputs before a parent is assigned to
parents_assigns = set(schedule_targets[assign_targets[x]][0][0] for x in lsi[2] if x in assign_targets) parents_assigns = set(schedule_targets[assign_targets[x]].outputs[0] for x in lsi.inputs if x in assign_targets)
for assign in parents_assigns: for assign in parents_assigns:
graph[key].append(assign) graph[key].append(assign)
in_degree[assign] += 1 in_degree[assign] += 1
@ -373,15 +388,15 @@ def create_schedule_with_vars(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffe
kernel_number = GlobalCounters.kernel_count kernel_number = GlobalCounters.kernel_count
while queue: while queue:
ps = queue.popleft() ps = queue.popleft()
for buf in ps[0]: seen.add(buf) for buf in ps.outputs: seen.add(buf)
if GRAPH: if GRAPH:
kernel_number += 1 kernel_number += 1
for out in ps[0]: realized_lazybuffer(out, kernel_number) for out in ps.outputs: realized_lazybuffer(out, kernel_number)
var_vals = merge_dicts([var_vals, ps[3]]) var_vals = merge_dicts([var_vals, ps.var_vals])
for out in ps[0]: del out.srcs # can only schedule once for out in ps.outputs: del out.srcs # can only schedule once
schedule.append(si:=ScheduleItem(ps[1], tuple(x.buffer for x in ps[0]+ps[2] if x.size != 0), ps[4])) schedule.append(si:=ScheduleItem(ps.ast, tuple(x.buffer for x in ps.outputs+ps.inputs if x.size != 0), ps.metadata))
if logops and si.ast.op is MetaOps.KERNEL and not any(i.device.startswith("DISK:") for i in si.inputs): logops.write(str(si.ast)+"\n") if logops and si.ast.op is MetaOps.KERNEL and not any(i.device.startswith("DISK:") for i in si.inputs): logops.write(str(si.ast)+"\n")
for x in graph[ps[0][0]]: for x in graph[ps.outputs[0]]:
in_degree[x] -= 1 in_degree[x] -= 1
if in_degree[x] == 0: queue.append(prescheduled[x]) if in_degree[x] == 0: queue.append(prescheduled[x])