ods/tinygrad_repo/tinygrad/realize.py

75 lines
4.3 KiB
Python

from typing import List, cast, Dict, Callable
import numpy as np
from tinygrad.ops import ScheduleItem, LazyOp, LoadOps, Device, BufferOps
from tinygrad.graph import log_schedule_item, print_tree
from tinygrad.lazy import LazyBuffer
from tinygrad.helpers import DEBUG, prod, all_int, getenv, IMAGE
from tinygrad.runtime.lib import RawBufferMapped, RawBufferTransfer
from tinygrad.runtime.ops_disk import RawDiskBuffer
from tinygrad.features.image import fix_schedule_for_images
def run_schedule(schedule:List[ScheduleItem], disable_logging=False):
# HACK: images can be not usable due to shape
if IMAGE >= 2: schedule = fix_schedule_for_images(schedule)
# NOTE: if you for loop the schedule it's slow because nothing frees
while len(schedule):
si = schedule.pop(0)
if not disable_logging: log_schedule_item(si)
assert all(x.realized for x in si.inputs), "can't run schedule, some inputs aren't realized"
if DEBUG >= 3: print_tree(si.ast)
if si.ast.op in LoadOps:
# confirm the LoadOps are contiguous and in order
for i,s in enumerate(si.ast.src): assert isinstance(s, LazyOp) and s.op == BufferOps.MEM and s.arg.idx == i+1 and s.arg.st.contiguous, f"bad LoadOps src {i}: {s}"
LOAD_OPS_DISPATCHER[cast(LoadOps, si.ast.op)](si.out, *si.inputs)
else:
si.out.realized = Device[si.out.device].exec_ast(si.ast, output=si.out, inputs=si.inputs, var_vals=si.var_vals, **si.out._device_extra_args())
del si.out.op
for v in si.out.views: del v.op
assert si.out.realized and isinstance(si.out.realized, Device[si.out.device].buffer), f"device mismatch on realized got {type(si.out.realized)} expected {si.out.device}"
assert si.out.realized.dtype == si.out.dtype, "realized dtype is incorrect"
# *** zero op LoadOps ***
def _realize_empty(buffer: LazyBuffer) -> None:
assert all_int(buffer.shape), "does not support symbolic shape"
if DEBUG >= 2: print(f"*** empty {buffer.device} shape {str(buffer.shape):23s} dtype {buffer.dtype}")
buffer.realized = Device[buffer.device].buffer(prod(buffer.shape), buffer.dtype, **buffer._device_extra_args())
def _realize_rand(buffer: LazyBuffer) -> None:
assert all_int(buffer.shape), "does not support symbolic shape"
if DEBUG >= 2: print(f"*** rand {buffer.device} seed {buffer.op.arg:<10d} shape {str(buffer.shape):23s} dtype {buffer.dtype}")
rng = np.random.default_rng(buffer.op.arg)
buffer.realized = Device[buffer.device].buffer.fromCPU(rng.random(size=prod(buffer.shape), dtype=np.float32).astype(dtype=buffer.dtype.np, copy=False), **buffer._device_extra_args())
# *** one op LoadOps ***
def _realize_from(buffer: LazyBuffer, src: LazyBuffer) -> None:
assert src.realized.size == buffer.st.size(), f"size mismatch on FROM {src.realized.size} != {buffer.st.size()}"
assert src.st.contiguous and buffer.st.contiguous, "all must be contiguous for from"
if DEBUG >= 2: print(f"*** copy {buffer.device} <- {src.device} size {src.realized.size:<16d} shape {str(buffer.shape):23s} dtype {src.realized.dtype}")
# TODO: make this generic
if isinstance(src.realized, RawDiskBuffer) and issubclass(Device[buffer.device].buffer, RawBufferMapped):
assert all_int(buffer.shape), "does not support symbolic shape"
buffer.realized = Device[buffer.device].buffer(prod(buffer.shape), buffer.dtype, **buffer._device_extra_args())
src.realized.readinto(cast(RawBufferMapped, buffer.realized)._buffer())
elif isinstance(src.realized, RawBufferTransfer) and issubclass(Device[buffer.device].buffer, RawBufferTransfer) and getenv("P2P", 0) >= 1:
buffer.realized = cast(RawBufferTransfer, Device[buffer.device].buffer).transfer(src.realized, buffer.shape, buffer.dtype, **buffer._device_extra_args())
else:
# TODO: schedule this as FROM to go to CPU, and a FROM to go to device
buffer.realized = Device[buffer.device].buffer.fromCPU(src.realized.toCPU(), **buffer._device_extra_args())
# *** n op LoadOps ***
def _realize_custom(buffer: LazyBuffer, *inputs: LazyBuffer) -> None:
if DEBUG >= 2: print(f"*** custom {buffer.device} shape {str(buffer.shape):23s} dtype {buffer.dtype}")
buffer.realized = buffer.op.arg(buffer, *inputs)
LOAD_OPS_DISPATCHER: Dict[LoadOps, Callable] = {
LoadOps.EMPTY: _realize_empty,
LoadOps.RAND: _realize_rand,
LoadOps.FROM: _realize_from,
LoadOps.CUSTOM: _realize_custom,
}