split to schedule.py (#3949)

* split to schedule.py

* split
This commit is contained in:
George Hotz 2024-03-26 21:02:46 -07:00 committed by GitHub
parent da07f31fd4
commit 68ca4d4276
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 235 additions and 232 deletions

View File

@ -73,7 +73,8 @@ assert out.as_buffer().cast('I')[0] == 5
print("******** third, the LazyBuffer ***********") print("******** third, the LazyBuffer ***********")
from tinygrad.lazy import LazyBuffer, LoadOps from tinygrad.lazy import LazyBuffer, LoadOps
from tinygrad.engine.realize import run_schedule, create_schedule from tinygrad.engine.realize import run_schedule
from tinygrad.engine.schedule import create_schedule
# allocate some values + load in values # allocate some values + load in values
a = LazyBuffer.loadop(LoadOps.EMPTY, (1,), dtypes.int32, DEVICE) a = LazyBuffer.loadop(LoadOps.EMPTY, (1,), dtypes.int32, DEVICE)

View File

@ -8,7 +8,7 @@ from tinygrad.features.search import time_linearizer, beam_search, bufs_from_lin
from tinygrad.helpers import ansilen, DEBUG, getenv from tinygrad.helpers import ansilen, DEBUG, getenv
from tinygrad.shape.symbolic import sym_infer from tinygrad.shape.symbolic import sym_infer
from tinygrad.dtype import dtypes from tinygrad.dtype import dtypes
from tinygrad.engine.realize import create_schedule from tinygrad.engine.schedule import create_schedule
if __name__ == "__main__": if __name__ == "__main__":
if getenv("HALF"): if getenv("HALF"):

View File

@ -3,7 +3,7 @@ from tinygrad.ops import LoadOps
from tinygrad.codegen.linearizer import Linearizer from tinygrad.codegen.linearizer import Linearizer
from test.external.fuzz_linearizer import run_linearizer from test.external.fuzz_linearizer import run_linearizer
from tinygrad.codegen.kernel import Opt, OptOps from tinygrad.codegen.kernel import Opt, OptOps
from tinygrad.engine.realize import create_schedule from tinygrad.engine.schedule import create_schedule
N = 17**3 N = 17**3

View File

@ -30,7 +30,7 @@ except ImportError:
import os import os
from tinygrad.tensor import Tensor from tinygrad.tensor import Tensor
from tinygrad.engine.realize import create_schedule from tinygrad.engine.schedule import create_schedule
# define the compute # define the compute
A = Tensor.rand(M, K, device="clang") A = Tensor.rand(M, K, device="clang")

View File

@ -16,7 +16,8 @@ from extra.onnx import get_run_onnx
from tinygrad import Tensor, Device, GlobalCounters, dtypes from tinygrad import Tensor, Device, GlobalCounters, dtypes
from tinygrad.dtype import ImageDType from tinygrad.dtype import ImageDType
from tinygrad.helpers import partition, Context, fetch, getenv, GRAPH, DEBUG from tinygrad.helpers import partition, Context, fetch, getenv, GRAPH, DEBUG
from tinygrad.engine.realize import run_schedule, lower_schedule_item, create_schedule from tinygrad.engine.realize import run_schedule, lower_schedule_item
from tinygrad.engine.schedule import create_schedule
from tinygrad.ops import LoadOps, ScheduleItem from tinygrad.ops import LoadOps, ScheduleItem
Device.DEFAULT = "GPU" Device.DEFAULT = "GPU"

View File

@ -4,7 +4,8 @@ from tinygrad.lazy import LazyBuffer
from tinygrad.ops import ReduceOps, GlobalCounters from tinygrad.ops import ReduceOps, GlobalCounters
from tinygrad.features.multi import MultiLazyBuffer, all_reduce from tinygrad.features.multi import MultiLazyBuffer, all_reduce
from tinygrad.engine.jit import TinyJit from tinygrad.engine.jit import TinyJit
from tinygrad.engine.realize import create_schedule, run_schedule from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import run_schedule
from tinygrad.helpers import getenv, Context, RING from tinygrad.helpers import getenv, Context, RING
from typing import List, Union from typing import List, Union

View File

@ -2,7 +2,7 @@ import time, unittest
from tinygrad.runtime.driver.hip_comgr import compile_hip from tinygrad.runtime.driver.hip_comgr import compile_hip
from tinygrad import Tensor from tinygrad import Tensor
from tinygrad.device import Device from tinygrad.device import Device
from tinygrad.engine.realize import create_schedule from tinygrad.engine.schedule import create_schedule
from tinygrad.codegen.linearizer import Linearizer from tinygrad.codegen.linearizer import Linearizer
class TestHIPCompileSpeed(unittest.TestCase): class TestHIPCompileSpeed(unittest.TestCase):

View File

@ -4,7 +4,7 @@ from tinygrad.tensor import Tensor
from tinygrad.codegen.linearizer import Linearizer from tinygrad.codegen.linearizer import Linearizer
from tinygrad.renderer.cstyle import OpenCLRenderer from tinygrad.renderer.cstyle import OpenCLRenderer
from tinygrad.features.graph import graph_uops from tinygrad.features.graph import graph_uops
from tinygrad.engine.realize import create_schedule from tinygrad.engine.schedule import create_schedule
from tinygrad.nn import Conv2d from tinygrad.nn import Conv2d
class TestUopsGraph(unittest.TestCase): class TestUopsGraph(unittest.TestCase):

View File

@ -3,7 +3,7 @@ import unittest
from tinygrad.tensor import Tensor from tinygrad.tensor import Tensor
from tinygrad.ops import LoadOps from tinygrad.ops import LoadOps
from tinygrad.nn import Conv2d from tinygrad.nn import Conv2d
from tinygrad.engine.realize import create_schedule from tinygrad.engine.schedule import create_schedule
class TestConvShapetracker(unittest.TestCase): class TestConvShapetracker(unittest.TestCase):
def test_conv_3x3_one_view(self): def test_conv_3x3_one_view(self):

View File

@ -6,7 +6,7 @@ import numpy as np
from hypothesis import given, strategies as strat, settings from hypothesis import given, strategies as strat, settings
from tinygrad.dtype import DType from tinygrad.dtype import DType
from tinygrad.helpers import CI, getenv from tinygrad.helpers import CI, getenv
from tinygrad.engine.realize import create_schedule from tinygrad.engine.schedule import create_schedule
from tinygrad.ops import UnaryOps, get_lazyop_info from tinygrad.ops import UnaryOps, get_lazyop_info
from test.helpers import is_dtype_supported from test.helpers import is_dtype_supported

View File

@ -2,7 +2,8 @@ import unittest
import time import time
import numpy as np import numpy as np
from tinygrad import Tensor, dtypes from tinygrad import Tensor, dtypes
from tinygrad.engine.realize import run_schedule, create_schedule, lower_schedule_item from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import run_schedule, lower_schedule_item
class TestFusionOp(unittest.TestCase): class TestFusionOp(unittest.TestCase):
def test_contiguous_add(self): def test_contiguous_add(self):

View File

@ -3,7 +3,7 @@ import numpy as np
import unittest import unittest
from tinygrad import Tensor, Device, dtypes from tinygrad import Tensor, Device, dtypes
from tinygrad.lazy import LazyBuffer, ReduceOps from tinygrad.lazy import LazyBuffer, ReduceOps
from tinygrad.engine.realize import create_schedule from tinygrad.engine.schedule import create_schedule
class TestLazyBuffer(unittest.TestCase): class TestLazyBuffer(unittest.TestCase):
def test_fromcpu_shape_tracker(self): def test_fromcpu_shape_tracker(self):

View File

@ -1,6 +1,6 @@
import unittest import unittest
from tinygrad.tensor import Tensor from tinygrad.tensor import Tensor
from tinygrad.engine.realize import create_schedule from tinygrad.engine.schedule import create_schedule
# stuff needed to unpack a kernel # stuff needed to unpack a kernel
# ruff: noqa: F401 # ruff: noqa: F401

View File

@ -10,7 +10,8 @@ from tinygrad.shape.view import View
from tinygrad.shape.symbolic import MulNode, Variable, NumNode, Node from tinygrad.shape.symbolic import MulNode, Variable, NumNode, Node
from tinygrad.tensor import Tensor from tinygrad.tensor import Tensor
from tinygrad.engine.jit import CacheCollector from tinygrad.engine.jit import CacheCollector
from tinygrad.engine.realize import create_schedule, run_schedule from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import run_schedule
from tinygrad.helpers import prod, Context from tinygrad.helpers import prod, Context
from tinygrad.dtype import DType, dtypes from tinygrad.dtype import DType, dtypes
from tinygrad.codegen.uops import UOpGraph from tinygrad.codegen.uops import UOpGraph

View File

@ -5,7 +5,7 @@ from tinygrad.device import BufferCopy
from tinygrad.ops import LoadOps, ReduceOps from tinygrad.ops import LoadOps, ReduceOps
from tinygrad.helpers import CI, prod, Context from tinygrad.helpers import CI, prod, Context
from tinygrad.nn.state import get_parameters, get_state_dict from tinygrad.nn.state import get_parameters, get_state_dict
from tinygrad.engine.realize import create_schedule from tinygrad.engine.schedule import create_schedule
from tinygrad.features.multi import all_reduce, MultiLazyBuffer from tinygrad.features.multi import all_reduce, MultiLazyBuffer
from random import randint from random import randint
import numpy as np import numpy as np

View File

@ -9,7 +9,7 @@ from tinygrad.ops import LoadOps
from tinygrad.helpers import DEBUG, GRAPH from tinygrad.helpers import DEBUG, GRAPH
from tinygrad.codegen.linearizer import Linearizer from tinygrad.codegen.linearizer import Linearizer
from tinygrad.features.graph import print_tree, realized_lazybuffer from tinygrad.features.graph import print_tree, realized_lazybuffer
from tinygrad.engine.realize import create_schedule from tinygrad.engine.schedule import create_schedule
from tinygrad import nn, dtypes from tinygrad import nn, dtypes
def check_schedule(t:Tensor, allowed:int, to_prerealize:Optional[List[Tensor]]=None, filter_loadops=True): def check_schedule(t:Tensor, allowed:int, to_prerealize:Optional[List[Tensor]]=None, filter_loadops=True):

View File

@ -1,7 +1,7 @@
import unittest import unittest
from tinygrad.codegen.linearizer import Linearizer from tinygrad.codegen.linearizer import Linearizer
from tinygrad.engine.realize import create_schedule from tinygrad.engine.schedule import create_schedule
from tinygrad.features.search import time_linearizer, bufs_from_lin from tinygrad.features.search import time_linearizer, bufs_from_lin
from tinygrad.device import Device, Buffer from tinygrad.device import Device, Buffer
from tinygrad.ops import LoadOps from tinygrad.ops import LoadOps

View File

@ -5,7 +5,7 @@ from tinygrad.tensor import Tensor
from tinygrad.dtype import dtypes, DType, PtrDType from tinygrad.dtype import dtypes, DType, PtrDType
from tinygrad.device import Buffer, Device, CompiledASTRunner from tinygrad.device import Buffer, Device, CompiledASTRunner
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
from tinygrad.engine.realize import create_schedule from tinygrad.engine.schedule import create_schedule
from tinygrad.codegen.linearizer import UOps, UOp from tinygrad.codegen.linearizer import UOps, UOp
from tinygrad.codegen.uops import exec_alu, UOpGraph from tinygrad.codegen.uops import exec_alu, UOpGraph
from test.helpers import is_dtype_supported from test.helpers import is_dtype_supported

View File

@ -1,6 +1,7 @@
import unittest import unittest
from tinygrad import Tensor from tinygrad import Tensor
from tinygrad.engine.realize import create_schedule, lower_schedule_item from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import lower_schedule_item
# TODO: can copy this in here when we remove it # TODO: can copy this in here when we remove it
#from tinygrad.ops import get_lazyop_info #from tinygrad.ops import get_lazyop_info

View File

@ -3,7 +3,7 @@ from tinygrad import Tensor, GlobalCounters
from tinygrad.helpers import Timing, CI, Profiling, WINO, DEBUG from tinygrad.helpers import Timing, CI, Profiling, WINO, DEBUG
from tinygrad.ops import LoadOps from tinygrad.ops import LoadOps
from tinygrad.codegen.linearizer import Linearizer from tinygrad.codegen.linearizer import Linearizer
from tinygrad.engine.realize import create_schedule from tinygrad.engine.schedule import create_schedule
class TestWinograd(unittest.TestCase): class TestWinograd(unittest.TestCase):
def setUp(self): def setUp(self):

View File

@ -1,16 +1,9 @@
import sys from typing import List, Dict, Optional, cast
from collections import defaultdict, deque from tinygrad.ops import LoadOps, ScheduleItem, BufferOps, GlobalCounters
from typing import List, Dict, Optional, cast, Set, DefaultDict
from tinygrad.ops import LoadOps, ScheduleItem, BufferOps, GlobalCounters, LazyOp, ReduceOps, ConstBuffer, MemBuffer, BinaryOps, UnaryOps
from tinygrad.device import Device, Buffer, BufferCopy, BufferXfer, BufferRead, JITRunner, update_stats from tinygrad.device import Device, Buffer, BufferCopy, BufferXfer, BufferRead, JITRunner, update_stats
from tinygrad.features.graph import realized_lazybuffer, log_lazybuffer from tinygrad.features.graph import realized_lazybuffer
from tinygrad.helpers import colored, getenv, GRAPH, cpu_time_execution, DEBUG, prod, dedup, all_int from tinygrad.helpers import colored, getenv, GRAPH, cpu_time_execution, DEBUG
from tinygrad.shape.symbolic import Variable from tinygrad.shape.symbolic import Variable
from tinygrad.dtype import ImageDType, dtypes
from tinygrad.lazy import LazyBuffer
from tinygrad.shape.shapetracker import ShapeTracker
# *** schedule running ***
class CustomOp(JITRunner): class CustomOp(JITRunner):
def __init__(self, fxn): def __init__(self, fxn):
@ -65,203 +58,3 @@ def run_schedule(schedule:List[ScheduleItem]):
elif (out:=si.outputs[0]).size > 0: update_stats(colored(f"empty {out.st.size:10d} {out.dtype}", "yellow"), 0, 0, {}, None, 1, device=out.device) elif (out:=si.outputs[0]).size > 0: update_stats(colored(f"empty {out.st.size:10d} {out.dtype}", "yellow"), 0, 0, {}, None, 1, device=out.device)
if GRAPH: if GRAPH:
for out in si.outputs: realized_lazybuffer(out, GlobalCounters.kernel_count) for out in si.outputs: realized_lazybuffer(out, GlobalCounters.kernel_count)
# *** schedule creation ***
# creation can recurse a lot
sys.setrecursionlimit(10000)
# recursively create a lazyop
def _recursive_lazyop(buf:LazyBuffer, inputs:List[LazyBuffer], var_vals:Dict[Variable, int], st:ShapeTracker,
realizes:Set[LazyBuffer], cache, first=True, assign_to:Optional[LazyBuffer]=None) -> LazyOp:
if (buf, st) in cache: return cache[(buf, st)]
if buf != buf.base:
st = buf.st + st
buf = buf.base
# all buffers here are base now
assert buf.op is not None
# consts are always fused and generated
if buf.op is LoadOps.CONST:
unbound_st, st_var_vals = st.simplify().unbind()
var_vals.update(st_var_vals)
return LazyOp(BufferOps.CONST, (), ConstBuffer(buf.arg, buf.dtype, unbound_st))
# if we aren't fusing it, it's a load and we add it to the inputs
if buf.realized or (buf in realizes and not first):
unbound_st, st_var_vals = st.simplify().unbind()
var_vals.update(st_var_vals)
if assign_to is not None and buf is assign_to:
if not unbound_st.contiguous:
# we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine
if not (len(unbound_st.views) == 1 and unbound_st.views[0].mask is not None and
ShapeTracker.from_shape(unbound_st.shape).shrink(unbound_st.views[0].mask) == unbound_st.shrink(unbound_st.views[0].mask)):
raise RuntimeError(f"must be contiguous for assign {unbound_st}")
return LazyOp(BufferOps.LOAD, (), MemBuffer(0, buf.dtype, unbound_st))
if buf not in inputs: inputs.append(buf)
return LazyOp(BufferOps.LOAD, (), MemBuffer(inputs.index(buf)+1, buf.dtype, unbound_st))
# if a CONTIGUOUS or ASSIGN made it all the way here, just skip it
if buf.op is LoadOps.CONTIGUOUS:
assert first
return _recursive_lazyop(buf.srcs[0], inputs, var_vals, st, realizes, cache, False)
if buf.op is LoadOps.ASSIGN:
assert first
assert buf.srcs[1].base is buf.srcs[1], "assign must be to base"
assert buf.srcs[1].realized is not None, f"assign must be already realized to schedule {buf.srcs[1]}"
return _recursive_lazyop(buf.srcs[0], inputs, var_vals, st, realizes, cache, False, assign_to=buf.srcs[1])
# if it's a reduce, we have to change the shapetracker
if buf.op in ReduceOps:
assert st.contiguous, "ReduceOps late fusion must be contiguous"
st = ShapeTracker.from_shape(buf.srcs[0].shape)
# otherwise we fuse it like normal
cache[(buf, st)] = ret = \
LazyOp(buf.op, tuple(_recursive_lazyop(x, inputs, var_vals, st, realizes, cache, False, assign_to) for x in buf.srcs), buf.arg)
return ret
def _schedule_one(out:LazyBuffer, realizes:Set[LazyBuffer], reduce_for_op: Dict[LazyBuffer, LazyBuffer]) -> ScheduleItem:
inputs: List[LazyBuffer] = []
var_vals: Dict[Variable, int] = out.st.var_vals.copy()
if out.op in {LoadOps.CUSTOM, LoadOps.SYNC, LoadOps.WAIT, LoadOps.COPY, LoadOps.EMPTY}:
op, inputs = LazyOp(out.op, (), out.arg), list(out.srcs)
else:
output_st = ShapeTracker.from_shape(reduce_for_op[out].shape if out in reduce_for_op else out.shape)
op = _recursive_lazyop(out, inputs, var_vals, output_st, realizes, cache={})
op = LazyOp(BufferOps.STORE, (op, ), MemBuffer(0, out.dtype, output_st.simplify().unbind()[0]))
return ScheduleItem((op,), (out,), tuple(inputs), var_vals)
# recursively search the entire graph for all LazyBuffers, insert realizes after expands
def _recurse_lb(buf:LazyBuffer, realizes:Set[LazyBuffer], allbufs:Dict[LazyBuffer, None],
simple_pads:Set[LazyBuffer], children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], scheduled=False):
if buf in allbufs or buf.base.realized: return
if GRAPH: log_lazybuffer(buf, scheduled)
if isinstance(buf.dtype, ImageDType) and (prod(buf.shape) != prod(buf.dtype.shape) or
not any(buf.shape[x]%4 == 0 for x in buf.st.unit_stride_axes())):
if DEBUG >= 3: print(f"forcing image {buf.dtype} with shape {buf.shape} to float32")
buf.dtype = dtypes.float32 # NOTE: this is what makes the dtype above not match
if buf.base != buf:
# realize all places where the buffer is expanded
if prod(buf.base.st.shape) < prod(buf.st.shape):
if len(buf.st.views) == 1 and buf.st.views[-1].mask and all_int(buf.base.st.shape) and \
prod(buf.base.st.shape) >= prod([y-x for x,y in buf.st.views[-1].mask]):
simple_pads.add(buf.base)
else:
realizes.add(buf.base)
return _recurse_lb(buf.base, realizes, allbufs, simple_pads, children)
if buf.forced_realize: realizes.add(buf)
allbufs[buf] = None
if buf.op in LoadOps: realizes.add(buf.base)
if buf.op == LoadOps.COPY:
assert buf.srcs[0].st.contiguous and buf.srcs[0].size == buf.srcs[0].base.size, "can only copy contig"
realizes.add(buf.srcs[0].base)
for x in buf.srcs:
children[x.base][buf] = None
_recurse_lb(x, realizes, allbufs, simple_pads, children)
UNSAFE_PAD_OPS = {BinaryOps.DIV, BinaryOps.CMPLT, BinaryOps.CMPEQ, UnaryOps.LOG2, UnaryOps.EXP2}
def _is_padding_okay(buf:LazyBuffer, realizes:Set[LazyBuffer]) -> bool:
if buf in realizes or buf.realized: return True
# NOTE: this broke to_image_idx and coder with JIT
if buf.op in UNSAFE_PAD_OPS: return False
return all(_is_padding_okay(x.base, realizes) for x in buf.srcs)
def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None) -> List[ScheduleItem]:
if seen is None: seen = set()
# start by just realizing the buffers passed in
realizes: Set[LazyBuffer] = set([x.base for x in outs if not x.base.realized])
allbufs: Dict[LazyBuffer, None] = {}
simple_pads: Set[LazyBuffer] = set()
children: DefaultDict[LazyBuffer, Dict[LazyBuffer, None]] = defaultdict(dict)
for out in outs: _recurse_lb(out.base, realizes, allbufs, simple_pads, children, scheduled=True)
# check if we have to realize pads
for p in simple_pads:
if not _is_padding_okay(p, realizes):
realizes.add(p)
# find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child)
reduce_for_op: Dict[LazyBuffer, LazyBuffer] = {}
for r in allbufs.keys():
if r != r.base or r.op not in ReduceOps or r in realizes: continue
# follow the reduce down
child_set: Dict[LazyBuffer, ShapeTracker] = {r: r.st}
realized_children: Dict[LazyBuffer, ShapeTracker] = {}
forced_realize = False
can_chase = True
while not forced_realize and len(child_set):
next_child_set = {}
for tr,st in child_set.items():
if tr in realizes:
realized_children[tr] = st
# can only have one output buffer
# can only reduce contiguous
# max one reduceop per kernel
if len(realized_children) > 1 or not st.contiguous or st.size != r.st.size or (tr in reduce_for_op and reduce_for_op[tr] != r):
can_chase = tr not in reduce_for_op or reduce_for_op[tr] == r
forced_realize = True
break
continue
for tr_next in children[tr].keys():
if not tr_next.realized:
# max one reduceop per kernel
if tr_next.op in ReduceOps:
forced_realize = True
break
st_childs = dedup([s for s in tr_next.srcs if s.base == tr])
if len(st_childs) > 1:
forced_realize = True
break
next_child_set[tr_next] = st + st_childs[0].st
child_set = next_child_set
if forced_realize:
tr = r
if can_chase:
# can chase this down to contiguous children
st = tr.st
while len(children[tr]) == 1:
tr_next = next(iter(children[tr].keys()))
st_childs = dedup([s for s in tr_next.srcs if s.base == tr])
if len(st_childs) > 1: break
if st.size != st_childs[0].st.size: break
st = st + st_childs[0].st
if not st.contiguous or tr_next.op in ReduceOps: break
tr = tr_next
reduce_for_op[tr] = r
realizes.add(tr)
else:
assert len(realized_children) == 1
reduce_for_op[next(iter(realized_children.keys()))] = r
# preschedule all buffers in realizes
prescheduled = {x:_schedule_one(x, realizes, reduce_for_op) for x in realizes if x not in seen and x.realized is None and x.op is not LoadOps.CONST}
assign_targets = {x.srcs[1]:x for x in realizes if x.op is LoadOps.ASSIGN and x not in seen and x.realized is None}
# breadth first ordering
graph: DefaultDict[LazyBuffer,List[LazyBuffer]] = defaultdict(list)
in_degree: DefaultDict[LazyBuffer,int] = defaultdict(int)
for out, si in prescheduled.items():
for x in si.inputs:
graph[x].append(out)
if x in assign_targets:
graph[out].append(assign_targets[x])
in_degree[assign_targets[x]] += 1
if x in prescheduled: in_degree[out] += 1
queue = deque(out for out in prescheduled if in_degree[out] == 0)
schedule: List[ScheduleItem] = []
while queue:
buf = queue.popleft()
seen.add(buf)
schedule.append(prescheduled[buf])
for x in graph[buf]:
in_degree[x] -= 1
if in_degree[x] == 0: queue.append(x)
# confirm everything was scheduled
assert len(prescheduled) == len(schedule), f"prescheduled {len(prescheduled)} but only scheduled {len(schedule)}"
return schedule

