lower schedule (#2559)

* lower schedule

* remove RAND, and don't put load in the JIT yet

* better fix for that test
This commit is contained in:
George Hotz 2023-12-01 19:17:46 -08:00 committed by GitHub
parent 077567f62d
commit 6733425095
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 105 additions and 93 deletions

View File

@ -98,13 +98,13 @@ class LazyOp:
src: Tuple[Union[LazyOp, LazyBuffer], ...] # the sources
arg: Optional[Any] = None # and an optional static argument
# there's currently 27 Ops you have to implement for an accelerator.
# there's currently 26 Ops you have to implement for an accelerator.
class UnaryOps(Enum): EXP2 = auto(); LOG2 = auto(); CAST = auto(); SIN = auto(); SQRT = auto()
class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); CMPLT = auto(); MAX = auto()
class ReduceOps(Enum): SUM = auto(); MAX = auto()
class MovementOps(Enum): RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); STRIDE = auto()
class TernaryOps(Enum): MULACC = auto(); WHERE = auto()
class LoadOps(Enum): EMPTY = auto(); RAND = auto(); CONST = auto(); FROM = auto(); CONTIGUOUS = auto(); CUSTOM = auto()
class LoadOps(Enum): EMPTY = auto(); CONST = auto(); FROM = auto(); CONTIGUOUS = auto(); CUSTOM = auto()
# NOTE: if you have a CompiledBuffer(DeviceBuffer)
# you do not need to implement the MovementOps
# as they are handled by the ShapeTracker(in tinygrad/shape/shapetracker.py, code 7/10)

View File

