diff --git a/test/external/external_benchmark_ast.py b/test/external/external_benchmark_ast.py index fa7ce0c6..b7d6dc89 100644 --- a/test/external/external_benchmark_ast.py +++ b/test/external/external_benchmark_ast.py @@ -19,7 +19,7 @@ if __name__ == "__main__": no_rewrite: List[float] = [] for k,v in output_groups.items(): st = time.perf_counter_ns() - lsi = _lower_lazybuffer(v, realizes) + lsi = _lower_lazybuffer(v, realizes)[0] et = time.perf_counter_ns() - st if lsi.ast.op is UOps.EXT: continue no_rewrite.append(et*1e-6) @@ -30,7 +30,7 @@ if __name__ == "__main__": with Context(AST_REWRITE=1): for k,v in output_groups.items(): st = time.perf_counter_ns() - lsi = _lower_lazybuffer(v, realizes) + lsi = _lower_lazybuffer(v, realizes)[0] bufs.append(v) et = time.perf_counter_ns() - st if lsi.ast.op is UOps.EXT: continue diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index ac898e41..2ccfe41e 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -6,7 +6,7 @@ from tinygrad.ops import BUFFER_UOPS, REDUCE_ALU, MetaOps, ReduceOps, UNSAFE_PAD from tinygrad.rewrite import PatternMatcher, UPat, graph_rewrite from tinygrad.engine.graph import log_lazybuffer, realized_lazybuffer from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, FUSE_CONV_BW, FUSE_ARANGE, AST_REWRITE, \ - GlobalCounters, all_same, colored, prod, dedup, all_int, merge_dicts, getenv, Metadata, unwrap + GlobalCounters, all_same, colored, flatten, prod, dedup, all_int, merge_dicts, getenv, Metadata, unwrap from tinygrad.shape.symbolic import Variable, sint from tinygrad.dtype import ConstType, ImageDType, PtrDType, dtypes from tinygrad.lazy import LazyBuffer @@ -202,15 +202,15 @@ reduceop_fusor = PatternMatcher([ (UPat({UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.STORE}, name="root"), push_reduceop_shape), ]) -def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]) -> LBScheduleItem: +def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]) -> List[LBScheduleItem]: """describe the computation for a LazyBuffer with UOp + inputs + var_vals""" if (out:=outs[0]).op is MetaOps.COPY and getenv("USE_COPY_KERNEL") and out.device.split(":")[0] == out.srcs[0].device.split(":")[0]: st_uop = ShapeTracker.from_shape(out.arg).to_uop() rd = UOp(UOps.LOAD, dtypes.uint8, (UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.uint8), (), 1), st_uop)) wr = UOp(UOps.STORE, None, (UOp(UOps.DEFINE_GLOBAL, PtrDType(out.dtype), (), 0), st_uop, rd)) - return LBScheduleItem(UOp(UOps.SINK, None, (wr,)), outs, [x.base for x in out.srcs]) + return [LBScheduleItem(UOp(UOps.SINK, None, (wr,)), outs, [x.base for x in out.srcs])] if out.op in {MetaOps.CUSTOM, MetaOps.COPY, MetaOps.EMPTY, MetaOps.VIEW}: - return LBScheduleItem(UOp(UOps.EXT, out.dtype, (), (out.op, out.arg)), outs, [x.base for x in out.srcs]) + return [LBScheduleItem(UOp(UOps.EXT, out.dtype, (), (out.op, out.arg)), outs, [x.base for x in out.srcs])] reduce_info: Dict[Tuple[LazyBuffer, ShapeTracker], Tuple[ShapeTracker, Tuple[int, ...]]] = {} if not AST_REWRITE: # push through all movementops between reduceops @@ -245,7 +245,7 @@ def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]) -> sink = UOp(UOps.SINK, None, tuple(ast)) if AST_REWRITE: sink = graph_rewrite(sink, reduceop_fusor) - return LBScheduleItem(sink, outs, list(inputs), var_vals, dedup([x[0].metadata for x in cache if x[0].metadata and x[0] not in inputs])) + return [LBScheduleItem(sink, outs, list(inputs), var_vals, dedup([x[0].metadata for x in cache if x[0].metadata and x[0] not in inputs]))] # *** DAG creation: decide which LazyBuffers should realize *** @@ -424,7 +424,7 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> \ """create a graph for realizing the outputs""" output_groups, realizes, assign_targets = _get_output_groups(outs, seen) # preschedule all buffers in realizes - prescheduled = [_lower_lazybuffer(group, realizes) for group in output_groups.values()] + prescheduled = flatten([_lower_lazybuffer(group, realizes) for group in output_groups.values()]) schedule_targets = {out:lsi for lsi in prescheduled for out in lsi.outputs} graph: DefaultDict[LBScheduleItem, List[LBScheduleItem]] = defaultdict(list)