refactor to list of kernels [run_process_replay] (#6403)

This commit is contained in:
qazal 2024-09-08 17:19:45 +08:00 committed by GitHub
parent 7df4373fd9
commit 9a67ec6174
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 8 additions and 8 deletions

View File

@ -19,7 +19,7 @@ if __name__ == "__main__":
no_rewrite: List[float] = []
for k,v in output_groups.items():
st = time.perf_counter_ns()
lsi = _lower_lazybuffer(v, realizes)
lsi = _lower_lazybuffer(v, realizes)[0]
et = time.perf_counter_ns() - st
if lsi.ast.op is UOps.EXT: continue
no_rewrite.append(et*1e-6)
@ -30,7 +30,7 @@ if __name__ == "__main__":
with Context(AST_REWRITE=1):
for k,v in output_groups.items():
st = time.perf_counter_ns()
lsi = _lower_lazybuffer(v, realizes)
lsi = _lower_lazybuffer(v, realizes)[0]
bufs.append(v)
et = time.perf_counter_ns() - st
if lsi.ast.op is UOps.EXT: continue

View File

@ -6,7 +6,7 @@ from tinygrad.ops import BUFFER_UOPS, REDUCE_ALU, MetaOps, ReduceOps, UNSAFE_PAD
from tinygrad.rewrite import PatternMatcher, UPat, graph_rewrite
from tinygrad.engine.graph import log_lazybuffer, realized_lazybuffer
from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, FUSE_CONV_BW, FUSE_ARANGE, AST_REWRITE, \
GlobalCounters, all_same, colored, prod, dedup, all_int, merge_dicts, getenv, Metadata, unwrap
GlobalCounters, all_same, colored, flatten, prod, dedup, all_int, merge_dicts, getenv, Metadata, unwrap
from tinygrad.shape.symbolic import Variable, sint
from tinygrad.dtype import ConstType, ImageDType, PtrDType, dtypes
from tinygrad.lazy import LazyBuffer
@ -202,15 +202,15 @@ reduceop_fusor = PatternMatcher([
(UPat({UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.STORE}, name="root"), push_reduceop_shape),
])
def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]) -> LBScheduleItem:
def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]) -> List[LBScheduleItem]:
"""describe the computation for a LazyBuffer with UOp + inputs + var_vals"""
if (out:=outs[0]).op is MetaOps.COPY and getenv("USE_COPY_KERNEL") and out.device.split(":")[0] == out.srcs[0].device.split(":")[0]:
st_uop = ShapeTracker.from_shape(out.arg).to_uop()
rd = UOp(UOps.LOAD, dtypes.uint8, (UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.uint8), (), 1), st_uop))
wr = UOp(UOps.STORE, None, (UOp(UOps.DEFINE_GLOBAL, PtrDType(out.dtype), (), 0), st_uop, rd))
return LBScheduleItem(UOp(UOps.SINK, None, (wr,)), outs, [x.base for x in out.srcs])
return [LBScheduleItem(UOp(UOps.SINK, None, (wr,)), outs, [x.base for x in out.srcs])]
if out.op in {MetaOps.CUSTOM, MetaOps.COPY, MetaOps.EMPTY, MetaOps.VIEW}:
return LBScheduleItem(UOp(UOps.EXT, out.dtype, (), (out.op, out.arg)), outs, [x.base for x in out.srcs])
return [LBScheduleItem(UOp(UOps.EXT, out.dtype, (), (out.op, out.arg)), outs, [x.base for x in out.srcs])]
reduce_info: Dict[Tuple[LazyBuffer, ShapeTracker], Tuple[ShapeTracker, Tuple[int, ...]]] = {}
if not AST_REWRITE:
# push through all movementops between reduceops
@ -245,7 +245,7 @@ def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]) ->
sink = UOp(UOps.SINK, None, tuple(ast))
if AST_REWRITE:
sink = graph_rewrite(sink, reduceop_fusor)
return LBScheduleItem(sink, outs, list(inputs), var_vals, dedup([x[0].metadata for x in cache if x[0].metadata and x[0] not in inputs]))
return [LBScheduleItem(sink, outs, list(inputs), var_vals, dedup([x[0].metadata for x in cache if x[0].metadata and x[0] not in inputs]))]
# *** DAG creation: decide which LazyBuffers should realize ***
@ -424,7 +424,7 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> \
"""create a graph for realizing the outputs"""
output_groups, realizes, assign_targets = _get_output_groups(outs, seen)
# preschedule all buffers in realizes
prescheduled = [_lower_lazybuffer(group, realizes) for group in output_groups.values()]
prescheduled = flatten([_lower_lazybuffer(group, realizes) for group in output_groups.values()])
schedule_targets = {out:lsi for lsi in prescheduled for out in lsi.outputs}
graph: DefaultDict[LBScheduleItem, List[LBScheduleItem]] = defaultdict(list)