This commit is contained in:
George Hotz 2022-06-14 17:28:43 -07:00
parent a8aeebfb0c
commit e057ca23bb
5 changed files with 26 additions and 8 deletions

View File

@ -116,13 +116,13 @@ hlops are syntactic sugar around mlops. They support most things torch does.
### mlops
mlops are mid level ops, there's 13 of them. They understand memory allocation and derivatives
mlops are mid level ops, there's 15 of them. They understand memory allocation and derivatives
```
Relu, Log, Exp # unary ops
Sum, Max # reduce ops (with axis argument)
Add, Sub, Mul, Pow # binary ops (no broadcasting, use expand)
Reshape, Permute, Slice, Expand # movement ops
Reshape, Permute, Slice, Expand, Flip # movement ops
Conv2D(NCHW) # processing op (Matmul is also Conv2D)
```

View File

@ -161,6 +161,12 @@ class TestOps(unittest.TestCase):
helper_test_op([(4,3,6,6)], lambda x: torch.reshape(x, (-1,3,6,6)), lambda x: x.reshape(shape=(-1,3,6,6)))
helper_test_op([(4,3,6,6)], lambda x: torch.reshape(x, (-1,1,6,6)), lambda x: x.reshape(shape=(-1,1,6,6)))
def test_flip(self):
helper_test_op([(4,3,6,6)], lambda x: torch.flip(x, (0,)), lambda x: x.flip(axis=(0,)))
helper_test_op([(4,3,6,6)], lambda x: torch.flip(x, (0,1)), lambda x: x.flip(axis=(0,1)))
helper_test_op([(4,3,6,6)], lambda x: torch.flip(x, (0,1,3)), lambda x: x.flip(axis=(0,1,3)))
helper_test_op([(4,3,6,6)], lambda x: torch.flip(x, (3,)), lambda x: x.flip(axis=(3,)))
def test_flatten(self):
for axis in range(3):
helper_test_op([(4,3,6,6)], lambda x: torch.flatten(x, start_dim=axis), lambda x: x.flatten(axis))
@ -196,7 +202,8 @@ class TestOps(unittest.TestCase):
W = 2
helper_test_op([(bs,cin,64,64), (6,cin//groups,H,W)],
lambda x,w: torch.nn.functional.conv2d(x,w,groups=groups).relu(),
lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), atol=1e-4, grad_rtol=1e-5)
# needed to relax tolerance on NVIDIA
lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), atol=1e-3, grad_rtol=1e-5)
def test_simple_grouped_conv2d(self):
bs = 1

View File

@ -8,6 +8,7 @@ class CPUBuffer(np.ndarray):
def exp(x): return np.exp(x)
def log(x): return np.log(x)
def sign(x): return np.sign(x)
def flip(x, axis): return np.flip(x, axis)
def amax(x, *args, **kwargs): return np.amax(x, *args, **kwargs)
def permute(x, order): return x.transpose(order)
def custompad(x, padding): return np.pad(x, padding)
@ -50,6 +51,7 @@ def reduce_op(op, inp, ret):
def movement_op(op, x, ret, arg=None):
if op == MovementOps.RESHAPE: ret[:] = x.reshape(arg)
elif op == MovementOps.PERMUTE: ret[:] = x.permute(arg)
elif op == MovementOps.FLIP: ret[:] = x.flip(arg)
elif op == MovementOps.SLICE:
padding = [(max(0, -p[0]), max(0, p[1]-x.shape[i])) for i,p in enumerate(arg)]
x = x.custompad(padding)

View File

@ -7,7 +7,7 @@ class TorchBuffer(torch.Tensor):
if isinstance(shape, torch.Tensor):
return super().__new__(cls, shape)
else:
return TorchBuffer(torch.zeros(shape))
return TorchBuffer(torch.zeros(shape)).to(device)
custompad = lambda x,padding: torch.nn.functional.pad(x, [item for sublist in padding[::-1] for item in sublist])
@staticmethod
def fromCPU(data):

View File

@ -125,6 +125,15 @@ class Expand(Function):
in_shape, = ctx.saved_tensors
return ctx.reduce_op(ReduceOps.SUM, grad_output, in_shape)
class Flip(Function):
def forward(ctx, x, axis):
ctx.save_for_backward(axis)
return ctx.movement_op(MovementOps.FLIP, x, axis)
def backward(ctx, grad_output):
axis, = ctx.saved_tensors
return ctx.movement_op(MovementOps.FLIP, grad_output, axis)
class Reshape(Function):
def forward(ctx, x, shape):
ctx.save_for_backward(x.shape)