From 68ca4d4276718f84a23400e53fea6ee7a1e3114a Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Tue, 26 Mar 2024 21:02:46 -0700 Subject: [PATCH] split to schedule.py (#3949) * split to schedule.py * split --- docs/abstractions2.py | 3 +- examples/handcode_resnet50_opt.py | 2 +- extra/autopad.py | 2 +- extra/gemm/tvm_gemm.py | 2 +- openpilot/compile2.py | 3 +- ...xternal_benchmark_multitensor_allreduce.py | 3 +- test/external/external_test_hip_compile.py | 2 +- test/external/external_test_uops_graphing.py | 2 +- test/test_conv_shapetracker.py | 2 +- test/test_dtype_alu.py | 2 +- test/test_fusion_op.py | 3 +- test/test_lazybuffer.py | 2 +- test/test_lazyop.py | 2 +- test/test_linearizer.py | 3 +- test/test_multitensor.py | 2 +- test/test_schedule.py | 2 +- test/test_search.py | 2 +- test/test_uops.py | 2 +- test/test_uops_stats.py | 3 +- test/test_winograd.py | 2 +- tinygrad/engine/realize.py | 215 +----------------- tinygrad/engine/schedule.py | 203 +++++++++++++++++ tinygrad/tensor.py | 3 +- 23 files changed, 235 insertions(+), 232 deletions(-) create mode 100644 tinygrad/engine/schedule.py diff --git a/docs/abstractions2.py b/docs/abstractions2.py index 3e5deafa..a1dab09b 100644 --- a/docs/abstractions2.py +++ b/docs/abstractions2.py @@ -73,7 +73,8 @@ assert out.as_buffer().cast('I')[0] == 5 print("******** third, the LazyBuffer ***********") from tinygrad.lazy import LazyBuffer, LoadOps -from tinygrad.engine.realize import run_schedule, create_schedule +from tinygrad.engine.realize import run_schedule +from tinygrad.engine.schedule import create_schedule # allocate some values + load in values a = LazyBuffer.loadop(LoadOps.EMPTY, (1,), dtypes.int32, DEVICE) diff --git a/examples/handcode_resnet50_opt.py b/examples/handcode_resnet50_opt.py index 4ce146c2..212d3f02 100644 --- a/examples/handcode_resnet50_opt.py +++ b/examples/handcode_resnet50_opt.py @@ -8,7 +8,7 @@ from tinygrad.features.search import time_linearizer, beam_search, bufs_from_lin from tinygrad.helpers import ansilen, DEBUG, getenv from tinygrad.shape.symbolic import sym_infer from tinygrad.dtype import dtypes -from tinygrad.engine.realize import create_schedule +from tinygrad.engine.schedule import create_schedule if __name__ == "__main__": if getenv("HALF"): diff --git a/extra/autopad.py b/extra/autopad.py index 25f73968..48a97f92 100644 --- a/extra/autopad.py +++ b/extra/autopad.py @@ -3,7 +3,7 @@ from tinygrad.ops import LoadOps from tinygrad.codegen.linearizer import Linearizer from test.external.fuzz_linearizer import run_linearizer from tinygrad.codegen.kernel import Opt, OptOps -from tinygrad.engine.realize import create_schedule +from tinygrad.engine.schedule import create_schedule N = 17**3 diff --git a/extra/gemm/tvm_gemm.py b/extra/gemm/tvm_gemm.py index 1c5253b6..59613540 100644 --- a/extra/gemm/tvm_gemm.py +++ b/extra/gemm/tvm_gemm.py @@ -30,7 +30,7 @@ except ImportError: import os from tinygrad.tensor import Tensor -from tinygrad.engine.realize import create_schedule +from tinygrad.engine.schedule import create_schedule # define the compute A = Tensor.rand(M, K, device="clang") diff --git a/openpilot/compile2.py b/openpilot/compile2.py index e82119ba..6148f24b 100644 --- a/openpilot/compile2.py +++ b/openpilot/compile2.py @@ -16,7 +16,8 @@ from extra.onnx import get_run_onnx from tinygrad import Tensor, Device, GlobalCounters, dtypes from tinygrad.dtype import ImageDType from tinygrad.helpers import partition, Context, fetch, getenv, GRAPH, DEBUG -from tinygrad.engine.realize import run_schedule, lower_schedule_item, create_schedule +from tinygrad.engine.realize import run_schedule, lower_schedule_item +from tinygrad.engine.schedule import create_schedule from tinygrad.ops import LoadOps, ScheduleItem Device.DEFAULT = "GPU" diff --git a/test/external/external_benchmark_multitensor_allreduce.py b/test/external/external_benchmark_multitensor_allreduce.py index 24496fe3..f765a912 100644 --- a/test/external/external_benchmark_multitensor_allreduce.py +++ b/test/external/external_benchmark_multitensor_allreduce.py @@ -4,7 +4,8 @@ from tinygrad.lazy import LazyBuffer from tinygrad.ops import ReduceOps, GlobalCounters from tinygrad.features.multi import MultiLazyBuffer, all_reduce from tinygrad.engine.jit import TinyJit -from tinygrad.engine.realize import create_schedule, run_schedule +from tinygrad.engine.schedule import create_schedule +from tinygrad.engine.realize import run_schedule from tinygrad.helpers import getenv, Context, RING from typing import List, Union diff --git a/test/external/external_test_hip_compile.py b/test/external/external_test_hip_compile.py index 68fa2b7e..cb7a7a19 100644 --- a/test/external/external_test_hip_compile.py +++ b/test/external/external_test_hip_compile.py @@ -2,7 +2,7 @@ import time, unittest from tinygrad.runtime.driver.hip_comgr import compile_hip from tinygrad import Tensor from tinygrad.device import Device -from tinygrad.engine.realize import create_schedule +from tinygrad.engine.schedule import create_schedule from tinygrad.codegen.linearizer import Linearizer class TestHIPCompileSpeed(unittest.TestCase): diff --git a/test/external/external_test_uops_graphing.py b/test/external/external_test_uops_graphing.py index 82bdd0b1..2f64b6f2 100644 --- a/test/external/external_test_uops_graphing.py +++ b/test/external/external_test_uops_graphing.py @@ -4,7 +4,7 @@ from tinygrad.tensor import Tensor from tinygrad.codegen.linearizer import Linearizer from tinygrad.renderer.cstyle import OpenCLRenderer from tinygrad.features.graph import graph_uops -from tinygrad.engine.realize import create_schedule +from tinygrad.engine.schedule import create_schedule from tinygrad.nn import Conv2d class TestUopsGraph(unittest.TestCase): diff --git a/test/test_conv_shapetracker.py b/test/test_conv_shapetracker.py index acd0bda0..d8447bb5 100644 --- a/test/test_conv_shapetracker.py +++ b/test/test_conv_shapetracker.py @@ -3,7 +3,7 @@ import unittest from tinygrad.tensor import Tensor from tinygrad.ops import LoadOps from tinygrad.nn import Conv2d -from tinygrad.engine.realize import create_schedule +from tinygrad.engine.schedule import create_schedule class TestConvShapetracker(unittest.TestCase): def test_conv_3x3_one_view(self): diff --git a/test/test_dtype_alu.py b/test/test_dtype_alu.py index c1411945..af960665 100644 --- a/test/test_dtype_alu.py +++ b/test/test_dtype_alu.py @@ -6,7 +6,7 @@ import numpy as np from hypothesis import given, strategies as strat, settings from tinygrad.dtype import DType from tinygrad.helpers import CI, getenv -from tinygrad.engine.realize import create_schedule +from tinygrad.engine.schedule import create_schedule from tinygrad.ops import UnaryOps, get_lazyop_info from test.helpers import is_dtype_supported diff --git a/test/test_fusion_op.py b/test/test_fusion_op.py index 5f73a500..7a2ee4e7 100644 --- a/test/test_fusion_op.py +++ b/test/test_fusion_op.py @@ -2,7 +2,8 @@ import unittest import time import numpy as np from tinygrad import Tensor, dtypes -from tinygrad.engine.realize import run_schedule, create_schedule, lower_schedule_item +from tinygrad.engine.schedule import create_schedule +from tinygrad.engine.realize import run_schedule, lower_schedule_item class TestFusionOp(unittest.TestCase): def test_contiguous_add(self): diff --git a/test/test_lazybuffer.py b/test/test_lazybuffer.py index b0bb4b87..7bdf351f 100644 --- a/test/test_lazybuffer.py +++ b/test/test_lazybuffer.py @@ -3,7 +3,7 @@ import numpy as np import unittest from tinygrad import Tensor, Device, dtypes from tinygrad.lazy import LazyBuffer, ReduceOps -from tinygrad.engine.realize import create_schedule +from tinygrad.engine.schedule import create_schedule class TestLazyBuffer(unittest.TestCase): def test_fromcpu_shape_tracker(self): diff --git a/test/test_lazyop.py b/test/test_lazyop.py index 34e707b5..c639b705 100644 --- a/test/test_lazyop.py +++ b/test/test_lazyop.py @@ -1,6 +1,6 @@ import unittest from tinygrad.tensor import Tensor -from tinygrad.engine.realize import create_schedule +from tinygrad.engine.schedule import create_schedule # stuff needed to unpack a kernel # ruff: noqa: F401 diff --git a/test/test_linearizer.py b/test/test_linearizer.py index eab1b6ba..86046100 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -10,7 +10,8 @@ from tinygrad.shape.view import View from tinygrad.shape.symbolic import MulNode, Variable, NumNode, Node from tinygrad.tensor import Tensor from tinygrad.engine.jit import CacheCollector -from tinygrad.engine.realize import create_schedule, run_schedule +from tinygrad.engine.schedule import create_schedule +from tinygrad.engine.realize import run_schedule from tinygrad.helpers import prod, Context from tinygrad.dtype import DType, dtypes from tinygrad.codegen.uops import UOpGraph diff --git a/test/test_multitensor.py b/test/test_multitensor.py index 347d45ea..fd3a3bcb 100644 --- a/test/test_multitensor.py +++ b/test/test_multitensor.py @@ -5,7 +5,7 @@ from tinygrad.device import BufferCopy from tinygrad.ops import LoadOps, ReduceOps from tinygrad.helpers import CI, prod, Context from tinygrad.nn.state import get_parameters, get_state_dict -from tinygrad.engine.realize import create_schedule +from tinygrad.engine.schedule import create_schedule from tinygrad.features.multi import all_reduce, MultiLazyBuffer from random import randint import numpy as np diff --git a/test/test_schedule.py b/test/test_schedule.py index e2128dd8..b86d5c98 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -9,7 +9,7 @@ from tinygrad.ops import LoadOps from tinygrad.helpers import DEBUG, GRAPH from tinygrad.codegen.linearizer import Linearizer from tinygrad.features.graph import print_tree, realized_lazybuffer -from tinygrad.engine.realize import create_schedule +from tinygrad.engine.schedule import create_schedule from tinygrad import nn, dtypes def check_schedule(t:Tensor, allowed:int, to_prerealize:Optional[List[Tensor]]=None, filter_loadops=True): diff --git a/test/test_search.py b/test/test_search.py index 987f156c..0d2d7ff6 100644 --- a/test/test_search.py +++ b/test/test_search.py @@ -1,7 +1,7 @@ import unittest from tinygrad.codegen.linearizer import Linearizer -from tinygrad.engine.realize import create_schedule +from tinygrad.engine.schedule import create_schedule from tinygrad.features.search import time_linearizer, bufs_from_lin from tinygrad.device import Device, Buffer from tinygrad.ops import LoadOps diff --git a/test/test_uops.py b/test/test_uops.py index 5883cc20..9c6f5b69 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -5,7 +5,7 @@ from tinygrad.tensor import Tensor from tinygrad.dtype import dtypes, DType, PtrDType from tinygrad.device import Buffer, Device, CompiledASTRunner from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps -from tinygrad.engine.realize import create_schedule +from tinygrad.engine.schedule import create_schedule from tinygrad.codegen.linearizer import UOps, UOp from tinygrad.codegen.uops import exec_alu, UOpGraph from test.helpers import is_dtype_supported diff --git a/test/test_uops_stats.py b/test/test_uops_stats.py index 6fce3c5e..edeabb6d 100644 --- a/test/test_uops_stats.py +++ b/test/test_uops_stats.py @@ -1,6 +1,7 @@ import unittest from tinygrad import Tensor -from tinygrad.engine.realize import create_schedule, lower_schedule_item +from tinygrad.engine.schedule import create_schedule +from tinygrad.engine.realize import lower_schedule_item # TODO: can copy this in here when we remove it #from tinygrad.ops import get_lazyop_info diff --git a/test/test_winograd.py b/test/test_winograd.py index 8d00a396..80208499 100644 --- a/test/test_winograd.py +++ b/test/test_winograd.py @@ -3,7 +3,7 @@ from tinygrad import Tensor, GlobalCounters from tinygrad.helpers import Timing, CI, Profiling, WINO, DEBUG from tinygrad.ops import LoadOps from tinygrad.codegen.linearizer import Linearizer -from tinygrad.engine.realize import create_schedule +from tinygrad.engine.schedule import create_schedule class TestWinograd(unittest.TestCase): def setUp(self): diff --git a/tinygrad/engine/realize.py b/tinygrad/engine/realize.py index 23757d50..65287f48 100644 --- a/tinygrad/engine/realize.py +++ b/tinygrad/engine/realize.py @@ -1,16 +1,9 @@ -import sys -from collections import defaultdict, deque -from typing import List, Dict, Optional, cast, Set, DefaultDict -from tinygrad.ops import LoadOps, ScheduleItem, BufferOps, GlobalCounters, LazyOp, ReduceOps, ConstBuffer, MemBuffer, BinaryOps, UnaryOps +from typing import List, Dict, Optional, cast +from tinygrad.ops import LoadOps, ScheduleItem, BufferOps, GlobalCounters from tinygrad.device import Device, Buffer, BufferCopy, BufferXfer, BufferRead, JITRunner, update_stats -from tinygrad.features.graph import realized_lazybuffer, log_lazybuffer -from tinygrad.helpers import colored, getenv, GRAPH, cpu_time_execution, DEBUG, prod, dedup, all_int +from tinygrad.features.graph import realized_lazybuffer +from tinygrad.helpers import colored, getenv, GRAPH, cpu_time_execution, DEBUG from tinygrad.shape.symbolic import Variable -from tinygrad.dtype import ImageDType, dtypes -from tinygrad.lazy import LazyBuffer -from tinygrad.shape.shapetracker import ShapeTracker - -# *** schedule running *** class CustomOp(JITRunner): def __init__(self, fxn): @@ -65,203 +58,3 @@ def run_schedule(schedule:List[ScheduleItem]): elif (out:=si.outputs[0]).size > 0: update_stats(colored(f"empty {out.st.size:10d} {out.dtype}", "yellow"), 0, 0, {}, None, 1, device=out.device) if GRAPH: for out in si.outputs: realized_lazybuffer(out, GlobalCounters.kernel_count) - -# *** schedule creation *** - -# creation can recurse a lot -sys.setrecursionlimit(10000) - -# recursively create a lazyop -def _recursive_lazyop(buf:LazyBuffer, inputs:List[LazyBuffer], var_vals:Dict[Variable, int], st:ShapeTracker, - realizes:Set[LazyBuffer], cache, first=True, assign_to:Optional[LazyBuffer]=None) -> LazyOp: - if (buf, st) in cache: return cache[(buf, st)] - if buf != buf.base: - st = buf.st + st - buf = buf.base - # all buffers here are base now - assert buf.op is not None - - # consts are always fused and generated - if buf.op is LoadOps.CONST: - unbound_st, st_var_vals = st.simplify().unbind() - var_vals.update(st_var_vals) - return LazyOp(BufferOps.CONST, (), ConstBuffer(buf.arg, buf.dtype, unbound_st)) - - # if we aren't fusing it, it's a load and we add it to the inputs - if buf.realized or (buf in realizes and not first): - unbound_st, st_var_vals = st.simplify().unbind() - var_vals.update(st_var_vals) - if assign_to is not None and buf is assign_to: - if not unbound_st.contiguous: - # we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine - if not (len(unbound_st.views) == 1 and unbound_st.views[0].mask is not None and - ShapeTracker.from_shape(unbound_st.shape).shrink(unbound_st.views[0].mask) == unbound_st.shrink(unbound_st.views[0].mask)): - raise RuntimeError(f"must be contiguous for assign {unbound_st}") - return LazyOp(BufferOps.LOAD, (), MemBuffer(0, buf.dtype, unbound_st)) - if buf not in inputs: inputs.append(buf) - return LazyOp(BufferOps.LOAD, (), MemBuffer(inputs.index(buf)+1, buf.dtype, unbound_st)) - - # if a CONTIGUOUS or ASSIGN made it all the way here, just skip it - if buf.op is LoadOps.CONTIGUOUS: - assert first - return _recursive_lazyop(buf.srcs[0], inputs, var_vals, st, realizes, cache, False) - if buf.op is LoadOps.ASSIGN: - assert first - assert buf.srcs[1].base is buf.srcs[1], "assign must be to base" - assert buf.srcs[1].realized is not None, f"assign must be already realized to schedule {buf.srcs[1]}" - return _recursive_lazyop(buf.srcs[0], inputs, var_vals, st, realizes, cache, False, assign_to=buf.srcs[1]) - - # if it's a reduce, we have to change the shapetracker - if buf.op in ReduceOps: - assert st.contiguous, "ReduceOps late fusion must be contiguous" - st = ShapeTracker.from_shape(buf.srcs[0].shape) - - # otherwise we fuse it like normal - cache[(buf, st)] = ret = \ - LazyOp(buf.op, tuple(_recursive_lazyop(x, inputs, var_vals, st, realizes, cache, False, assign_to) for x in buf.srcs), buf.arg) - return ret - -def _schedule_one(out:LazyBuffer, realizes:Set[LazyBuffer], reduce_for_op: Dict[LazyBuffer, LazyBuffer]) -> ScheduleItem: - inputs: List[LazyBuffer] = [] - var_vals: Dict[Variable, int] = out.st.var_vals.copy() - if out.op in {LoadOps.CUSTOM, LoadOps.SYNC, LoadOps.WAIT, LoadOps.COPY, LoadOps.EMPTY}: - op, inputs = LazyOp(out.op, (), out.arg), list(out.srcs) - else: - output_st = ShapeTracker.from_shape(reduce_for_op[out].shape if out in reduce_for_op else out.shape) - op = _recursive_lazyop(out, inputs, var_vals, output_st, realizes, cache={}) - op = LazyOp(BufferOps.STORE, (op, ), MemBuffer(0, out.dtype, output_st.simplify().unbind()[0])) - return ScheduleItem((op,), (out,), tuple(inputs), var_vals) - -# recursively search the entire graph for all LazyBuffers, insert realizes after expands -def _recurse_lb(buf:LazyBuffer, realizes:Set[LazyBuffer], allbufs:Dict[LazyBuffer, None], - simple_pads:Set[LazyBuffer], children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], scheduled=False): - if buf in allbufs or buf.base.realized: return - if GRAPH: log_lazybuffer(buf, scheduled) - if isinstance(buf.dtype, ImageDType) and (prod(buf.shape) != prod(buf.dtype.shape) or - not any(buf.shape[x]%4 == 0 for x in buf.st.unit_stride_axes())): - if DEBUG >= 3: print(f"forcing image {buf.dtype} with shape {buf.shape} to float32") - buf.dtype = dtypes.float32 # NOTE: this is what makes the dtype above not match - if buf.base != buf: - # realize all places where the buffer is expanded - if prod(buf.base.st.shape) < prod(buf.st.shape): - if len(buf.st.views) == 1 and buf.st.views[-1].mask and all_int(buf.base.st.shape) and \ - prod(buf.base.st.shape) >= prod([y-x for x,y in buf.st.views[-1].mask]): - simple_pads.add(buf.base) - else: - realizes.add(buf.base) - return _recurse_lb(buf.base, realizes, allbufs, simple_pads, children) - if buf.forced_realize: realizes.add(buf) - allbufs[buf] = None - if buf.op in LoadOps: realizes.add(buf.base) - if buf.op == LoadOps.COPY: - assert buf.srcs[0].st.contiguous and buf.srcs[0].size == buf.srcs[0].base.size, "can only copy contig" - realizes.add(buf.srcs[0].base) - for x in buf.srcs: - children[x.base][buf] = None - _recurse_lb(x, realizes, allbufs, simple_pads, children) - -UNSAFE_PAD_OPS = {BinaryOps.DIV, BinaryOps.CMPLT, BinaryOps.CMPEQ, UnaryOps.LOG2, UnaryOps.EXP2} -def _is_padding_okay(buf:LazyBuffer, realizes:Set[LazyBuffer]) -> bool: - if buf in realizes or buf.realized: return True - # NOTE: this broke to_image_idx and coder with JIT - if buf.op in UNSAFE_PAD_OPS: return False - return all(_is_padding_okay(x.base, realizes) for x in buf.srcs) - -def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None) -> List[ScheduleItem]: - if seen is None: seen = set() - - # start by just realizing the buffers passed in - realizes: Set[LazyBuffer] = set([x.base for x in outs if not x.base.realized]) - allbufs: Dict[LazyBuffer, None] = {} - simple_pads: Set[LazyBuffer] = set() - children: DefaultDict[LazyBuffer, Dict[LazyBuffer, None]] = defaultdict(dict) - for out in outs: _recurse_lb(out.base, realizes, allbufs, simple_pads, children, scheduled=True) - - # check if we have to realize pads - for p in simple_pads: - if not _is_padding_okay(p, realizes): - realizes.add(p) - - # find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child) - reduce_for_op: Dict[LazyBuffer, LazyBuffer] = {} - for r in allbufs.keys(): - if r != r.base or r.op not in ReduceOps or r in realizes: continue - - # follow the reduce down - child_set: Dict[LazyBuffer, ShapeTracker] = {r: r.st} - realized_children: Dict[LazyBuffer, ShapeTracker] = {} - forced_realize = False - can_chase = True - while not forced_realize and len(child_set): - next_child_set = {} - for tr,st in child_set.items(): - if tr in realizes: - realized_children[tr] = st - # can only have one output buffer - # can only reduce contiguous - # max one reduceop per kernel - if len(realized_children) > 1 or not st.contiguous or st.size != r.st.size or (tr in reduce_for_op and reduce_for_op[tr] != r): - can_chase = tr not in reduce_for_op or reduce_for_op[tr] == r - forced_realize = True - break - continue - for tr_next in children[tr].keys(): - if not tr_next.realized: - # max one reduceop per kernel - if tr_next.op in ReduceOps: - forced_realize = True - break - st_childs = dedup([s for s in tr_next.srcs if s.base == tr]) - if len(st_childs) > 1: - forced_realize = True - break - next_child_set[tr_next] = st + st_childs[0].st - child_set = next_child_set - if forced_realize: - tr = r - if can_chase: - # can chase this down to contiguous children - st = tr.st - while len(children[tr]) == 1: - tr_next = next(iter(children[tr].keys())) - st_childs = dedup([s for s in tr_next.srcs if s.base == tr]) - if len(st_childs) > 1: break - if st.size != st_childs[0].st.size: break - st = st + st_childs[0].st - if not st.contiguous or tr_next.op in ReduceOps: break - tr = tr_next - reduce_for_op[tr] = r - realizes.add(tr) - else: - assert len(realized_children) == 1 - reduce_for_op[next(iter(realized_children.keys()))] = r - - # preschedule all buffers in realizes - prescheduled = {x:_schedule_one(x, realizes, reduce_for_op) for x in realizes if x not in seen and x.realized is None and x.op is not LoadOps.CONST} - assign_targets = {x.srcs[1]:x for x in realizes if x.op is LoadOps.ASSIGN and x not in seen and x.realized is None} - - # breadth first ordering - graph: DefaultDict[LazyBuffer,List[LazyBuffer]] = defaultdict(list) - in_degree: DefaultDict[LazyBuffer,int] = defaultdict(int) - for out, si in prescheduled.items(): - for x in si.inputs: - graph[x].append(out) - if x in assign_targets: - graph[out].append(assign_targets[x]) - in_degree[assign_targets[x]] += 1 - if x in prescheduled: in_degree[out] += 1 - - queue = deque(out for out in prescheduled if in_degree[out] == 0) - schedule: List[ScheduleItem] = [] - while queue: - buf = queue.popleft() - seen.add(buf) - schedule.append(prescheduled[buf]) - for x in graph[buf]: - in_degree[x] -= 1 - if in_degree[x] == 0: queue.append(x) - - # confirm everything was scheduled - assert len(prescheduled) == len(schedule), f"prescheduled {len(prescheduled)} but only scheduled {len(schedule)}" - return schedule - diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py new file mode 100644 index 00000000..2504ca75 --- /dev/null +++ b/tinygrad/engine/schedule.py @@ -0,0 +1,203 @@ +from collections import defaultdict, deque +from typing import List, Dict, Optional, Set, DefaultDict +from tinygrad.ops import LoadOps, ScheduleItem, BufferOps, LazyOp, ReduceOps, ConstBuffer, MemBuffer, BinaryOps, UnaryOps +from tinygrad.features.graph import log_lazybuffer +from tinygrad.helpers import GRAPH, DEBUG, prod, dedup, all_int +from tinygrad.shape.symbolic import Variable +from tinygrad.dtype import ImageDType, dtypes +from tinygrad.lazy import LazyBuffer +from tinygrad.shape.shapetracker import ShapeTracker + +# recursively create a lazyop +def _recursive_lazyop(buf:LazyBuffer, inputs:List[LazyBuffer], var_vals:Dict[Variable, int], st:ShapeTracker, + realizes:Set[LazyBuffer], cache, first=True, assign_to:Optional[LazyBuffer]=None) -> LazyOp: + if (buf, st) in cache: return cache[(buf, st)] + if buf != buf.base: + st = buf.st + st + buf = buf.base + # all buffers here are base now + assert buf.op is not None + + # consts are always fused and generated + if buf.op is LoadOps.CONST: + unbound_st, st_var_vals = st.simplify().unbind() + var_vals.update(st_var_vals) + return LazyOp(BufferOps.CONST, (), ConstBuffer(buf.arg, buf.dtype, unbound_st)) + + # if we aren't fusing it, it's a load and we add it to the inputs + if buf.realized or (buf in realizes and not first): + unbound_st, st_var_vals = st.simplify().unbind() + var_vals.update(st_var_vals) + if assign_to is not None and buf is assign_to: + if not unbound_st.contiguous: + # we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine + if not (len(unbound_st.views) == 1 and unbound_st.views[0].mask is not None and + ShapeTracker.from_shape(unbound_st.shape).shrink(unbound_st.views[0].mask) == unbound_st.shrink(unbound_st.views[0].mask)): + raise RuntimeError(f"must be contiguous for assign {unbound_st}") + return LazyOp(BufferOps.LOAD, (), MemBuffer(0, buf.dtype, unbound_st)) + if buf not in inputs: inputs.append(buf) + return LazyOp(BufferOps.LOAD, (), MemBuffer(inputs.index(buf)+1, buf.dtype, unbound_st)) + + # if a CONTIGUOUS or ASSIGN made it all the way here, just skip it + if buf.op is LoadOps.CONTIGUOUS: + assert first + return _recursive_lazyop(buf.srcs[0], inputs, var_vals, st, realizes, cache, False) + if buf.op is LoadOps.ASSIGN: + assert first + assert buf.srcs[1].base is buf.srcs[1], "assign must be to base" + assert buf.srcs[1].realized is not None, f"assign must be already realized to schedule {buf.srcs[1]}" + return _recursive_lazyop(buf.srcs[0], inputs, var_vals, st, realizes, cache, False, assign_to=buf.srcs[1]) + + # if it's a reduce, we have to change the shapetracker + if buf.op in ReduceOps: + assert st.contiguous, "ReduceOps late fusion must be contiguous" + st = ShapeTracker.from_shape(buf.srcs[0].shape) + + # otherwise we fuse it like normal + cache[(buf, st)] = ret = \ + LazyOp(buf.op, tuple(_recursive_lazyop(x, inputs, var_vals, st, realizes, cache, False, assign_to) for x in buf.srcs), buf.arg) + return ret + +def _schedule_one(out:LazyBuffer, realizes:Set[LazyBuffer], reduce_for_op: Dict[LazyBuffer, LazyBuffer]) -> ScheduleItem: + inputs: List[LazyBuffer] = [] + var_vals: Dict[Variable, int] = out.st.var_vals.copy() + if out.op in {LoadOps.CUSTOM, LoadOps.SYNC, LoadOps.WAIT, LoadOps.COPY, LoadOps.EMPTY}: + op, inputs = LazyOp(out.op, (), out.arg), list(out.srcs) + else: + output_st = ShapeTracker.from_shape(reduce_for_op[out].shape if out in reduce_for_op else out.shape) + op = _recursive_lazyop(out, inputs, var_vals, output_st, realizes, cache={}) + op = LazyOp(BufferOps.STORE, (op, ), MemBuffer(0, out.dtype, output_st.simplify().unbind()[0])) + return ScheduleItem((op,), (out,), tuple(inputs), var_vals) + +# recursively search the entire graph for all LazyBuffers, insert realizes after expands +def _recurse_lb(buf:LazyBuffer, realizes:Set[LazyBuffer], allbufs:Dict[LazyBuffer, None], + simple_pads:Set[LazyBuffer], children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], scheduled=False): + if buf in allbufs or buf.base.realized: return + if GRAPH: log_lazybuffer(buf, scheduled) + if isinstance(buf.dtype, ImageDType) and (prod(buf.shape) != prod(buf.dtype.shape) or + not any(buf.shape[x]%4 == 0 for x in buf.st.unit_stride_axes())): + if DEBUG >= 3: print(f"forcing image {buf.dtype} with shape {buf.shape} to float32") + buf.dtype = dtypes.float32 # NOTE: this is what makes the dtype above not match + if buf.base != buf: + # realize all places where the buffer is expanded + if prod(buf.base.st.shape) < prod(buf.st.shape): + if len(buf.st.views) == 1 and buf.st.views[-1].mask and all_int(buf.base.st.shape) and \ + prod(buf.base.st.shape) >= prod([y-x for x,y in buf.st.views[-1].mask]): + simple_pads.add(buf.base) + else: + realizes.add(buf.base) + return _recurse_lb(buf.base, realizes, allbufs, simple_pads, children) + if buf.forced_realize: realizes.add(buf) + allbufs[buf] = None + if buf.op in LoadOps: realizes.add(buf.base) + if buf.op == LoadOps.COPY: + assert buf.srcs[0].st.contiguous and buf.srcs[0].size == buf.srcs[0].base.size, "can only copy contig" + realizes.add(buf.srcs[0].base) + for x in buf.srcs: + children[x.base][buf] = None + _recurse_lb(x, realizes, allbufs, simple_pads, children) + +UNSAFE_PAD_OPS = {BinaryOps.DIV, BinaryOps.CMPLT, BinaryOps.CMPEQ, UnaryOps.LOG2, UnaryOps.EXP2} +def _is_padding_okay(buf:LazyBuffer, realizes:Set[LazyBuffer]) -> bool: + if buf in realizes or buf.realized: return True + # NOTE: this broke to_image_idx and coder with JIT + if buf.op in UNSAFE_PAD_OPS: return False + return all(_is_padding_okay(x.base, realizes) for x in buf.srcs) + +def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None) -> List[ScheduleItem]: + if seen is None: seen = set() + + # start by just realizing the buffers passed in + realizes: Set[LazyBuffer] = set([x.base for x in outs if not x.base.realized]) + allbufs: Dict[LazyBuffer, None] = {} + simple_pads: Set[LazyBuffer] = set() + children: DefaultDict[LazyBuffer, Dict[LazyBuffer, None]] = defaultdict(dict) + for out in outs: _recurse_lb(out.base, realizes, allbufs, simple_pads, children, scheduled=True) + + # check if we have to realize pads + for p in simple_pads: + if not _is_padding_okay(p, realizes): + realizes.add(p) + + # find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child) + reduce_for_op: Dict[LazyBuffer, LazyBuffer] = {} + for r in allbufs.keys(): + if r != r.base or r.op not in ReduceOps or r in realizes: continue + + # follow the reduce down + child_set: Dict[LazyBuffer, ShapeTracker] = {r: r.st} + realized_children: Dict[LazyBuffer, ShapeTracker] = {} + forced_realize = False + can_chase = True + while not forced_realize and len(child_set): + next_child_set = {} + for tr,st in child_set.items(): + if tr in realizes: + realized_children[tr] = st + # can only have one output buffer + # can only reduce contiguous + # max one reduceop per kernel + if len(realized_children) > 1 or not st.contiguous or st.size != r.st.size or (tr in reduce_for_op and reduce_for_op[tr] != r): + can_chase = tr not in reduce_for_op or reduce_for_op[tr] == r + forced_realize = True + break + continue + for tr_next in children[tr].keys(): + if not tr_next.realized: + # max one reduceop per kernel + if tr_next.op in ReduceOps: + forced_realize = True + break + st_childs = dedup([s for s in tr_next.srcs if s.base == tr]) + if len(st_childs) > 1: + forced_realize = True + break + next_child_set[tr_next] = st + st_childs[0].st + child_set = next_child_set + if forced_realize: + tr = r + if can_chase: + # can chase this down to contiguous children + st = tr.st + while len(children[tr]) == 1: + tr_next = next(iter(children[tr].keys())) + st_childs = dedup([s for s in tr_next.srcs if s.base == tr]) + if len(st_childs) > 1: break + if st.size != st_childs[0].st.size: break + st = st + st_childs[0].st + if not st.contiguous or tr_next.op in ReduceOps: break + tr = tr_next + reduce_for_op[tr] = r + realizes.add(tr) + else: + assert len(realized_children) == 1 + reduce_for_op[next(iter(realized_children.keys()))] = r + + # preschedule all buffers in realizes + prescheduled = {x:_schedule_one(x, realizes, reduce_for_op) for x in realizes if x not in seen and x.realized is None and x.op is not LoadOps.CONST} + assign_targets = {x.srcs[1]:x for x in realizes if x.op is LoadOps.ASSIGN and x not in seen and x.realized is None} + + # breadth first ordering + graph: DefaultDict[LazyBuffer,List[LazyBuffer]] = defaultdict(list) + in_degree: DefaultDict[LazyBuffer,int] = defaultdict(int) + for out, si in prescheduled.items(): + for x in si.inputs: + graph[x].append(out) + if x in assign_targets: + graph[out].append(assign_targets[x]) + in_degree[assign_targets[x]] += 1 + if x in prescheduled: in_degree[out] += 1 + + queue = deque(out for out in prescheduled if in_degree[out] == 0) + schedule: List[ScheduleItem] = [] + while queue: + buf = queue.popleft() + seen.add(buf) + schedule.append(prescheduled[buf]) + for x in graph[buf]: + in_degree[x] -= 1 + if in_degree[x] == 0: queue.append(x) + + # confirm everything was scheduled + assert len(prescheduled) == len(schedule), f"prescheduled {len(prescheduled)} but only scheduled {len(schedule)}" + return schedule diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 31959c68..4b7b8b58 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -14,7 +14,8 @@ from tinygrad.features.multi import MultiLazyBuffer from tinygrad.ops import LoadOps from tinygrad.device import Buffer, Device from tinygrad.shape.symbolic import sint -from tinygrad.engine.realize import run_schedule, create_schedule +from tinygrad.engine.realize import run_schedule +from tinygrad.engine.schedule import create_schedule # **** start with two base classes, Tensor and Function ****