mirror of https://github.com/commaai/tinygrad.git
add flip
This commit is contained in:
parent
a8aeebfb0c
commit
e057ca23bb
|
@ -116,13 +116,13 @@ hlops are syntactic sugar around mlops. They support most things torch does.
|
|||
|
||||
### mlops
|
||||
|
||||
mlops are mid level ops, there's 13 of them. They understand memory allocation and derivatives
|
||||
mlops are mid level ops, there's 15 of them. They understand memory allocation and derivatives
|
||||
|
||||
```
|
||||
Relu, Log, Exp # unary ops
|
||||
Sum, Max # reduce ops (with axis argument)
|
||||
Add, Sub, Mul, Pow # binary ops (no broadcasting, use expand)
|
||||
Reshape, Permute, Slice, Expand # movement ops
|
||||
Reshape, Permute, Slice, Expand, Flip # movement ops
|
||||
Conv2D(NCHW) # processing op (Matmul is also Conv2D)
|
||||
```
|
||||
|
||||
|
|
|
@ -161,6 +161,12 @@ class TestOps(unittest.TestCase):
|
|||
helper_test_op([(4,3,6,6)], lambda x: torch.reshape(x, (-1,3,6,6)), lambda x: x.reshape(shape=(-1,3,6,6)))
|
||||
helper_test_op([(4,3,6,6)], lambda x: torch.reshape(x, (-1,1,6,6)), lambda x: x.reshape(shape=(-1,1,6,6)))
|
||||
|
||||
def test_flip(self):
|
||||
helper_test_op([(4,3,6,6)], lambda x: torch.flip(x, (0,)), lambda x: x.flip(axis=(0,)))
|
||||
helper_test_op([(4,3,6,6)], lambda x: torch.flip(x, (0,1)), lambda x: x.flip(axis=(0,1)))
|
||||
helper_test_op([(4,3,6,6)], lambda x: torch.flip(x, (0,1,3)), lambda x: x.flip(axis=(0,1,3)))
|
||||
helper_test_op([(4,3,6,6)], lambda x: torch.flip(x, (3,)), lambda x: x.flip(axis=(3,)))
|
||||
|
||||
def test_flatten(self):
|
||||
for axis in range(3):
|
||||
helper_test_op([(4,3,6,6)], lambda x: torch.flatten(x, start_dim=axis), lambda x: x.flatten(axis))
|
||||
|
@ -196,7 +202,8 @@ class TestOps(unittest.TestCase):
|
|||
W = 2
|
||||
helper_test_op([(bs,cin,64,64), (6,cin//groups,H,W)],
|
||||
lambda x,w: torch.nn.functional.conv2d(x,w,groups=groups).relu(),
|
||||
lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), atol=1e-4, grad_rtol=1e-5)
|
||||
# needed to relax tolerance on NVIDIA
|
||||
lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), atol=1e-3, grad_rtol=1e-5)
|
||||
|
||||
def test_simple_grouped_conv2d(self):
|
||||
bs = 1
|
||||
|
|
|
@ -8,6 +8,7 @@ class CPUBuffer(np.ndarray):
|
|||
def exp(x): return np.exp(x)
|
||||
def log(x): return np.log(x)
|
||||
def sign(x): return np.sign(x)
|
||||
def flip(x, axis): return np.flip(x, axis)
|
||||
def amax(x, *args, **kwargs): return np.amax(x, *args, **kwargs)
|
||||
def permute(x, order): return x.transpose(order)
|
||||
def custompad(x, padding): return np.pad(x, padding)
|
||||
|
@ -50,6 +51,7 @@ def reduce_op(op, inp, ret):
|
|||
def movement_op(op, x, ret, arg=None):
|
||||
if op == MovementOps.RESHAPE: ret[:] = x.reshape(arg)
|
||||
elif op == MovementOps.PERMUTE: ret[:] = x.permute(arg)
|
||||
elif op == MovementOps.FLIP: ret[:] = x.flip(arg)
|
||||
elif op == MovementOps.SLICE:
|
||||
padding = [(max(0, -p[0]), max(0, p[1]-x.shape[i])) for i,p in enumerate(arg)]
|
||||
x = x.custompad(padding)
|
||||
|
|
|
@ -7,7 +7,7 @@ class TorchBuffer(torch.Tensor):
|
|||
if isinstance(shape, torch.Tensor):
|
||||
return super().__new__(cls, shape)
|
||||
else:
|
||||
return TorchBuffer(torch.zeros(shape))
|
||||
return TorchBuffer(torch.zeros(shape)).to(device)
|
||||
custompad = lambda x,padding: torch.nn.functional.pad(x, [item for sublist in padding[::-1] for item in sublist])
|
||||
@staticmethod
|
||||
def fromCPU(data):
|
||||
|
|
|
@ -125,6 +125,15 @@ class Expand(Function):
|
|||
in_shape, = ctx.saved_tensors
|
||||
return ctx.reduce_op(ReduceOps.SUM, grad_output, in_shape)
|
||||
|
||||
class Flip(Function):
|
||||
def forward(ctx, x, axis):
|
||||
ctx.save_for_backward(axis)
|
||||
return ctx.movement_op(MovementOps.FLIP, x, axis)
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
axis, = ctx.saved_tensors
|
||||
return ctx.movement_op(MovementOps.FLIP, grad_output, axis)
|
||||
|
||||
class Reshape(Function):
|
||||
def forward(ctx, x, shape):
|
||||
ctx.save_for_backward(x.shape)
|
||||
|
|
Loading…
Reference in New Issue