mirror of https://github.com/commaai/tinygrad.git
refactor to list of kernels [run_process_replay] (#6403)
This commit is contained in:
parent
7df4373fd9
commit
9a67ec6174
|
@ -19,7 +19,7 @@ if __name__ == "__main__":
|
|||
no_rewrite: List[float] = []
|
||||
for k,v in output_groups.items():
|
||||
st = time.perf_counter_ns()
|
||||
lsi = _lower_lazybuffer(v, realizes)
|
||||
lsi = _lower_lazybuffer(v, realizes)[0]
|
||||
et = time.perf_counter_ns() - st
|
||||
if lsi.ast.op is UOps.EXT: continue
|
||||
no_rewrite.append(et*1e-6)
|
||||
|
@ -30,7 +30,7 @@ if __name__ == "__main__":
|
|||
with Context(AST_REWRITE=1):
|
||||
for k,v in output_groups.items():
|
||||
st = time.perf_counter_ns()
|
||||
lsi = _lower_lazybuffer(v, realizes)
|
||||
lsi = _lower_lazybuffer(v, realizes)[0]
|
||||
bufs.append(v)
|
||||
et = time.perf_counter_ns() - st
|
||||
if lsi.ast.op is UOps.EXT: continue
|
||||
|
|
|
@ -6,7 +6,7 @@ from tinygrad.ops import BUFFER_UOPS, REDUCE_ALU, MetaOps, ReduceOps, UNSAFE_PAD
|
|||
from tinygrad.rewrite import PatternMatcher, UPat, graph_rewrite
|
||||
from tinygrad.engine.graph import log_lazybuffer, realized_lazybuffer
|
||||
from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, FUSE_CONV_BW, FUSE_ARANGE, AST_REWRITE, \
|
||||
GlobalCounters, all_same, colored, prod, dedup, all_int, merge_dicts, getenv, Metadata, unwrap
|
||||
GlobalCounters, all_same, colored, flatten, prod, dedup, all_int, merge_dicts, getenv, Metadata, unwrap
|
||||
from tinygrad.shape.symbolic import Variable, sint
|
||||
from tinygrad.dtype import ConstType, ImageDType, PtrDType, dtypes
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
|
@ -202,15 +202,15 @@ reduceop_fusor = PatternMatcher([
|
|||
(UPat({UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.STORE}, name="root"), push_reduceop_shape),
|
||||
])
|
||||
|
||||
def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]) -> LBScheduleItem:
|
||||
def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]) -> List[LBScheduleItem]:
|
||||
"""describe the computation for a LazyBuffer with UOp + inputs + var_vals"""
|
||||
if (out:=outs[0]).op is MetaOps.COPY and getenv("USE_COPY_KERNEL") and out.device.split(":")[0] == out.srcs[0].device.split(":")[0]:
|
||||
st_uop = ShapeTracker.from_shape(out.arg).to_uop()
|
||||
rd = UOp(UOps.LOAD, dtypes.uint8, (UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.uint8), (), 1), st_uop))
|
||||
wr = UOp(UOps.STORE, None, (UOp(UOps.DEFINE_GLOBAL, PtrDType(out.dtype), (), 0), st_uop, rd))
|
||||
return LBScheduleItem(UOp(UOps.SINK, None, (wr,)), outs, [x.base for x in out.srcs])
|
||||
return [LBScheduleItem(UOp(UOps.SINK, None, (wr,)), outs, [x.base for x in out.srcs])]
|
||||
if out.op in {MetaOps.CUSTOM, MetaOps.COPY, MetaOps.EMPTY, MetaOps.VIEW}:
|
||||
return LBScheduleItem(UOp(UOps.EXT, out.dtype, (), (out.op, out.arg)), outs, [x.base for x in out.srcs])
|
||||
return [LBScheduleItem(UOp(UOps.EXT, out.dtype, (), (out.op, out.arg)), outs, [x.base for x in out.srcs])]
|
||||
reduce_info: Dict[Tuple[LazyBuffer, ShapeTracker], Tuple[ShapeTracker, Tuple[int, ...]]] = {}
|
||||
if not AST_REWRITE:
|
||||
# push through all movementops between reduceops
|
||||
|
@ -245,7 +245,7 @@ def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]) ->
|
|||
sink = UOp(UOps.SINK, None, tuple(ast))
|
||||
if AST_REWRITE:
|
||||
sink = graph_rewrite(sink, reduceop_fusor)
|
||||
return LBScheduleItem(sink, outs, list(inputs), var_vals, dedup([x[0].metadata for x in cache if x[0].metadata and x[0] not in inputs]))
|
||||
return [LBScheduleItem(sink, outs, list(inputs), var_vals, dedup([x[0].metadata for x in cache if x[0].metadata and x[0] not in inputs]))]
|
||||
|
||||
# *** DAG creation: decide which LazyBuffers should realize ***
|
||||
|
||||
|
@ -424,7 +424,7 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> \
|
|||
"""create a graph for realizing the outputs"""
|
||||
output_groups, realizes, assign_targets = _get_output_groups(outs, seen)
|
||||
# preschedule all buffers in realizes
|
||||
prescheduled = [_lower_lazybuffer(group, realizes) for group in output_groups.values()]
|
||||
prescheduled = flatten([_lower_lazybuffer(group, realizes) for group in output_groups.values()])
|
||||
schedule_targets = {out:lsi for lsi in prescheduled for out in lsi.outputs}
|
||||
|
||||
graph: DefaultDict[LBScheduleItem, List[LBScheduleItem]] = defaultdict(list)
|
||||
|
|
Loading…
Reference in New Issue