mirror of https://github.com/commaai/tinygrad.git
add LBScheduleItem type [run_process_replay] (#5944)
* add LBScheduleItem type [run_process_replay] * minor cleanups * fix * fix fuzz tests * add group cache type
This commit is contained in:
parent
1dab75ae37
commit
73d4d51845
|
@ -30,16 +30,16 @@ def fuzz_schedule(outs:List[LazyBuffer]):
|
||||||
seed = Tensor._seed
|
seed = Tensor._seed
|
||||||
ts, (_, prescheduled) = toposorts[0]
|
ts, (_, prescheduled) = toposorts[0]
|
||||||
for key in ts:
|
for key in ts:
|
||||||
for out in (ps:=prescheduled[key])[0]:
|
for out in (ps:=prescheduled[key]).outputs:
|
||||||
# freeze assign state before exec
|
# freeze assign state before exec
|
||||||
if out.op is MetaOps.ASSIGN:
|
if out.op is MetaOps.ASSIGN:
|
||||||
prerealized[out] = out.buffer.as_buffer()
|
prerealized[out] = out.buffer.as_buffer()
|
||||||
assign_targets[out.srcs[1]] = out
|
assign_targets[out.srcs[1]] = out
|
||||||
for x in ps[2]:
|
for x in ps.inputs:
|
||||||
if x not in ground_truth and x.device != "NPY": prerealized[x] = x.buffer.as_buffer()
|
if x not in ground_truth and x.device != "NPY": prerealized[x] = x.buffer.as_buffer()
|
||||||
si = ScheduleItem(ps[1], tuple(x.buffer for x in ps[0]+ps[2] if x.size != 0))
|
si = ScheduleItem(ps.ast, tuple(x.buffer for x in ps.outputs+ps.inputs if x.size != 0))
|
||||||
_exec_si(si, seed)
|
_exec_si(si, seed)
|
||||||
for out in ps[0]:
|
for out in ps.outputs:
|
||||||
ground_truth[out] = out.buffer.as_buffer()
|
ground_truth[out] = out.buffer.as_buffer()
|
||||||
del out.srcs # only schedule the LazyBuffer in this fuzz run
|
del out.srcs # only schedule the LazyBuffer in this fuzz run
|
||||||
|
|
||||||
|
@ -48,19 +48,19 @@ def fuzz_schedule(outs:List[LazyBuffer]):
|
||||||
if DEBUG >= 1: print(colored(f"testing permutation {i} {ctx}", "yellow"))
|
if DEBUG >= 1: print(colored(f"testing permutation {i} {ctx}", "yellow"))
|
||||||
rawbufs: Dict[LazyBuffer, Buffer] = {}
|
rawbufs: Dict[LazyBuffer, Buffer] = {}
|
||||||
for key in ts:
|
for key in ts:
|
||||||
for out in (ps:=prescheduled[key])[0]:
|
for out in (ps:=prescheduled[key]).outputs:
|
||||||
rawbufs[out] = Buffer(out.buffer.device, out.buffer.size, out.buffer.dtype)
|
rawbufs[out] = Buffer(out.buffer.device, out.buffer.size, out.buffer.dtype)
|
||||||
if out.op is MetaOps.ASSIGN: rawbufs[out].ensure_allocated().copyin(prerealized[out])
|
if out.op is MetaOps.ASSIGN: rawbufs[out].ensure_allocated().copyin(prerealized[out])
|
||||||
for x in ps[2]:
|
for x in ps.inputs:
|
||||||
if x not in rawbufs:
|
if x not in rawbufs:
|
||||||
# override the assign_target after ASSIGN
|
# override the assign_target after ASSIGN
|
||||||
if x in assign_targets and assign_targets[x] in rawbufs: rawbufs[x] = rawbufs[assign_targets[x]]
|
if x in assign_targets and assign_targets[x] in rawbufs: rawbufs[x] = rawbufs[assign_targets[x]]
|
||||||
elif x.device == "NPY": rawbufs[x] = x.buffer
|
elif x.device == "NPY": rawbufs[x] = x.buffer
|
||||||
# copy the pre realized input
|
# copy the pre realized input
|
||||||
else: rawbufs[x] = Buffer(x.buffer.device, x.buffer.size, x.buffer.dtype, initial_value=prerealized[x])
|
else: rawbufs[x] = Buffer(x.buffer.device, x.buffer.size, x.buffer.dtype, initial_value=prerealized[x])
|
||||||
si = ScheduleItem(ps[1], tuple(rawbufs[x] for x in ps[0]+ps[2] if x.size != 0))
|
si = ScheduleItem(ps.ast, tuple(rawbufs[x] for x in ps.outputs+ps.inputs if x.size != 0))
|
||||||
_exec_si(si, seed)
|
_exec_si(si, seed)
|
||||||
for out in ps[0]:
|
for out in ps.outputs:
|
||||||
outbuf = np.frombuffer(rawbufs[out].as_buffer(), _to_np_dtype(out.dtype))
|
outbuf = np.frombuffer(rawbufs[out].as_buffer(), _to_np_dtype(out.dtype))
|
||||||
try: np.testing.assert_allclose(outbuf, np.frombuffer(ground_truth[out], _to_np_dtype(out.dtype)), atol=1e-2, rtol=1e-2)
|
try: np.testing.assert_allclose(outbuf, np.frombuffer(ground_truth[out], _to_np_dtype(out.dtype)), atol=1e-2, rtol=1e-2)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import sys, pickle, atexit
|
import sys, pickle, atexit
|
||||||
from collections import defaultdict, deque
|
from collections import defaultdict, deque
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
from typing import Tuple, List, Dict, Optional, Set, DefaultDict, cast, get_args
|
from typing import Tuple, List, Dict, Optional, Set, DefaultDict, cast, get_args
|
||||||
from tinygrad.ops import MetaOps, BufferOps, LazyOp, Op, ReduceOps, ConstBuffer, MemBuffer, UNSAFE_PAD_OPS, UnaryOps, reduce_st
|
from tinygrad.ops import MetaOps, BufferOps, LazyOp, Op, ReduceOps, ConstBuffer, MemBuffer, UNSAFE_PAD_OPS, UnaryOps, reduce_st
|
||||||
from tinygrad.engine.graph import log_lazybuffer, realized_lazybuffer
|
from tinygrad.engine.graph import log_lazybuffer, realized_lazybuffer
|
||||||
|
@ -35,11 +35,20 @@ class ScheduleItem:
|
||||||
"""Read only buffers in the schedule."""
|
"""Read only buffers in the schedule."""
|
||||||
return self.bufs[len(self.ast.src):] if self.ast.op is MetaOps.KERNEL else self.bufs[1:]
|
return self.bufs[len(self.ast.src):] if self.ast.op is MetaOps.KERNEL else self.bufs[1:]
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class LBScheduleItem:
|
||||||
|
ast: LazyOp
|
||||||
|
outputs: List[LazyBuffer]
|
||||||
|
inputs: List[LazyBuffer]
|
||||||
|
var_vals: Dict[Variable, int] = field(default_factory=dict)
|
||||||
|
metadata: List[Metadata] = field(default_factory=list)
|
||||||
|
|
||||||
# *** DAG transformation: List[LazyBuffer] -> ScheduleItem ***
|
# *** DAG transformation: List[LazyBuffer] -> ScheduleItem ***
|
||||||
|
|
||||||
def _recursive_lazyop(buf:LazyBuffer, inputs:Dict[LazyBuffer, int], outputs:Tuple[LazyBuffer, ...], var_vals:Dict[Variable, int], st:ShapeTracker,
|
def _recursive_lazyop(buf:LazyBuffer, inputs:Dict[LazyBuffer, int], outputs:Tuple[LazyBuffer, ...], var_vals:Dict[Variable, int], st:ShapeTracker,
|
||||||
realizes:Dict[LazyBuffer, None], assign_targets:Dict[LazyBuffer, LazyBuffer],
|
realizes:Dict[LazyBuffer, None], assign_targets:Dict[LazyBuffer, LazyBuffer],
|
||||||
reduce_info:Dict[Tuple[LazyBuffer, ShapeTracker], Tuple[ShapeTracker, Tuple[int, ...]]], cache) -> LazyOp:
|
reduce_info:Dict[Tuple[LazyBuffer, ShapeTracker], Tuple[ShapeTracker, Tuple[int, ...]]],
|
||||||
|
cache:Dict[Tuple[LazyBuffer, ShapeTracker], LazyOp]) -> LazyOp:
|
||||||
"""recursively create a lazyop"""
|
"""recursively create a lazyop"""
|
||||||
if buf is not buf.base: st, buf = buf.st+st, buf.base
|
if buf is not buf.base: st, buf = buf.st+st, buf.base
|
||||||
if (buf, st) in cache: return cache[(buf, st)]
|
if (buf, st) in cache: return cache[(buf, st)]
|
||||||
|
@ -89,8 +98,10 @@ def _permute_reduce(input_st:ShapeTracker, axis:Tuple[int, ...]) -> Tuple[ShapeT
|
||||||
tmp = input_st.permute(permute_axis)
|
tmp = input_st.permute(permute_axis)
|
||||||
return tmp, tmp.shape[-len(axis):]
|
return tmp, tmp.shape[-len(axis):]
|
||||||
|
|
||||||
def _recurse_reduceops(buf:LazyBuffer, st:ShapeTracker, realizes:Dict[LazyBuffer, None], outs:List[LazyBuffer],\
|
def _recurse_reduceops(buf:LazyBuffer, st:ShapeTracker, realizes:Dict[LazyBuffer, None], outs:List[LazyBuffer],
|
||||||
reduce_info:Dict[Tuple[LazyBuffer, ShapeTracker], Tuple[ShapeTracker, Tuple[int, ...]]], cache) -> Optional[Tuple[LazyBuffer, ShapeTracker]]:
|
reduce_info:Dict[Tuple[LazyBuffer, ShapeTracker], Tuple[ShapeTracker, Tuple[int, ...]]],
|
||||||
|
cache:Dict[Tuple[LazyBuffer, ShapeTracker], Optional[Tuple[LazyBuffer, ShapeTracker]]]) -> \
|
||||||
|
Optional[Tuple[LazyBuffer, ShapeTracker]]:
|
||||||
if (buf, st) in cache: return cache[(buf, st)]
|
if (buf, st) in cache: return cache[(buf, st)]
|
||||||
if buf.base.realized is not None or (buf.base in realizes and buf.base not in outs): return None
|
if buf.base.realized is not None or (buf.base in realizes and buf.base not in outs): return None
|
||||||
if buf is not buf.base: st, buf = buf.st+st, buf.base
|
if buf is not buf.base: st, buf = buf.st+st, buf.base
|
||||||
|
@ -125,15 +136,17 @@ def _recurse_reduceops(buf:LazyBuffer, st:ShapeTracker, realizes:Dict[LazyBuffer
|
||||||
return (buf, st)
|
return (buf, st)
|
||||||
return cache.setdefault((buf, st), top_reduce)
|
return cache.setdefault((buf, st), top_reduce)
|
||||||
|
|
||||||
def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]) -> Tuple[LazyOp, List[LazyBuffer], Dict[Variable, int], List[Metadata]]:
|
def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]) -> LBScheduleItem:
|
||||||
"""describe the computation for a LazyBuffer with LazyOp + inputs + var_vals"""
|
"""describe the computation for a LazyBuffer with LazyOp + inputs + var_vals"""
|
||||||
if (out:=outs[0]).op is MetaOps.COPY and getenv("USE_COPY_KERNEL") and out.device.split(":")[0] == out.srcs[0].device.split(":")[0]:
|
if (out:=outs[0]).op is MetaOps.COPY and getenv("USE_COPY_KERNEL") and out.device.split(":")[0] == out.srcs[0].device.split(":")[0]:
|
||||||
rd = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.uint8, st:=ShapeTracker.from_shape((out.arg,))))
|
rd = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.uint8, st:=ShapeTracker.from_shape((out.arg,))))
|
||||||
return LazyOp(MetaOps.KERNEL, (LazyOp(BufferOps.STORE, (rd,), MemBuffer(0, dtypes.uint8, st)), )), [x.base for x in out.srcs], {}, []
|
wr = LazyOp(BufferOps.STORE, (rd,), MemBuffer(0, dtypes.uint8, st))
|
||||||
if out.op in {MetaOps.CUSTOM, MetaOps.COPY, MetaOps.EMPTY, MetaOps.VIEW}: return LazyOp(out.op, (), out.arg), [x.base for x in out.srcs], {}, []
|
return LBScheduleItem(LazyOp(MetaOps.KERNEL, (wr,)), outs, [x.base for x in out.srcs])
|
||||||
|
if out.op in {MetaOps.CUSTOM, MetaOps.COPY, MetaOps.EMPTY, MetaOps.VIEW}:
|
||||||
|
return LBScheduleItem(LazyOp(out.op, (), out.arg), outs, [x.base for x in out.srcs])
|
||||||
# push through all movementops between reduceops
|
# push through all movementops between reduceops
|
||||||
reduce_info: Dict[Tuple[LazyBuffer, ShapeTracker], Tuple[ShapeTracker, Tuple[int, ...]]] = {}
|
reduce_info: Dict[Tuple[LazyBuffer, ShapeTracker], Tuple[ShapeTracker, Tuple[int, ...]]] = {}
|
||||||
seen_ops: Dict[Tuple[LazyBuffer, ShapeTracker], None] = {}
|
seen_ops: Dict[Tuple[LazyBuffer, ShapeTracker], Optional[Tuple[LazyBuffer, ShapeTracker]]] = {}
|
||||||
for out in outs: _recurse_reduceops(out, out.st, realizes, outs, reduce_info, seen_ops)
|
for out in outs: _recurse_reduceops(out, out.st, realizes, outs, reduce_info, seen_ops)
|
||||||
# pad all reduceops to the max of each dimension
|
# pad all reduceops to the max of each dimension
|
||||||
shape_dims = [sorted(dedup(dims)) for dims in zip(*[input_st.shape for input_st,_ in reduce_info.values()])]
|
shape_dims = [sorted(dedup(dims)) for dims in zip(*[input_st.shape for input_st,_ in reduce_info.values()])]
|
||||||
|
@ -158,13 +171,14 @@ def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]) ->
|
||||||
output_st, vv = output_st.simplify().unbind()
|
output_st, vv = output_st.simplify().unbind()
|
||||||
if vv: var_vals.update(vv)
|
if vv: var_vals.update(vv)
|
||||||
ast.append(LazyOp(BufferOps.STORE, (lop,), MemBuffer(i, out.dtype, output_st)))
|
ast.append(LazyOp(BufferOps.STORE, (lop,), MemBuffer(i, out.dtype, output_st)))
|
||||||
return LazyOp(MetaOps.KERNEL, tuple(ast)), list(inputs), var_vals, dedup([x[0].metadata for x in cache if x[0].metadata and x[0] not in inputs])
|
return LBScheduleItem(LazyOp(MetaOps.KERNEL, tuple(ast)), outs, list(inputs), var_vals,
|
||||||
|
dedup([x[0].metadata for x in cache if x[0].metadata and x[0] not in inputs]))
|
||||||
|
|
||||||
# *** DAG creation: decide which LazyBuffers should realize ***
|
# *** DAG creation: decide which LazyBuffers should realize ***
|
||||||
|
|
||||||
def _recurse_lb(buf:LazyBuffer, realizes:Dict[LazyBuffer, None], allbufs:Dict[LazyBuffer, None], simple_pads:Dict[LazyBuffer, None],\
|
def _recurse_lb(buf:LazyBuffer, realizes:Dict[LazyBuffer, None], allbufs:Dict[LazyBuffer, None], simple_pads:Dict[LazyBuffer, None],
|
||||||
children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], assign_targets:Dict[LazyBuffer, LazyBuffer],\
|
children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], assign_targets:Dict[LazyBuffer, LazyBuffer],
|
||||||
double_reduces:Dict[LazyBuffer, None], scheduled=False) -> None:
|
double_reduces:Dict[LazyBuffer, None], scheduled=False) -> None:
|
||||||
"""recursively search the entire graph for all LazyBuffers, insert realizes after expands"""
|
"""recursively search the entire graph for all LazyBuffers, insert realizes after expands"""
|
||||||
if buf in allbufs or buf.base.realized is not None: return
|
if buf in allbufs or buf.base.realized is not None: return
|
||||||
if GRAPH: log_lazybuffer(buf, scheduled)
|
if GRAPH: log_lazybuffer(buf, scheduled)
|
||||||
|
@ -205,7 +219,8 @@ def _is_padding_okay(buf:LazyBuffer, realizes:Dict[LazyBuffer, None]) -> bool:
|
||||||
return all(_is_padding_okay(x.base, realizes) for x in buf.srcs)
|
return all(_is_padding_okay(x.base, realizes) for x in buf.srcs)
|
||||||
|
|
||||||
def _recursive_group(tr:LazyBuffer, st:ShapeTracker, r:LazyBuffer, children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]],
|
def _recursive_group(tr:LazyBuffer, st:ShapeTracker, r:LazyBuffer, children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]],
|
||||||
realizes:Dict[LazyBuffer, None], reduce_for_op:Dict[LazyBuffer, LazyBuffer], group:Dict[LazyBuffer, None], cache:Set):
|
realizes:Dict[LazyBuffer, None], reduce_for_op:Dict[LazyBuffer, LazyBuffer], group:Dict[LazyBuffer, None],
|
||||||
|
cache:Set[Tuple[LazyBuffer, ShapeTracker]]) -> None:
|
||||||
"""recursively search the LazyBuffer for groupable children, realize the LazyBuffer if a child can't group"""
|
"""recursively search the LazyBuffer for groupable children, realize the LazyBuffer if a child can't group"""
|
||||||
if (tr, st) in cache: return
|
if (tr, st) in cache: return
|
||||||
cache.add((tr, st))
|
cache.add((tr, st))
|
||||||
|
@ -239,7 +254,7 @@ SCHEDULES: List = []
|
||||||
def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> \
|
def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> \
|
||||||
Tuple[DefaultDict[LazyBuffer, List[LazyBuffer]], # this is the graph
|
Tuple[DefaultDict[LazyBuffer, List[LazyBuffer]], # this is the graph
|
||||||
DefaultDict[LazyBuffer, int], # this is the in-degree of the graph
|
DefaultDict[LazyBuffer, int], # this is the in-degree of the graph
|
||||||
Dict[LazyBuffer, Tuple[List[LazyBuffer], LazyOp, List[LazyBuffer], Dict[Variable, int], List[Metadata]]]]: # this is ???
|
Dict[LazyBuffer, LBScheduleItem]]: # this is the schedule item, but still in LazyBuffer
|
||||||
"""create a graph for realizing the outputs"""
|
"""create a graph for realizing the outputs"""
|
||||||
# start by just realizing the buffers passed in
|
# start by just realizing the buffers passed in
|
||||||
realizes: Dict[LazyBuffer, None] = {x.base:None for x in outs if x.base.realized is None}
|
realizes: Dict[LazyBuffer, None] = {x.base:None for x in outs if x.base.realized is None}
|
||||||
|
@ -330,20 +345,20 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> \
|
||||||
buf.buffer.options = None
|
buf.buffer.options = None
|
||||||
|
|
||||||
# preschedule all buffers in realizes
|
# preschedule all buffers in realizes
|
||||||
prescheduled = {group[0]:(group, *_lower_lazybuffer(group, realizes)) for group in output_groups.values()}
|
prescheduled = {group[0]:_lower_lazybuffer(group, realizes) for group in output_groups.values()}
|
||||||
schedule_targets = {out:ps for ps in prescheduled.values() for out in ps[0]}
|
schedule_targets = {out:ps for ps in prescheduled.values() for out in ps.outputs}
|
||||||
|
|
||||||
graph: DefaultDict[LazyBuffer, List[LazyBuffer]] = defaultdict(list)
|
graph: DefaultDict[LazyBuffer, List[LazyBuffer]] = defaultdict(list)
|
||||||
in_degree: DefaultDict[LazyBuffer, int] = defaultdict(int)
|
in_degree: DefaultDict[LazyBuffer, int] = defaultdict(int)
|
||||||
for key, lsi in prescheduled.items():
|
for key, lsi in prescheduled.items():
|
||||||
if key not in in_degree: in_degree[key] = 0
|
if key not in in_degree: in_degree[key] = 0
|
||||||
# realize outputs after all parents are realized
|
# realize outputs after all parents are realized
|
||||||
scheduled_parents = set(schedule_targets[x][0][0] for x in lsi[2] if x in schedule_targets)
|
scheduled_parents = set(schedule_targets[x].outputs[0] for x in lsi.inputs if x in schedule_targets)
|
||||||
for x in scheduled_parents:
|
for x in scheduled_parents:
|
||||||
graph[x].append(key)
|
graph[x].append(key)
|
||||||
in_degree[key] += 1
|
in_degree[key] += 1
|
||||||
# realize outputs before a parent is assigned to
|
# realize outputs before a parent is assigned to
|
||||||
parents_assigns = set(schedule_targets[assign_targets[x]][0][0] for x in lsi[2] if x in assign_targets)
|
parents_assigns = set(schedule_targets[assign_targets[x]].outputs[0] for x in lsi.inputs if x in assign_targets)
|
||||||
for assign in parents_assigns:
|
for assign in parents_assigns:
|
||||||
graph[key].append(assign)
|
graph[key].append(assign)
|
||||||
in_degree[assign] += 1
|
in_degree[assign] += 1
|
||||||
|
@ -373,15 +388,15 @@ def create_schedule_with_vars(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffe
|
||||||
kernel_number = GlobalCounters.kernel_count
|
kernel_number = GlobalCounters.kernel_count
|
||||||
while queue:
|
while queue:
|
||||||
ps = queue.popleft()
|
ps = queue.popleft()
|
||||||
for buf in ps[0]: seen.add(buf)
|
for buf in ps.outputs: seen.add(buf)
|
||||||
if GRAPH:
|
if GRAPH:
|
||||||
kernel_number += 1
|
kernel_number += 1
|
||||||
for out in ps[0]: realized_lazybuffer(out, kernel_number)
|
for out in ps.outputs: realized_lazybuffer(out, kernel_number)
|
||||||
var_vals = merge_dicts([var_vals, ps[3]])
|
var_vals = merge_dicts([var_vals, ps.var_vals])
|
||||||
for out in ps[0]: del out.srcs # can only schedule once
|
for out in ps.outputs: del out.srcs # can only schedule once
|
||||||
schedule.append(si:=ScheduleItem(ps[1], tuple(x.buffer for x in ps[0]+ps[2] if x.size != 0), ps[4]))
|
schedule.append(si:=ScheduleItem(ps.ast, tuple(x.buffer for x in ps.outputs+ps.inputs if x.size != 0), ps.metadata))
|
||||||
if logops and si.ast.op is MetaOps.KERNEL and not any(i.device.startswith("DISK:") for i in si.inputs): logops.write(str(si.ast)+"\n")
|
if logops and si.ast.op is MetaOps.KERNEL and not any(i.device.startswith("DISK:") for i in si.inputs): logops.write(str(si.ast)+"\n")
|
||||||
for x in graph[ps[0][0]]:
|
for x in graph[ps.outputs[0]]:
|
||||||
in_degree[x] -= 1
|
in_degree[x] -= 1
|
||||||
if in_degree[x] == 0: queue.append(prescheduled[x])
|
if in_degree[x] == 0: queue.append(prescheduled[x])
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue