slice -> pad, shrink

This commit is contained in:
George Hotz 2023-02-28 19:57:45 -08:00
parent ea3fa07c2a
commit 7ff92550bb
4 changed files with 30 additions and 22 deletions

View File

@ -129,13 +129,13 @@ hlops are syntactic sugar around mlops. They support most things torch does.
### mlops
mlops are mid level ops, there's 15 of them. They understand derivatives. They are very simple.
mlops are mid level ops, there's 16 of them. They understand derivatives. They are very simple.
```
Relu, Log, Exp, Reciprocal # unary ops
Sum, Max # reduce ops (with axis argument)
Add, Sub, Mul, Pow # binary ops (no broadcasting, use expand)
Expand, Reshape, Permute, Slice, Flip # movement ops
Log, Exp # unary ops
Sum, Max # reduce ops (with axis argument)
Maximum, Add, Sub, Mul, Pow, Div # binary ops (no broadcasting, use expand)
Expand, Reshape, Permute, Pad, Shrink, Flip # movement ops
```
You no longer need to write mlops for a new accelerator
@ -149,7 +149,7 @@ Buffer # class of memory on
unary_op (NOOP, NEG, NOT, EXP, LOG) # A -> A
reduce_op (SUM, MAX) # A -> B (smaller size, B has 1 in shape)
binary_op (ADD, SUB, MUL, DIV, POW, CMPEQ, MAX) # A + A -> A (all the same size)
movement_op (RESHAPE, PERMUTE, EXPAND, FLIP, PAD, SHRINK) # A -> B (different size)
movement_op (EXPAND, RESHAPE, PERMUTE, PAD, SHRINK, FLIP) # A -> B (different size)
fused_op [[optional]] (MULACC) # A * A -> B
```

View File

@ -181,12 +181,6 @@ class LazyBuffer:
# NOTE: this reshape can only move around 1s
return LazyBuffer(x.device, new_tmp_shape, ReduceOps, LazyOp(op, (x,), new_tmp_shape)).movement_op(MovementOps.RESHAPE, new_shape)
# syntactic sugar around PAD and SHRINK
# TODO: turn RESHAPE into EXPAND and CONTRACT (current EXPAND should be REPEAT)
def slice(self:LazyBuffer, arg):
padding = tuple((max(0, -p[0]), max(0, p[1]-self.shape[i])) for i,p in enumerate(arg))
return self.movement_op(MovementOps.PAD, padding).movement_op(MovementOps.SHRINK, tuple((p[0] + padding[i][0], p[1] + padding[i][0]) for i,p in enumerate(arg)))
def movement_op(self:LazyBuffer, op:MovementOps, arg:Tuple[Any, ...]) -> LazyBuffer:
# very instant nop
if op == MovementOps.RESHAPE and self.shape == arg: return self

View File

@ -140,13 +140,21 @@ class Permute(Function):
def backward(self, grad_output):
return grad_output.movement_op(MovementOps.PERMUTE, tuple(argsort(self.input_order)))
class Slice(Function):
def forward(self, x, arg=None):
self.narg = tuple((0-p[0], x.shape[i]-p[0]) for i,p in enumerate(arg))
return x.slice(tuple(arg))
class Pad(Function):
def forward(self, x, arg):
self.narg = tuple((p[0], x.shape[i]+p[0]) for i,p in enumerate(arg))
return x.movement_op(MovementOps.PAD, arg)
def backward(self, grad_output):
return grad_output.slice(self.narg)
return grad_output.movement_op(MovementOps.SHRINK, self.narg)
class Shrink(Function):
def forward(self, x, arg):
self.narg = tuple((p[0], x.shape[i]-p[1]) for i,p in enumerate(arg))
return x.movement_op(MovementOps.SHRINK, arg)
def backward(self, grad_output):
return grad_output.movement_op(MovementOps.PAD, self.narg)
class Flip(Function):
def forward(self, x, axis):

View File

@ -2,7 +2,7 @@
from __future__ import annotations
import math, functools, itertools
import numpy as np
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence
from tinygrad.helpers import prod, argfix, make_pair, getenv, DEBUG, flatten
from tinygrad.lazy import Device, LazyBuffer
from tinygrad.image import image_conv2d_decorator
@ -25,8 +25,7 @@ class Function:
def apply(fxn:Type[Function], *x:Tensor, **kwargs) -> Tensor:
ctx = fxn(x[0].device, *x)
ret = Tensor(ctx.forward(*[t.lazydata for t in x], **kwargs), device=ctx.device, requires_grad=ctx.requires_grad)
if ctx.requires_grad and not Tensor.no_grad:
ret._ctx = ctx # used by autograd engine
if ctx.requires_grad and not Tensor.no_grad: ret._ctx = ctx # used by autograd engine
return ret
import tinygrad.mlops as mlops
@ -108,7 +107,6 @@ class Tensor:
return ret
# ***** creation helper functions *****
# TODO: remove use of numpy here and make lazy
@staticmethod
def zeros(*shape, **kwargs): return Tensor([0], **kwargs).reshape([1]*len(shape)).expand(shape).contiguous()
@ -125,6 +123,7 @@ class Tensor:
@staticmethod
def eye(dim, **kwargs): return Tensor([1], **kwargs).slice(((0,dim+1),)).reshape(1, dim+1).expand(dim, dim+1).reshape(dim*(dim+1)).slice(((0,dim*dim),)).reshape(dim, dim)
# TODO: below line, remove use of numpy here and make lazy
# TODO: requires cumsum to remove numpy
@staticmethod
def arange(stop, start=0, step=1, **kwargs): return Tensor(np.arange(start=start, stop=stop, step=step, dtype=np.float32), **kwargs)
@ -195,10 +194,17 @@ class Tensor:
def expand(self, shape, *args) -> Tensor: return mlops.Expand.apply(self, shape=tuple(x if x != -1 else s for s,x in zip(self.shape, argfix(shape, *args))))
def permute(self, order, *args) -> Tensor: return mlops.Permute.apply(self, order=argfix(order, *args))
def flip(self, axis, *args) -> Tensor: return mlops.Flip.apply(self, axis=argfix(axis, *args))
def slice(self, arg) -> Tensor: return mlops.Slice.apply(self, arg=tuple(a if a is not None else (0,s) for s,a in zip(self.shape, arg)))
def pad(self, arg:Tuple[Tuple[int, int], ...]) -> Tensor: return mlops.Pad.apply(self, arg=arg) if any(x != (0,0) for x in arg) else self
def shrink(self, arg:Tuple[Tuple[int, int], ...]) -> Tensor: return mlops.Shrink.apply(self, arg=arg) if any(x != (0,s) for x,s in zip(arg, self.shape)) else self
# ***** movement hlops *****
# NOTE: using slice is discouraged and things should migrate to pad and shrink
def slice(self, arg:Sequence[Optional[Tuple[int, int]]]) -> Tensor:
arg_ = tuple(a if a is not None else (0,s) for s,a in zip(self.shape, arg))
padding = tuple((max(0, -p[0]), max(0, p[1]-self.shape[i])) for i,p in enumerate(arg_))
return self.pad(padding).shrink(tuple((p[0] + padding[i][0], p[1] + padding[i][0]) for i,p in enumerate(arg_)))
# Tensors mostly follow the normal python indexing / slicing behavior for sequences
# - Negative indices are taken relative to the end of the sequence, so X[-2] returns the 2nd-to-last element
# - A slice i:j returns the elements with indices in [i, j)