@ -11,7 +11,7 @@ unary_op (NOOP, EXP2, LOG2, CAST, SIN, SQRT) # A -> A
reduce_op (SUM, MAX) # A -> B (smaller size, B has 1 in shape)
binary_op (ADD, SUB, MUL, DIV, CMPEQ, MAX) # A + A -> A (all the same size)
movement_op (EXPAND, RESHAPE, PERMUTE, PAD, SHRINK, STRIDE) # A -> B (different size)
load_op (EMPTY, RAND, CONST, FROM, CONTIGUOUS, CUSTOM) # -> A (initialize data on device)
load_op (EMPTY, CONST, FROM, CONTIGUOUS, CUSTOM) # -> A (initialize data on device)
ternary_op (WHERE) # A, A, A -> A
ternary_op [[optional]] (MULACC) # A * A -> B
```

View File

@ -88,8 +88,9 @@ class TestInferenceMinKernels(unittest.TestCase):
args_tiny = {"dim": 512, "hidden_dim": 1024, "n_heads": 8, "n_layers": 4, "norm_eps": 1e-05, "vocab_size": 1000}
model = Transformer(**args_tiny)
for p in get_parameters(model): p.assign(np.zeros(p.shape, dtype=p.dtype.np))
inp = Tensor([[1,2,3,4]])
with CLCache(100):
model(Tensor([[1,2,3,4]]), 0).realize()
model(inp, 0).realize()
@unittest.skipUnless(Device.DEFAULT == "GPU", "Not Implemented")
class TestOptBinOp(unittest.TestCase):

View File

@ -5,8 +5,8 @@ from tinygrad.nn.state import get_parameters
# for speed
def derandomize(x):
if isinstance(x, LazyOp):
new_op = LoadOps.EMPTY if x.op == LoadOps.RAND else x.op
return LazyOp(new_op, tuple([derandomize(s) for s in x.src]), None if x.op == LoadOps.RAND else x.arg)
new_op = LoadOps.EMPTY if x.op == LoadOps.CUSTOM else x.op
return LazyOp(new_op, tuple([derandomize(s) for s in x.src]), None if x.op == LoadOps.CUSTOM else x.arg)
x.op = derandomize(x.op)
return x

View File

@ -33,6 +33,31 @@ class _Device:
return "CPU"
Device = _Device()
# **************** base Runner + helpers ****************
class JITRunner:
def __init__(self):
self.op_estimate, self.mem_estimate = 0, 0
def exec(self, rawbufs:List[Buffer], var_vals:Optional[Dict[Variable, int]]=None) -> Optional[float]:
var_vals = var_vals if var_vals is not None else {}
from tinygrad.jit import CacheCollector
et = self(rawbufs, var_vals)
CacheCollector.add(self, rawbufs, var_vals)
return et
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False) -> Optional[float]:
raise NotImplementedError("override this")
def update_stats(name:str, op_estimate:sint, mem_estimate:sint, var_vals: Optional[Dict[Variable, int]], et: Optional[float], buf_count, jit=False, num_kernels=1, lra: Optional[Dict]=None):
if var_vals is None: var_vals = {}
op_estimate, mem_estimate = sym_infer(op_estimate, var_vals), sym_infer(mem_estimate, var_vals)
if DEBUG >= 2:
print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', ('magenta' if num_kernels == 1 else 'CYAN') if jit else None)} {name+' '*(37-ansilen(name))} arg {buf_count:3d} sz {str(lra.get('global_size', '') if lra else ''):18s} {str(lra.get('local_size', '') if lra else ''):12s} OPs {int(op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " +
(str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)"))
GlobalCounters.kernel_count += num_kernels
GlobalCounters.global_ops += op_estimate
GlobalCounters.global_mem += mem_estimate
if et is not None: GlobalCounters.time_sum_s += et
# **************** Buffer / Allocator ****************
class Buffer:
@ -47,27 +72,6 @@ class Buffer:
if self.device == Device.DEFAULT: GlobalCounters.mem_used -= self.size * self.dtype.itemsize
self.allocator.free(self._buf, self.size, self.dtype)
def __repr__(self): return f"<buf device:{self.device} size:{self.size}>"
def copy_(self, src:Buffer):
assert self.size == src.size and self.dtype == src.dtype, "buffer copy size/dtype mismatch"
if hasattr(self.allocator, 'transfer') and type(self.allocator) is type(src.allocator):
# fast path, used on HIP between GPUs
self.allocator.transfer(self._buf, src._buf, self.size*self.dtype.itemsize)
return
if getenv("FROM_BUFFER") and hasattr(self.allocator, 'from_buffer') and hasattr(self.allocator, 'transfer') and hasattr(src.allocator, 'as_buffer'):
# fast path, used on Metal in OS X Sonoma
# NOTE: this is *only* faster if the pages from disk are already loaded into memory
fb = self.allocator.from_buffer(src.allocator.as_buffer(src._buf))
if fb:
self.allocator.transfer(self._buf, fb, self.size*self.dtype.itemsize)
return
if hasattr(self.allocator, 'as_buffer'):
# fast(ish) path, uses readinto in diskbuffers
src.allocator.copyout(self.allocator.as_buffer(self._buf), src._buf)
elif hasattr(src.allocator, 'as_buffer'):
self.allocator.copyin(self._buf, src.allocator.as_buffer(src._buf))
else:
# slow path, allocates a CPU buffer
self.copyin(src.toCPU().data)
def copyin(self, mv:memoryview):
mv = flat_mv(mv)
assert len(mv) == self.size*self.dtype.itemsize, f"size mismatch, {len(mv)=} != {self.dtype=} {self.size=}"
@ -82,6 +86,33 @@ class Buffer:
if self.size > 0: self.allocator.copyout(flat_mv(ret.data), self._buf)
return ret
class _BufferCopy(JITRunner):
# TODO: make wait work
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False):
dest, src = rawbufs
assert dest.size == src.size and dest.dtype == src.dtype, "buffer copy size/dtype mismatch"
if DEBUG >= 2: print(f"*** copy {dest.device} <- {src.device} size {dest.size:<16d} dtype {dest.dtype}")
if hasattr(dest.allocator, 'transfer') and type(dest.allocator) is type(src.allocator):
# fast path, used on HIP between GPUs
dest.allocator.transfer(dest._buf, src._buf, dest.size*dest.dtype.itemsize)
return
if getenv("FROM_BUFFER") and hasattr(dest.allocator, 'from_buffer') and hasattr(dest.allocator, 'transfer') and hasattr(src.allocator, 'as_buffer'):
# fast path, used on Metal in OS X Sonoma
# NOTE: this is *only* faster if the pages from disk are already loaded into memory
fb = dest.allocator.from_buffer(src.allocator.as_buffer(src._buf))
if fb:
dest.allocator.transfer(dest._buf, fb, dest.size*dest.dtype.itemsize)
return
if hasattr(dest.allocator, 'as_buffer'):
# fast(ish) path, uses readinto in diskbuffers
src.allocator.copyout(dest.allocator.as_buffer(dest._buf), src._buf)
elif hasattr(src.allocator, 'as_buffer'):
dest.allocator.copyin(dest._buf, src.allocator.as_buffer(src._buf))
else:
# slow path, allocates a CPU buffer
dest.copyin(src.toCPU().data)
BufferCopy = _BufferCopy()
# TODO: size, dest, src are the same type. can we enforce this?
class Allocator:
def alloc(self, size:int, dtype:DType):
@ -115,31 +146,6 @@ class _MallocAllocator(LRUAllocator):
def copyout(self, dest:memoryview, src): ctypes.memmove(from_mv(dest), src, len(dest))
MallocAllocator = _MallocAllocator()
# **************** base Runner + helpers ****************
class JITRunner:
def __init__(self):
self.op_estimate, self.mem_estimate = 0, 0
def exec(self, rawbufs:List[Buffer], var_vals:Optional[Dict[Variable, int]]=None) -> Optional[float]:
var_vals = var_vals if var_vals is not None else {}
from tinygrad.jit import CacheCollector
et = self(rawbufs, var_vals)
CacheCollector.add(self, rawbufs, var_vals)
return et
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False) -> Optional[float]:
raise NotImplementedError("override this")
def update_stats(name:str, op_estimate:sint, mem_estimate:sint, var_vals: Optional[Dict[Variable, int]], et: Optional[float], buf_count, jit=False, num_kernels=1, lra: Optional[Dict]=None):
if var_vals is None: var_vals = {}
op_estimate, mem_estimate = sym_infer(op_estimate, var_vals), sym_infer(mem_estimate, var_vals)
if DEBUG >= 2:
print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', ('magenta' if num_kernels == 1 else 'CYAN') if jit else None)} {name+' '*(37-ansilen(name))} arg {buf_count:3d} sz {str(lra.get('global_size', '') if lra else ''):18s} {str(lra.get('local_size', '') if lra else ''):12s} OPs {int(op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " +
(str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)"))
GlobalCounters.kernel_count += num_kernels
GlobalCounters.global_ops += op_estimate
GlobalCounters.global_mem += mem_estimate
if et is not None: GlobalCounters.time_sum_s += et
# **************** for Interpreted Devices ****************
class InterpretedASTRunner(JITRunner):

View File

@ -86,7 +86,7 @@ def create_lazybuffer(device:str, st:ShapeTracker, optype:OpType, op:LazyOp, dty
if 0 in st.shape: return LazyBuffer(device, ShapeTracker.from_shape(st.shape), LoadOps, LazyOp(LoadOps.CONST, tuple(), 0.0), dtype)
# fromcpu aren't cached
if not LAZYCACHE or (optype is LoadOps and op.op in {LoadOps.EMPTY, LoadOps.RAND, LoadOps.CONST}): return LazyBuffer(device, st, optype, op, dtype, base=base)
if not LAZYCACHE or (optype is LoadOps and op.op in {LoadOps.EMPTY, LoadOps.CUSTOM, LoadOps.CONST}): return LazyBuffer(device, st, optype, op, dtype, base=base)
# wop is the deduping key. i feel this used to compare more deeply
wop = (device, dtype, optype, ref(op), ref(base) if base else None)
@ -166,8 +166,8 @@ class LazyBuffer:
op, base_bufs = _replace_bufferops(op)
# add the store
if op.op not in LoadOps:
# add the store
info = get_lazyop_info(op)
assert info.dtype == self.dtype or isinstance(self.dtype, ImageDType), f"dtype mismatch {info.dtype=} != {self.dtype=}"
@ -179,6 +179,9 @@ class LazyBuffer:
# TODO: why doesn't this match?
#assert info.shape == self.shape, f"shape mismatch {info.shape=} != {self.shape=}"
op = LazyOp(BufferOps.STORE, (op, ), MemBuffer(0, self.dtype, ShapeTracker.from_shape(info.shape)))
else:
# check loadop validity of bufferops
for i,s in enumerate(op.src): assert isinstance(s, LazyOp) and s.op == BufferOps.LOAD and s.arg.idx == i+1 and s.arg.st.contiguous, f"bad LoadOps src {i}: {s}"
return ret + [ScheduleItem(op, self, tuple(base_bufs), {k:var_vals[k] for k in vars_from_ast(op)})]

View File

@ -17,7 +17,7 @@ class ReduceOps(Enum): SUM = auto(); MAX = auto() # noqa: E702
class BufferOps(Enum): LOAD = auto(); CONST = auto(); STORE = auto() # noqa: E702
# Ops below this line are not allowed in ASTs
class MovementOps(Enum): RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); STRIDE = auto(); AS_STRIDED = auto() # noqa: E702
class LoadOps(Enum): EMPTY = auto(); RAND = auto(); CONST = auto(); FROM = auto(); CONTIGUOUS = auto(); CUSTOM = auto() # noqa: E702
class LoadOps(Enum): EMPTY = auto(); CONST = auto(); FROM = auto(); CONTIGUOUS = auto(); CUSTOM = auto() # noqa: E702
Op = Union[UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, TernaryOps, BufferOps]
OpType = Union[Type[UnaryOps], Type[BinaryOps], Type[ReduceOps], Type[MovementOps], Type[LoadOps], Type[TernaryOps], Type[BufferOps]]

View File

@ -1,17 +1,32 @@
from typing import List, cast, Dict, Callable
import numpy as np
from tinygrad.ops import ScheduleItem, LazyOp, LoadOps, BufferOps
from tinygrad.device import Device, Buffer
from typing import List, Dict, Tuple, Generator, Optional
from tinygrad.ops import ScheduleItem, LoadOps, BufferOps
from tinygrad.device import Device, Buffer, BufferCopy, JITRunner
from tinygrad.graph import log_schedule_item, print_tree
from tinygrad.helpers import DEBUG, prod
from tinygrad.helpers import prod
from tinygrad.shape.symbolic import Variable
def run_schedule(schedule:List[ScheduleItem], disable_logging=False):
# NOTE: if you for loop the schedule it's slow because nothing frees
class CustomOp(JITRunner):
def __init__(self, fxn):
self.fxn = fxn
super().__init__()
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False): self.fxn(*rawbufs)
def lower_schedule(schedule:List[ScheduleItem]) -> Generator[Tuple[ScheduleItem, Optional[JITRunner]], None, None]:
while len(schedule):
si = schedule.pop(0)
assert all(si.out.device == x.device for x in si.inputs) or si.ast.op is LoadOps.FROM, f"all devices must be the same, {si.out.device} != {[x.device for x in si.inputs]} {print_tree(si.ast) or ''}"
if si.ast.op is LoadOps.EMPTY: yield si, None
elif si.ast.op is LoadOps.FROM: yield si, BufferCopy
elif si.ast.op is LoadOps.CUSTOM: yield si, CustomOp(si.ast.arg)
else: yield si, Device[si.out.device].get_runner(si.ast)
del si.out.op
for v in si.out.views: del v.op
def run_schedule(schedule:List[ScheduleItem], disable_logging=False):
for si, fxn in lower_schedule(schedule):
if not disable_logging: log_schedule_item(si)
assert all(x.realized for x in si.inputs), "can't run schedule, some inputs aren't realized"
assert all(si.out.device == x.device for x in si.inputs) or si.ast.op is LoadOps.FROM, f"all devices must be the same, {si.out.device} != {[x.device for x in si.inputs]} {print_tree(si.ast) or ''}"
# check if we can reuse the output buffer
# if it's aliased, don't use it
# TODO: this is pretty wrong actually, who knows where else this buffer is used?
@ -24,35 +39,15 @@ def run_schedule(schedule:List[ScheduleItem], disable_logging=False):
if any(not x.arg.st.contiguous for x in si.ast.get_lazyops() if x.op == BufferOps.LOAD and x.arg.idx == i+1):
si.out.output_buffer = None
break
# we don't have an output buffer, we have to create it, and create to max size if it has symbolic shape
si.out.realized = si.out.output_buffer if si.out.output_buffer is not None else \
Buffer(si.out.device, prod((s if isinstance(s, int) else s.max for s in si.out.shape)), si.out.dtype)
if si.ast.op in LoadOps:
if DEBUG >= 2: print(f"*** {si.ast.op:>15s} {f'{si.out.device} <- {si.inputs[0].device}' if si.ast.op is LoadOps.FROM else si.out.device:25s} sz {si.out.realized.size:5d} shape {si.out.shape} dtype {si.out.dtype} arg {si.ast.arg}")
# confirm the LoadOps are contiguous and in order
for i,s in enumerate(si.ast.src): assert isinstance(s, LazyOp) and s.op == BufferOps.LOAD and s.arg.idx == i+1 and s.arg.st.contiguous, f"bad LoadOps src {i}: {s}"
kwargs = {"arg": si.ast.arg} if si.ast.arg is not None else {}
LOAD_OPS_DISPATCHER[cast(LoadOps, si.ast.op)](si.out.realized, *[x.realized for x in si.inputs], **kwargs)
else:
Device[si.out.device].get_runner(si.ast).exec([si.out.realized] + [x.realized for x in si.inputs], si.var_vals)
del si.out.op
for v in si.out.views: del v.op
#assert si.out.realized and isinstance(si.out.realized, Device[si.out.device].buffer), f"device mismatch on realized got {type(si.out.realized)} expected {si.out.device}"
assert si.out.realized.dtype == si.out.dtype, f"realized dtype is incorrect, {si.out.realized.dtype} != {si.out.dtype}"
# *** LoadOps implementation ***
# get all the buffers
rawbufs = [si.out.realized] + [x.realized for x in si.inputs]
# TODO: remove this and write the RNG in tinygrad
def _realize_rand(buffer: Buffer, arg) -> None:
rng = np.random.default_rng(arg)
rng_np_buffer = rng.random(size=buffer.size, dtype=np.float32).astype(dtype=buffer.dtype.np, copy=False)
buffer.copyin(rng_np_buffer.data)
def _realize_custom(*buffers: Buffer, arg) -> None: arg(*buffers)
LOAD_OPS_DISPATCHER: Dict[LoadOps, Callable] = {
LoadOps.EMPTY: lambda x: None,
LoadOps.RAND: _realize_rand,
LoadOps.FROM: Buffer.copy_,
LoadOps.CUSTOM: _realize_custom
}
# run the function and put in JIT
if fxn: fxn.exec(rawbufs, si.var_vals)
assert si.out.realized.device == si.out.device, f"realized device is incorrect, {si.out.realized.device=} != {si.out.device=}"
assert si.out.realized.dtype == si.out.dtype, f"realized dtype is incorrect, {si.out.realized.dtype=} != {si.out.dtype=}"

View File

@ -10,7 +10,7 @@ import numpy as np
from tinygrad.helpers import ImageDType, argfix, make_pair, getenv, IMAGE, DEBUG, flatten, DType, dtypes, prod, all_int, round_up
from tinygrad.lazy import LazyBuffer
from tinygrad.ops import LoadOps
from tinygrad.device import Device
from tinygrad.device import Device, Buffer
from tinygrad.shape.symbolic import sint
from tinygrad.realize import run_schedule
@ -156,8 +156,7 @@ class Tensor:
@staticmethod
def rand(*shape, **kwargs):
Tensor._seed += 1
return Tensor._loadop(LoadOps.RAND, prod((shape:=argfix(*shape))), arg=Tensor._seed, **kwargs).reshape(shape)
return Tensor._loadop(LoadOps.CUSTOM, prod((shape:=argfix(*shape))), arg=custom_random, **kwargs).reshape(shape)
# ***** creation helper functions *****
@ -828,3 +827,11 @@ if IMAGE:
from tinygrad.features.image import image_conv2d, image_dot
setattr(Tensor, "conv2d", image_conv2d)
setattr(Tensor, "dot", image_dot)
# TODO: remove the custom op and replace with threefry
def custom_random(out:Buffer):
Tensor._seed += 1
if DEBUG >= 2: print(f"*** rand {out.device} seed {Tensor._seed} size {out.size:<16d} dtype {out.dtype}")
rng = np.random.default_rng(Tensor._seed)
rng_np_buffer = rng.random(size=out.size, dtype=np.float32).astype(dtype=out.dtype.np, copy=False)
out.copyin(rng_np_buffer.data)