diff --git a/docs/abstractions.py b/docs/abstractions.py index 6d66104b..5892a28f 100644 --- a/docs/abstractions.py +++ b/docs/abstractions.py @@ -98,13 +98,13 @@ class LazyOp: src: Tuple[Union[LazyOp, LazyBuffer], ...] # the sources arg: Optional[Any] = None # and an optional static argument -# there's currently 27 Ops you have to implement for an accelerator. +# there's currently 26 Ops you have to implement for an accelerator. class UnaryOps(Enum): EXP2 = auto(); LOG2 = auto(); CAST = auto(); SIN = auto(); SQRT = auto() class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); CMPLT = auto(); MAX = auto() class ReduceOps(Enum): SUM = auto(); MAX = auto() class MovementOps(Enum): RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); STRIDE = auto() class TernaryOps(Enum): MULACC = auto(); WHERE = auto() -class LoadOps(Enum): EMPTY = auto(); RAND = auto(); CONST = auto(); FROM = auto(); CONTIGUOUS = auto(); CUSTOM = auto() +class LoadOps(Enum): EMPTY = auto(); CONST = auto(); FROM = auto(); CONTIGUOUS = auto(); CUSTOM = auto() # NOTE: if you have a CompiledBuffer(DeviceBuffer) # you do not need to implement the MovementOps # as they are handled by the ShapeTracker(in tinygrad/shape/shapetracker.py, code 7/10) diff --git a/docs/adding_new_accelerators.md b/docs/adding_new_accelerators.md index 9cd9ddbd..cf62ba42 100644 --- a/docs/adding_new_accelerators.md +++ b/docs/adding_new_accelerators.md @@ -11,7 +11,7 @@ unary_op (NOOP, EXP2, LOG2, CAST, SIN, SQRT) # A -> A reduce_op (SUM, MAX) # A -> B (smaller size, B has 1 in shape) binary_op (ADD, SUB, MUL, DIV, CMPEQ, MAX) # A + A -> A (all the same size) movement_op (EXPAND, RESHAPE, PERMUTE, PAD, SHRINK, STRIDE) # A -> B (different size) -load_op (EMPTY, RAND, CONST, FROM, CONTIGUOUS, CUSTOM) # -> A (initialize data on device) +load_op (EMPTY, CONST, FROM, CONTIGUOUS, CUSTOM) # -> A (initialize data on device) ternary_op (WHERE) # A, A, A -> A ternary_op [[optional]] (MULACC) # A * A -> B ``` diff --git a/test/external/external_test_opt.py b/test/external/external_test_opt.py index f53e6fff..f2052ea5 100644 --- a/test/external/external_test_opt.py +++ b/test/external/external_test_opt.py @@ -88,8 +88,9 @@ class TestInferenceMinKernels(unittest.TestCase): args_tiny = {"dim": 512, "hidden_dim": 1024, "n_heads": 8, "n_layers": 4, "norm_eps": 1e-05, "vocab_size": 1000} model = Transformer(**args_tiny) for p in get_parameters(model): p.assign(np.zeros(p.shape, dtype=p.dtype.np)) + inp = Tensor([[1,2,3,4]]) with CLCache(100): - model(Tensor([[1,2,3,4]]), 0).realize() + model(inp, 0).realize() @unittest.skipUnless(Device.DEFAULT == "GPU", "Not Implemented") class TestOptBinOp(unittest.TestCase): diff --git a/test/helpers.py b/test/helpers.py index e9ec794e..6f31dd19 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -5,8 +5,8 @@ from tinygrad.nn.state import get_parameters # for speed def derandomize(x): if isinstance(x, LazyOp): - new_op = LoadOps.EMPTY if x.op == LoadOps.RAND else x.op - return LazyOp(new_op, tuple([derandomize(s) for s in x.src]), None if x.op == LoadOps.RAND else x.arg) + new_op = LoadOps.EMPTY if x.op == LoadOps.CUSTOM else x.op + return LazyOp(new_op, tuple([derandomize(s) for s in x.src]), None if x.op == LoadOps.CUSTOM else x.arg) x.op = derandomize(x.op) return x diff --git a/tinygrad/device.py b/tinygrad/device.py index 97d38bea..cb1e99d9 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -33,6 +33,31 @@ class _Device: return "CPU" Device = _Device() +# **************** base Runner + helpers **************** + +class JITRunner: + def __init__(self): + self.op_estimate, self.mem_estimate = 0, 0 + def exec(self, rawbufs:List[Buffer], var_vals:Optional[Dict[Variable, int]]=None) -> Optional[float]: + var_vals = var_vals if var_vals is not None else {} + from tinygrad.jit import CacheCollector + et = self(rawbufs, var_vals) + CacheCollector.add(self, rawbufs, var_vals) + return et + def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False) -> Optional[float]: + raise NotImplementedError("override this") + +def update_stats(name:str, op_estimate:sint, mem_estimate:sint, var_vals: Optional[Dict[Variable, int]], et: Optional[float], buf_count, jit=False, num_kernels=1, lra: Optional[Dict]=None): + if var_vals is None: var_vals = {} + op_estimate, mem_estimate = sym_infer(op_estimate, var_vals), sym_infer(mem_estimate, var_vals) + if DEBUG >= 2: + print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', ('magenta' if num_kernels == 1 else 'CYAN') if jit else None)} {name+' '*(37-ansilen(name))} arg {buf_count:3d} sz {str(lra.get('global_size', '') if lra else ''):18s} {str(lra.get('local_size', '') if lra else ''):12s} OPs {int(op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " + + (str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)")) + GlobalCounters.kernel_count += num_kernels + GlobalCounters.global_ops += op_estimate + GlobalCounters.global_mem += mem_estimate + if et is not None: GlobalCounters.time_sum_s += et + # **************** Buffer / Allocator **************** class Buffer: @@ -47,27 +72,6 @@ class Buffer: if self.device == Device.DEFAULT: GlobalCounters.mem_used -= self.size * self.dtype.itemsize self.allocator.free(self._buf, self.size, self.dtype) def __repr__(self): return f"" - def copy_(self, src:Buffer): - assert self.size == src.size and self.dtype == src.dtype, "buffer copy size/dtype mismatch" - if hasattr(self.allocator, 'transfer') and type(self.allocator) is type(src.allocator): - # fast path, used on HIP between GPUs - self.allocator.transfer(self._buf, src._buf, self.size*self.dtype.itemsize) - return - if getenv("FROM_BUFFER") and hasattr(self.allocator, 'from_buffer') and hasattr(self.allocator, 'transfer') and hasattr(src.allocator, 'as_buffer'): - # fast path, used on Metal in OS X Sonoma - # NOTE: this is *only* faster if the pages from disk are already loaded into memory - fb = self.allocator.from_buffer(src.allocator.as_buffer(src._buf)) - if fb: - self.allocator.transfer(self._buf, fb, self.size*self.dtype.itemsize) - return - if hasattr(self.allocator, 'as_buffer'): - # fast(ish) path, uses readinto in diskbuffers - src.allocator.copyout(self.allocator.as_buffer(self._buf), src._buf) - elif hasattr(src.allocator, 'as_buffer'): - self.allocator.copyin(self._buf, src.allocator.as_buffer(src._buf)) - else: - # slow path, allocates a CPU buffer - self.copyin(src.toCPU().data) def copyin(self, mv:memoryview): mv = flat_mv(mv) assert len(mv) == self.size*self.dtype.itemsize, f"size mismatch, {len(mv)=} != {self.dtype=} {self.size=}" @@ -82,6 +86,33 @@ class Buffer: if self.size > 0: self.allocator.copyout(flat_mv(ret.data), self._buf) return ret +class _BufferCopy(JITRunner): + # TODO: make wait work + def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False): + dest, src = rawbufs + assert dest.size == src.size and dest.dtype == src.dtype, "buffer copy size/dtype mismatch" + if DEBUG >= 2: print(f"*** copy {dest.device} <- {src.device} size {dest.size:<16d} dtype {dest.dtype}") + if hasattr(dest.allocator, 'transfer') and type(dest.allocator) is type(src.allocator): + # fast path, used on HIP between GPUs + dest.allocator.transfer(dest._buf, src._buf, dest.size*dest.dtype.itemsize) + return + if getenv("FROM_BUFFER") and hasattr(dest.allocator, 'from_buffer') and hasattr(dest.allocator, 'transfer') and hasattr(src.allocator, 'as_buffer'): + # fast path, used on Metal in OS X Sonoma + # NOTE: this is *only* faster if the pages from disk are already loaded into memory + fb = dest.allocator.from_buffer(src.allocator.as_buffer(src._buf)) + if fb: + dest.allocator.transfer(dest._buf, fb, dest.size*dest.dtype.itemsize) + return + if hasattr(dest.allocator, 'as_buffer'): + # fast(ish) path, uses readinto in diskbuffers + src.allocator.copyout(dest.allocator.as_buffer(dest._buf), src._buf) + elif hasattr(src.allocator, 'as_buffer'): + dest.allocator.copyin(dest._buf, src.allocator.as_buffer(src._buf)) + else: + # slow path, allocates a CPU buffer + dest.copyin(src.toCPU().data) +BufferCopy = _BufferCopy() + # TODO: size, dest, src are the same type. can we enforce this? class Allocator: def alloc(self, size:int, dtype:DType): @@ -115,31 +146,6 @@ class _MallocAllocator(LRUAllocator): def copyout(self, dest:memoryview, src): ctypes.memmove(from_mv(dest), src, len(dest)) MallocAllocator = _MallocAllocator() -# **************** base Runner + helpers **************** - -class JITRunner: - def __init__(self): - self.op_estimate, self.mem_estimate = 0, 0 - def exec(self, rawbufs:List[Buffer], var_vals:Optional[Dict[Variable, int]]=None) -> Optional[float]: - var_vals = var_vals if var_vals is not None else {} - from tinygrad.jit import CacheCollector - et = self(rawbufs, var_vals) - CacheCollector.add(self, rawbufs, var_vals) - return et - def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False) -> Optional[float]: - raise NotImplementedError("override this") - -def update_stats(name:str, op_estimate:sint, mem_estimate:sint, var_vals: Optional[Dict[Variable, int]], et: Optional[float], buf_count, jit=False, num_kernels=1, lra: Optional[Dict]=None): - if var_vals is None: var_vals = {} - op_estimate, mem_estimate = sym_infer(op_estimate, var_vals), sym_infer(mem_estimate, var_vals) - if DEBUG >= 2: - print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', ('magenta' if num_kernels == 1 else 'CYAN') if jit else None)} {name+' '*(37-ansilen(name))} arg {buf_count:3d} sz {str(lra.get('global_size', '') if lra else ''):18s} {str(lra.get('local_size', '') if lra else ''):12s} OPs {int(op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " + - (str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)")) - GlobalCounters.kernel_count += num_kernels - GlobalCounters.global_ops += op_estimate - GlobalCounters.global_mem += mem_estimate - if et is not None: GlobalCounters.time_sum_s += et - # **************** for Interpreted Devices **************** class InterpretedASTRunner(JITRunner): diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index 0c991157..a670480d 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -86,7 +86,7 @@ def create_lazybuffer(device:str, st:ShapeTracker, optype:OpType, op:LazyOp, dty if 0 in st.shape: return LazyBuffer(device, ShapeTracker.from_shape(st.shape), LoadOps, LazyOp(LoadOps.CONST, tuple(), 0.0), dtype) # fromcpu aren't cached - if not LAZYCACHE or (optype is LoadOps and op.op in {LoadOps.EMPTY, LoadOps.RAND, LoadOps.CONST}): return LazyBuffer(device, st, optype, op, dtype, base=base) + if not LAZYCACHE or (optype is LoadOps and op.op in {LoadOps.EMPTY, LoadOps.CUSTOM, LoadOps.CONST}): return LazyBuffer(device, st, optype, op, dtype, base=base) # wop is the deduping key. i feel this used to compare more deeply wop = (device, dtype, optype, ref(op), ref(base) if base else None) @@ -166,8 +166,8 @@ class LazyBuffer: op, base_bufs = _replace_bufferops(op) - # add the store if op.op not in LoadOps: + # add the store info = get_lazyop_info(op) assert info.dtype == self.dtype or isinstance(self.dtype, ImageDType), f"dtype mismatch {info.dtype=} != {self.dtype=}" @@ -179,6 +179,9 @@ class LazyBuffer: # TODO: why doesn't this match? #assert info.shape == self.shape, f"shape mismatch {info.shape=} != {self.shape=}" op = LazyOp(BufferOps.STORE, (op, ), MemBuffer(0, self.dtype, ShapeTracker.from_shape(info.shape))) + else: + # check loadop validity of bufferops + for i,s in enumerate(op.src): assert isinstance(s, LazyOp) and s.op == BufferOps.LOAD and s.arg.idx == i+1 and s.arg.st.contiguous, f"bad LoadOps src {i}: {s}" return ret + [ScheduleItem(op, self, tuple(base_bufs), {k:var_vals[k] for k in vars_from_ast(op)})] diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 887a6fa0..ccfd454e 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -17,7 +17,7 @@ class ReduceOps(Enum): SUM = auto(); MAX = auto() # noqa: E702 class BufferOps(Enum): LOAD = auto(); CONST = auto(); STORE = auto() # noqa: E702 # Ops below this line are not allowed in ASTs class MovementOps(Enum): RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); STRIDE = auto(); AS_STRIDED = auto() # noqa: E702 -class LoadOps(Enum): EMPTY = auto(); RAND = auto(); CONST = auto(); FROM = auto(); CONTIGUOUS = auto(); CUSTOM = auto() # noqa: E702 +class LoadOps(Enum): EMPTY = auto(); CONST = auto(); FROM = auto(); CONTIGUOUS = auto(); CUSTOM = auto() # noqa: E702 Op = Union[UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, TernaryOps, BufferOps] OpType = Union[Type[UnaryOps], Type[BinaryOps], Type[ReduceOps], Type[MovementOps], Type[LoadOps], Type[TernaryOps], Type[BufferOps]] diff --git a/tinygrad/realize.py b/tinygrad/realize.py index 6afbd1f2..67fced87 100644 --- a/tinygrad/realize.py +++ b/tinygrad/realize.py @@ -1,17 +1,32 @@ -from typing import List, cast, Dict, Callable -import numpy as np -from tinygrad.ops import ScheduleItem, LazyOp, LoadOps, BufferOps -from tinygrad.device import Device, Buffer +from typing import List, Dict, Tuple, Generator, Optional +from tinygrad.ops import ScheduleItem, LoadOps, BufferOps +from tinygrad.device import Device, Buffer, BufferCopy, JITRunner from tinygrad.graph import log_schedule_item, print_tree -from tinygrad.helpers import DEBUG, prod +from tinygrad.helpers import prod +from tinygrad.shape.symbolic import Variable -def run_schedule(schedule:List[ScheduleItem], disable_logging=False): - # NOTE: if you for loop the schedule it's slow because nothing frees +class CustomOp(JITRunner): + def __init__(self, fxn): + self.fxn = fxn + super().__init__() + def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False): self.fxn(*rawbufs) + +def lower_schedule(schedule:List[ScheduleItem]) -> Generator[Tuple[ScheduleItem, Optional[JITRunner]], None, None]: while len(schedule): si = schedule.pop(0) + assert all(si.out.device == x.device for x in si.inputs) or si.ast.op is LoadOps.FROM, f"all devices must be the same, {si.out.device} != {[x.device for x in si.inputs]} {print_tree(si.ast) or ''}" + if si.ast.op is LoadOps.EMPTY: yield si, None + elif si.ast.op is LoadOps.FROM: yield si, BufferCopy + elif si.ast.op is LoadOps.CUSTOM: yield si, CustomOp(si.ast.arg) + else: yield si, Device[si.out.device].get_runner(si.ast) + del si.out.op + for v in si.out.views: del v.op + +def run_schedule(schedule:List[ScheduleItem], disable_logging=False): + for si, fxn in lower_schedule(schedule): if not disable_logging: log_schedule_item(si) assert all(x.realized for x in si.inputs), "can't run schedule, some inputs aren't realized" - assert all(si.out.device == x.device for x in si.inputs) or si.ast.op is LoadOps.FROM, f"all devices must be the same, {si.out.device} != {[x.device for x in si.inputs]} {print_tree(si.ast) or ''}" + # check if we can reuse the output buffer # if it's aliased, don't use it # TODO: this is pretty wrong actually, who knows where else this buffer is used? @@ -24,35 +39,15 @@ def run_schedule(schedule:List[ScheduleItem], disable_logging=False): if any(not x.arg.st.contiguous for x in si.ast.get_lazyops() if x.op == BufferOps.LOAD and x.arg.idx == i+1): si.out.output_buffer = None break + # we don't have an output buffer, we have to create it, and create to max size if it has symbolic shape si.out.realized = si.out.output_buffer if si.out.output_buffer is not None else \ Buffer(si.out.device, prod((s if isinstance(s, int) else s.max for s in si.out.shape)), si.out.dtype) - if si.ast.op in LoadOps: - if DEBUG >= 2: print(f"*** {si.ast.op:>15s} {f'{si.out.device} <- {si.inputs[0].device}' if si.ast.op is LoadOps.FROM else si.out.device:25s} sz {si.out.realized.size:5d} shape {si.out.shape} dtype {si.out.dtype} arg {si.ast.arg}") - # confirm the LoadOps are contiguous and in order - for i,s in enumerate(si.ast.src): assert isinstance(s, LazyOp) and s.op == BufferOps.LOAD and s.arg.idx == i+1 and s.arg.st.contiguous, f"bad LoadOps src {i}: {s}" - kwargs = {"arg": si.ast.arg} if si.ast.arg is not None else {} - LOAD_OPS_DISPATCHER[cast(LoadOps, si.ast.op)](si.out.realized, *[x.realized for x in si.inputs], **kwargs) - else: - Device[si.out.device].get_runner(si.ast).exec([si.out.realized] + [x.realized for x in si.inputs], si.var_vals) - del si.out.op - for v in si.out.views: del v.op - #assert si.out.realized and isinstance(si.out.realized, Device[si.out.device].buffer), f"device mismatch on realized got {type(si.out.realized)} expected {si.out.device}" - assert si.out.realized.dtype == si.out.dtype, f"realized dtype is incorrect, {si.out.realized.dtype} != {si.out.dtype}" -# *** LoadOps implementation *** + # get all the buffers + rawbufs = [si.out.realized] + [x.realized for x in si.inputs] -# TODO: remove this and write the RNG in tinygrad -def _realize_rand(buffer: Buffer, arg) -> None: - rng = np.random.default_rng(arg) - rng_np_buffer = rng.random(size=buffer.size, dtype=np.float32).astype(dtype=buffer.dtype.np, copy=False) - buffer.copyin(rng_np_buffer.data) - -def _realize_custom(*buffers: Buffer, arg) -> None: arg(*buffers) - -LOAD_OPS_DISPATCHER: Dict[LoadOps, Callable] = { - LoadOps.EMPTY: lambda x: None, - LoadOps.RAND: _realize_rand, - LoadOps.FROM: Buffer.copy_, - LoadOps.CUSTOM: _realize_custom -} + # run the function and put in JIT + if fxn: fxn.exec(rawbufs, si.var_vals) + assert si.out.realized.device == si.out.device, f"realized device is incorrect, {si.out.realized.device=} != {si.out.device=}" + assert si.out.realized.dtype == si.out.dtype, f"realized dtype is incorrect, {si.out.realized.dtype=} != {si.out.dtype=}" diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 702c6ede..e3c4a1f8 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -10,7 +10,7 @@ import numpy as np from tinygrad.helpers import ImageDType, argfix, make_pair, getenv, IMAGE, DEBUG, flatten, DType, dtypes, prod, all_int, round_up from tinygrad.lazy import LazyBuffer from tinygrad.ops import LoadOps -from tinygrad.device import Device +from tinygrad.device import Device, Buffer from tinygrad.shape.symbolic import sint from tinygrad.realize import run_schedule @@ -156,8 +156,7 @@ class Tensor: @staticmethod def rand(*shape, **kwargs): - Tensor._seed += 1 - return Tensor._loadop(LoadOps.RAND, prod((shape:=argfix(*shape))), arg=Tensor._seed, **kwargs).reshape(shape) + return Tensor._loadop(LoadOps.CUSTOM, prod((shape:=argfix(*shape))), arg=custom_random, **kwargs).reshape(shape) # ***** creation helper functions ***** @@ -828,3 +827,11 @@ if IMAGE: from tinygrad.features.image import image_conv2d, image_dot setattr(Tensor, "conv2d", image_conv2d) setattr(Tensor, "dot", image_dot) + +# TODO: remove the custom op and replace with threefry +def custom_random(out:Buffer): + Tensor._seed += 1 + if DEBUG >= 2: print(f"*** rand {out.device} seed {Tensor._seed} size {out.size:<16d} dtype {out.dtype}") + rng = np.random.default_rng(Tensor._seed) + rng_np_buffer = rng.random(size=out.size, dtype=np.float32).astype(dtype=out.dtype.np, copy=False) + out.copyin(rng_np_buffer.data)