203
tinygrad/engine/schedule.py Normal file
View File

@ -0,0 +1,203 @@
from collections import defaultdict, deque
from typing import List, Dict, Optional, Set, DefaultDict
from tinygrad.ops import LoadOps, ScheduleItem, BufferOps, LazyOp, ReduceOps, ConstBuffer, MemBuffer, BinaryOps, UnaryOps
from tinygrad.features.graph import log_lazybuffer
from tinygrad.helpers import GRAPH, DEBUG, prod, dedup, all_int
from tinygrad.shape.symbolic import Variable
from tinygrad.dtype import ImageDType, dtypes
from tinygrad.lazy import LazyBuffer
from tinygrad.shape.shapetracker import ShapeTracker
# recursively create a lazyop
def _recursive_lazyop(buf:LazyBuffer, inputs:List[LazyBuffer], var_vals:Dict[Variable, int], st:ShapeTracker,
realizes:Set[LazyBuffer], cache, first=True, assign_to:Optional[LazyBuffer]=None) -> LazyOp:
if (buf, st) in cache: return cache[(buf, st)]
if buf != buf.base:
st = buf.st + st
buf = buf.base
# all buffers here are base now
assert buf.op is not None
# consts are always fused and generated
if buf.op is LoadOps.CONST:
unbound_st, st_var_vals = st.simplify().unbind()
var_vals.update(st_var_vals)
return LazyOp(BufferOps.CONST, (), ConstBuffer(buf.arg, buf.dtype, unbound_st))
# if we aren't fusing it, it's a load and we add it to the inputs
if buf.realized or (buf in realizes and not first):
unbound_st, st_var_vals = st.simplify().unbind()
var_vals.update(st_var_vals)
if assign_to is not None and buf is assign_to:
if not unbound_st.contiguous:
# we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine
if not (len(unbound_st.views) == 1 and unbound_st.views[0].mask is not None and
ShapeTracker.from_shape(unbound_st.shape).shrink(unbound_st.views[0].mask) == unbound_st.shrink(unbound_st.views[0].mask)):
raise RuntimeError(f"must be contiguous for assign {unbound_st}")
return LazyOp(BufferOps.LOAD, (), MemBuffer(0, buf.dtype, unbound_st))
if buf not in inputs: inputs.append(buf)
return LazyOp(BufferOps.LOAD, (), MemBuffer(inputs.index(buf)+1, buf.dtype, unbound_st))
# if a CONTIGUOUS or ASSIGN made it all the way here, just skip it
if buf.op is LoadOps.CONTIGUOUS:
assert first
return _recursive_lazyop(buf.srcs[0], inputs, var_vals, st, realizes, cache, False)
if buf.op is LoadOps.ASSIGN:
assert first
assert buf.srcs[1].base is buf.srcs[1], "assign must be to base"
assert buf.srcs[1].realized is not None, f"assign must be already realized to schedule {buf.srcs[1]}"
return _recursive_lazyop(buf.srcs[0], inputs, var_vals, st, realizes, cache, False, assign_to=buf.srcs[1])
# if it's a reduce, we have to change the shapetracker
if buf.op in ReduceOps:
assert st.contiguous, "ReduceOps late fusion must be contiguous"
st = ShapeTracker.from_shape(buf.srcs[0].shape)
# otherwise we fuse it like normal
cache[(buf, st)] = ret = \
LazyOp(buf.op, tuple(_recursive_lazyop(x, inputs, var_vals, st, realizes, cache, False, assign_to) for x in buf.srcs), buf.arg)
return ret
def _schedule_one(out:LazyBuffer, realizes:Set[LazyBuffer], reduce_for_op: Dict[LazyBuffer, LazyBuffer]) -> ScheduleItem:
inputs: List[LazyBuffer] = []
var_vals: Dict[Variable, int] = out.st.var_vals.copy()
if out.op in {LoadOps.CUSTOM, LoadOps.SYNC, LoadOps.WAIT, LoadOps.COPY, LoadOps.EMPTY}:
op, inputs = LazyOp(out.op, (), out.arg), list(out.srcs)
else:
output_st = ShapeTracker.from_shape(reduce_for_op[out].shape if out in reduce_for_op else out.shape)
op = _recursive_lazyop(out, inputs, var_vals, output_st, realizes, cache={})
op = LazyOp(BufferOps.STORE, (op, ), MemBuffer(0, out.dtype, output_st.simplify().unbind()[0]))
return ScheduleItem((op,), (out,), tuple(inputs), var_vals)
# recursively search the entire graph for all LazyBuffers, insert realizes after expands
def _recurse_lb(buf:LazyBuffer, realizes:Set[LazyBuffer], allbufs:Dict[LazyBuffer, None],
simple_pads:Set[LazyBuffer], children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], scheduled=False):
if buf in allbufs or buf.base.realized: return
if GRAPH: log_lazybuffer(buf, scheduled)
if isinstance(buf.dtype, ImageDType) and (prod(buf.shape) != prod(buf.dtype.shape) or
not any(buf.shape[x]%4 == 0 for x in buf.st.unit_stride_axes())):
if DEBUG >= 3: print(f"forcing image {buf.dtype} with shape {buf.shape} to float32")
buf.dtype = dtypes.float32 # NOTE: this is what makes the dtype above not match
if buf.base != buf:
# realize all places where the buffer is expanded
if prod(buf.base.st.shape) < prod(buf.st.shape):
if len(buf.st.views) == 1 and buf.st.views[-1].mask and all_int(buf.base.st.shape) and \
prod(buf.base.st.shape) >= prod([y-x for x,y in buf.st.views[-1].mask]):
simple_pads.add(buf.base)
else:
realizes.add(buf.base)
return _recurse_lb(buf.base, realizes, allbufs, simple_pads, children)
if buf.forced_realize: realizes.add(buf)
allbufs[buf] = None
if buf.op in LoadOps: realizes.add(buf.base)
if buf.op == LoadOps.COPY:
assert buf.srcs[0].st.contiguous and buf.srcs[0].size == buf.srcs[0].base.size, "can only copy contig"
realizes.add(buf.srcs[0].base)
for x in buf.srcs:
children[x.base][buf] = None
_recurse_lb(x, realizes, allbufs, simple_pads, children)
UNSAFE_PAD_OPS = {BinaryOps.DIV, BinaryOps.CMPLT, BinaryOps.CMPEQ, UnaryOps.LOG2, UnaryOps.EXP2}
def _is_padding_okay(buf:LazyBuffer, realizes:Set[LazyBuffer]) -> bool:
if buf in realizes or buf.realized: return True
# NOTE: this broke to_image_idx and coder with JIT
if buf.op in UNSAFE_PAD_OPS: return False
return all(_is_padding_okay(x.base, realizes) for x in buf.srcs)
def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None) -> List[ScheduleItem]:
if seen is None: seen = set()
# start by just realizing the buffers passed in
realizes: Set[LazyBuffer] = set([x.base for x in outs if not x.base.realized])
allbufs: Dict[LazyBuffer, None] = {}
simple_pads: Set[LazyBuffer] = set()
children: DefaultDict[LazyBuffer, Dict[LazyBuffer, None]] = defaultdict(dict)
for out in outs: _recurse_lb(out.base, realizes, allbufs, simple_pads, children, scheduled=True)
# check if we have to realize pads
for p in simple_pads:
if not _is_padding_okay(p, realizes):
realizes.add(p)
# find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child)
reduce_for_op: Dict[LazyBuffer, LazyBuffer] = {}
for r in allbufs.keys():
if r != r.base or r.op not in ReduceOps or r in realizes: continue
# follow the reduce down
child_set: Dict[LazyBuffer, ShapeTracker] = {r: r.st}
realized_children: Dict[LazyBuffer, ShapeTracker] = {}
forced_realize = False
can_chase = True
while not forced_realize and len(child_set):
next_child_set = {}
for tr,st in child_set.items():
if tr in realizes:
realized_children[tr] = st
# can only have one output buffer
# can only reduce contiguous
# max one reduceop per kernel
if len(realized_children) > 1 or not st.contiguous or st.size != r.st.size or (tr in reduce_for_op and reduce_for_op[tr] != r):
can_chase = tr not in reduce_for_op or reduce_for_op[tr] == r
forced_realize = True
break
continue
for tr_next in children[tr].keys():
if not tr_next.realized:
# max one reduceop per kernel
if tr_next.op in ReduceOps:
forced_realize = True
break
st_childs = dedup([s for s in tr_next.srcs if s.base == tr])
if len(st_childs) > 1:
forced_realize = True
break
next_child_set[tr_next] = st + st_childs[0].st
child_set = next_child_set
if forced_realize:
tr = r
if can_chase:
# can chase this down to contiguous children
st = tr.st
while len(children[tr]) == 1:
tr_next = next(iter(children[tr].keys()))
st_childs = dedup([s for s in tr_next.srcs if s.base == tr])
if len(st_childs) > 1: break
if st.size != st_childs[0].st.size: break
st = st + st_childs[0].st
if not st.contiguous or tr_next.op in ReduceOps: break
tr = tr_next
reduce_for_op[tr] = r
realizes.add(tr)
else:
assert len(realized_children) == 1
reduce_for_op[next(iter(realized_children.keys()))] = r
# preschedule all buffers in realizes
prescheduled = {x:_schedule_one(x, realizes, reduce_for_op) for x in realizes if x not in seen and x.realized is None and x.op is not LoadOps.CONST}
assign_targets = {x.srcs[1]:x for x in realizes if x.op is LoadOps.ASSIGN and x not in seen and x.realized is None}
# breadth first ordering
graph: DefaultDict[LazyBuffer,List[LazyBuffer]] = defaultdict(list)
in_degree: DefaultDict[LazyBuffer,int] = defaultdict(int)
for out, si in prescheduled.items():
for x in si.inputs:
graph[x].append(out)
if x in assign_targets:
graph[out].append(assign_targets[x])
in_degree[assign_targets[x]] += 1
if x in prescheduled: in_degree[out] += 1
queue = deque(out for out in prescheduled if in_degree[out] == 0)
schedule: List[ScheduleItem] = []
while queue:
buf = queue.popleft()
seen.add(buf)
schedule.append(prescheduled[buf])
for x in graph[buf]:
in_degree[x] -= 1
if in_degree[x] == 0: queue.append(x)
# confirm everything was scheduled
assert len(prescheduled) == len(schedule), f"prescheduled {len(prescheduled)} but only scheduled {len(schedule)}"
return schedule

View File

@ -14,7 +14,8 @@ from tinygrad.features.multi import MultiLazyBuffer
from tinygrad.ops import LoadOps from tinygrad.ops import LoadOps
from tinygrad.device import Buffer, Device from tinygrad.device import Buffer, Device
from tinygrad.shape.symbolic import sint from tinygrad.shape.symbolic import sint
from tinygrad.engine.realize import run_schedule, create_schedule from tinygrad.engine.realize import run_schedule
from tinygrad.engine.schedule import create_schedule
# **** start with two base classes, Tensor and Function **** # **** start with two base classes, Tensor and Function ****