From 7ff92550bb2addf2f49a5f239c84adbab762265c Mon Sep 17 00:00:00 2001 From: George Hotz Date: Tue, 28 Feb 2023 19:57:45 -0800 Subject: [PATCH] slice -> pad, shrink --- README.md | 12 ++++++------ tinygrad/lazy.py | 6 ------ tinygrad/mlops.py | 18 +++++++++++++----- tinygrad/tensor.py | 16 +++++++++++----- 4 files changed, 30 insertions(+), 22 deletions(-) diff --git a/README.md b/README.md index 22d7243b..f744ffaf 100644 --- a/README.md +++ b/README.md @@ -129,13 +129,13 @@ hlops are syntactic sugar around mlops. They support most things torch does. ### mlops -mlops are mid level ops, there's 15 of them. They understand derivatives. They are very simple. +mlops are mid level ops, there's 16 of them. They understand derivatives. They are very simple. ``` -Relu, Log, Exp, Reciprocal # unary ops -Sum, Max # reduce ops (with axis argument) -Add, Sub, Mul, Pow # binary ops (no broadcasting, use expand) -Expand, Reshape, Permute, Slice, Flip # movement ops +Log, Exp # unary ops +Sum, Max # reduce ops (with axis argument) +Maximum, Add, Sub, Mul, Pow, Div # binary ops (no broadcasting, use expand) +Expand, Reshape, Permute, Pad, Shrink, Flip # movement ops ``` You no longer need to write mlops for a new accelerator @@ -149,7 +149,7 @@ Buffer # class of memory on unary_op (NOOP, NEG, NOT, EXP, LOG) # A -> A reduce_op (SUM, MAX) # A -> B (smaller size, B has 1 in shape) binary_op (ADD, SUB, MUL, DIV, POW, CMPEQ, MAX) # A + A -> A (all the same size) -movement_op (RESHAPE, PERMUTE, EXPAND, FLIP, PAD, SHRINK) # A -> B (different size) +movement_op (EXPAND, RESHAPE, PERMUTE, PAD, SHRINK, FLIP) # A -> B (different size) fused_op [[optional]] (MULACC) # A * A -> B ``` diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index a5fb4335..c4659610 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -181,12 +181,6 @@ class LazyBuffer: # NOTE: this reshape can only move around 1s return LazyBuffer(x.device, new_tmp_shape, ReduceOps, LazyOp(op, (x,), new_tmp_shape)).movement_op(MovementOps.RESHAPE, new_shape) - # syntactic sugar around PAD and SHRINK - # TODO: turn RESHAPE into EXPAND and CONTRACT (current EXPAND should be REPEAT) - def slice(self:LazyBuffer, arg): - padding = tuple((max(0, -p[0]), max(0, p[1]-self.shape[i])) for i,p in enumerate(arg)) - return self.movement_op(MovementOps.PAD, padding).movement_op(MovementOps.SHRINK, tuple((p[0] + padding[i][0], p[1] + padding[i][0]) for i,p in enumerate(arg))) - def movement_op(self:LazyBuffer, op:MovementOps, arg:Tuple[Any, ...]) -> LazyBuffer: # very instant nop if op == MovementOps.RESHAPE and self.shape == arg: return self diff --git a/tinygrad/mlops.py b/tinygrad/mlops.py index cc7f8466..47755845 100644 --- a/tinygrad/mlops.py +++ b/tinygrad/mlops.py @@ -140,13 +140,21 @@ class Permute(Function): def backward(self, grad_output): return grad_output.movement_op(MovementOps.PERMUTE, tuple(argsort(self.input_order))) -class Slice(Function): - def forward(self, x, arg=None): - self.narg = tuple((0-p[0], x.shape[i]-p[0]) for i,p in enumerate(arg)) - return x.slice(tuple(arg)) +class Pad(Function): + def forward(self, x, arg): + self.narg = tuple((p[0], x.shape[i]+p[0]) for i,p in enumerate(arg)) + return x.movement_op(MovementOps.PAD, arg) def backward(self, grad_output): - return grad_output.slice(self.narg) + return grad_output.movement_op(MovementOps.SHRINK, self.narg) + +class Shrink(Function): + def forward(self, x, arg): + self.narg = tuple((p[0], x.shape[i]-p[1]) for i,p in enumerate(arg)) + return x.movement_op(MovementOps.SHRINK, arg) + + def backward(self, grad_output): + return grad_output.movement_op(MovementOps.PAD, self.narg) class Flip(Function): def forward(self, x, axis): diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 9f6c5011..29955324 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -2,7 +2,7 @@ from __future__ import annotations import math, functools, itertools import numpy as np -from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union +from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence from tinygrad.helpers import prod, argfix, make_pair, getenv, DEBUG, flatten from tinygrad.lazy import Device, LazyBuffer from tinygrad.image import image_conv2d_decorator @@ -25,8 +25,7 @@ class Function: def apply(fxn:Type[Function], *x:Tensor, **kwargs) -> Tensor: ctx = fxn(x[0].device, *x) ret = Tensor(ctx.forward(*[t.lazydata for t in x], **kwargs), device=ctx.device, requires_grad=ctx.requires_grad) - if ctx.requires_grad and not Tensor.no_grad: - ret._ctx = ctx # used by autograd engine + if ctx.requires_grad and not Tensor.no_grad: ret._ctx = ctx # used by autograd engine return ret import tinygrad.mlops as mlops @@ -108,7 +107,6 @@ class Tensor: return ret # ***** creation helper functions ***** - # TODO: remove use of numpy here and make lazy @staticmethod def zeros(*shape, **kwargs): return Tensor([0], **kwargs).reshape([1]*len(shape)).expand(shape).contiguous() @@ -125,6 +123,7 @@ class Tensor: @staticmethod def eye(dim, **kwargs): return Tensor([1], **kwargs).slice(((0,dim+1),)).reshape(1, dim+1).expand(dim, dim+1).reshape(dim*(dim+1)).slice(((0,dim*dim),)).reshape(dim, dim) + # TODO: below line, remove use of numpy here and make lazy # TODO: requires cumsum to remove numpy @staticmethod def arange(stop, start=0, step=1, **kwargs): return Tensor(np.arange(start=start, stop=stop, step=step, dtype=np.float32), **kwargs) @@ -195,10 +194,17 @@ class Tensor: def expand(self, shape, *args) -> Tensor: return mlops.Expand.apply(self, shape=tuple(x if x != -1 else s for s,x in zip(self.shape, argfix(shape, *args)))) def permute(self, order, *args) -> Tensor: return mlops.Permute.apply(self, order=argfix(order, *args)) def flip(self, axis, *args) -> Tensor: return mlops.Flip.apply(self, axis=argfix(axis, *args)) - def slice(self, arg) -> Tensor: return mlops.Slice.apply(self, arg=tuple(a if a is not None else (0,s) for s,a in zip(self.shape, arg))) + def pad(self, arg:Tuple[Tuple[int, int], ...]) -> Tensor: return mlops.Pad.apply(self, arg=arg) if any(x != (0,0) for x in arg) else self + def shrink(self, arg:Tuple[Tuple[int, int], ...]) -> Tensor: return mlops.Shrink.apply(self, arg=arg) if any(x != (0,s) for x,s in zip(arg, self.shape)) else self # ***** movement hlops ***** + # NOTE: using slice is discouraged and things should migrate to pad and shrink + def slice(self, arg:Sequence[Optional[Tuple[int, int]]]) -> Tensor: + arg_ = tuple(a if a is not None else (0,s) for s,a in zip(self.shape, arg)) + padding = tuple((max(0, -p[0]), max(0, p[1]-self.shape[i])) for i,p in enumerate(arg_)) + return self.pad(padding).shrink(tuple((p[0] + padding[i][0], p[1] + padding[i][0]) for i,p in enumerate(arg_))) + # Tensors mostly follow the normal python indexing / slicing behavior for sequences # - Negative indices are taken relative to the end of the sequence, so X[-2] returns the 2nd-to-last element # - A slice i:j returns the elements with indices in [i, j)