mirror of https://github.com/commaai/tinygrad.git
slice -> pad, shrink
This commit is contained in:
parent
ea3fa07c2a
commit
7ff92550bb
12
README.md
12
README.md
|
@ -129,13 +129,13 @@ hlops are syntactic sugar around mlops. They support most things torch does.
|
|||
|
||||
### mlops
|
||||
|
||||
mlops are mid level ops, there's 15 of them. They understand derivatives. They are very simple.
|
||||
mlops are mid level ops, there's 16 of them. They understand derivatives. They are very simple.
|
||||
|
||||
```
|
||||
Relu, Log, Exp, Reciprocal # unary ops
|
||||
Sum, Max # reduce ops (with axis argument)
|
||||
Add, Sub, Mul, Pow # binary ops (no broadcasting, use expand)
|
||||
Expand, Reshape, Permute, Slice, Flip # movement ops
|
||||
Log, Exp # unary ops
|
||||
Sum, Max # reduce ops (with axis argument)
|
||||
Maximum, Add, Sub, Mul, Pow, Div # binary ops (no broadcasting, use expand)
|
||||
Expand, Reshape, Permute, Pad, Shrink, Flip # movement ops
|
||||
```
|
||||
|
||||
You no longer need to write mlops for a new accelerator
|
||||
|
@ -149,7 +149,7 @@ Buffer # class of memory on
|
|||
unary_op (NOOP, NEG, NOT, EXP, LOG) # A -> A
|
||||
reduce_op (SUM, MAX) # A -> B (smaller size, B has 1 in shape)
|
||||
binary_op (ADD, SUB, MUL, DIV, POW, CMPEQ, MAX) # A + A -> A (all the same size)
|
||||
movement_op (RESHAPE, PERMUTE, EXPAND, FLIP, PAD, SHRINK) # A -> B (different size)
|
||||
movement_op (EXPAND, RESHAPE, PERMUTE, PAD, SHRINK, FLIP) # A -> B (different size)
|
||||
fused_op [[optional]] (MULACC) # A * A -> B
|
||||
```
|
||||
|
||||
|
|
|
@ -181,12 +181,6 @@ class LazyBuffer:
|
|||
# NOTE: this reshape can only move around 1s
|
||||
return LazyBuffer(x.device, new_tmp_shape, ReduceOps, LazyOp(op, (x,), new_tmp_shape)).movement_op(MovementOps.RESHAPE, new_shape)
|
||||
|
||||
# syntactic sugar around PAD and SHRINK
|
||||
# TODO: turn RESHAPE into EXPAND and CONTRACT (current EXPAND should be REPEAT)
|
||||
def slice(self:LazyBuffer, arg):
|
||||
padding = tuple((max(0, -p[0]), max(0, p[1]-self.shape[i])) for i,p in enumerate(arg))
|
||||
return self.movement_op(MovementOps.PAD, padding).movement_op(MovementOps.SHRINK, tuple((p[0] + padding[i][0], p[1] + padding[i][0]) for i,p in enumerate(arg)))
|
||||
|
||||
def movement_op(self:LazyBuffer, op:MovementOps, arg:Tuple[Any, ...]) -> LazyBuffer:
|
||||
# very instant nop
|
||||
if op == MovementOps.RESHAPE and self.shape == arg: return self
|
||||
|
|
|
@ -140,13 +140,21 @@ class Permute(Function):
|
|||
def backward(self, grad_output):
|
||||
return grad_output.movement_op(MovementOps.PERMUTE, tuple(argsort(self.input_order)))
|
||||
|
||||
class Slice(Function):
|
||||
def forward(self, x, arg=None):
|
||||
self.narg = tuple((0-p[0], x.shape[i]-p[0]) for i,p in enumerate(arg))
|
||||
return x.slice(tuple(arg))
|
||||
class Pad(Function):
|
||||
def forward(self, x, arg):
|
||||
self.narg = tuple((p[0], x.shape[i]+p[0]) for i,p in enumerate(arg))
|
||||
return x.movement_op(MovementOps.PAD, arg)
|
||||
|
||||
def backward(self, grad_output):
|
||||
return grad_output.slice(self.narg)
|
||||
return grad_output.movement_op(MovementOps.SHRINK, self.narg)
|
||||
|
||||
class Shrink(Function):
|
||||
def forward(self, x, arg):
|
||||
self.narg = tuple((p[0], x.shape[i]-p[1]) for i,p in enumerate(arg))
|
||||
return x.movement_op(MovementOps.SHRINK, arg)
|
||||
|
||||
def backward(self, grad_output):
|
||||
return grad_output.movement_op(MovementOps.PAD, self.narg)
|
||||
|
||||
class Flip(Function):
|
||||
def forward(self, x, axis):
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
from __future__ import annotations
|
||||
import math, functools, itertools
|
||||
import numpy as np
|
||||
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union
|
||||
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence
|
||||
from tinygrad.helpers import prod, argfix, make_pair, getenv, DEBUG, flatten
|
||||
from tinygrad.lazy import Device, LazyBuffer
|
||||
from tinygrad.image import image_conv2d_decorator
|
||||
|
@ -25,8 +25,7 @@ class Function:
|
|||
def apply(fxn:Type[Function], *x:Tensor, **kwargs) -> Tensor:
|
||||
ctx = fxn(x[0].device, *x)
|
||||
ret = Tensor(ctx.forward(*[t.lazydata for t in x], **kwargs), device=ctx.device, requires_grad=ctx.requires_grad)
|
||||
if ctx.requires_grad and not Tensor.no_grad:
|
||||
ret._ctx = ctx # used by autograd engine
|
||||
if ctx.requires_grad and not Tensor.no_grad: ret._ctx = ctx # used by autograd engine
|
||||
return ret
|
||||
|
||||
import tinygrad.mlops as mlops
|
||||
|
@ -108,7 +107,6 @@ class Tensor:
|
|||
return ret
|
||||
|
||||
# ***** creation helper functions *****
|
||||
# TODO: remove use of numpy here and make lazy
|
||||
|
||||
@staticmethod
|
||||
def zeros(*shape, **kwargs): return Tensor([0], **kwargs).reshape([1]*len(shape)).expand(shape).contiguous()
|
||||
|
@ -125,6 +123,7 @@ class Tensor:
|
|||
@staticmethod
|
||||
def eye(dim, **kwargs): return Tensor([1], **kwargs).slice(((0,dim+1),)).reshape(1, dim+1).expand(dim, dim+1).reshape(dim*(dim+1)).slice(((0,dim*dim),)).reshape(dim, dim)
|
||||
|
||||
# TODO: below line, remove use of numpy here and make lazy
|
||||
# TODO: requires cumsum to remove numpy
|
||||
@staticmethod
|
||||
def arange(stop, start=0, step=1, **kwargs): return Tensor(np.arange(start=start, stop=stop, step=step, dtype=np.float32), **kwargs)
|
||||
|
@ -195,10 +194,17 @@ class Tensor:
|
|||
def expand(self, shape, *args) -> Tensor: return mlops.Expand.apply(self, shape=tuple(x if x != -1 else s for s,x in zip(self.shape, argfix(shape, *args))))
|
||||
def permute(self, order, *args) -> Tensor: return mlops.Permute.apply(self, order=argfix(order, *args))
|
||||
def flip(self, axis, *args) -> Tensor: return mlops.Flip.apply(self, axis=argfix(axis, *args))
|
||||
def slice(self, arg) -> Tensor: return mlops.Slice.apply(self, arg=tuple(a if a is not None else (0,s) for s,a in zip(self.shape, arg)))
|
||||
def pad(self, arg:Tuple[Tuple[int, int], ...]) -> Tensor: return mlops.Pad.apply(self, arg=arg) if any(x != (0,0) for x in arg) else self
|
||||
def shrink(self, arg:Tuple[Tuple[int, int], ...]) -> Tensor: return mlops.Shrink.apply(self, arg=arg) if any(x != (0,s) for x,s in zip(arg, self.shape)) else self
|
||||
|
||||
# ***** movement hlops *****
|
||||
|
||||
# NOTE: using slice is discouraged and things should migrate to pad and shrink
|
||||
def slice(self, arg:Sequence[Optional[Tuple[int, int]]]) -> Tensor:
|
||||
arg_ = tuple(a if a is not None else (0,s) for s,a in zip(self.shape, arg))
|
||||
padding = tuple((max(0, -p[0]), max(0, p[1]-self.shape[i])) for i,p in enumerate(arg_))
|
||||
return self.pad(padding).shrink(tuple((p[0] + padding[i][0], p[1] + padding[i][0]) for i,p in enumerate(arg_)))
|
||||
|
||||
# Tensors mostly follow the normal python indexing / slicing behavior for sequences
|
||||
# - Negative indices are taken relative to the end of the sequence, so X[-2] returns the 2nd-to-last element
|
||||
# - A slice i:j returns the elements with indices in [i, j)
|
||||
|
|
Loading…
Reference in New Issue