diff --git a/README.md b/README.md index bd48493b..a8da3260 100644 --- a/README.md +++ b/README.md @@ -116,14 +116,14 @@ hlops are syntactic sugar around mlops. They support most things torch does. ### mlops -mlops are mid level ops, there's 13 of them. They understand memory allocation and derivatives +mlops are mid level ops, there's 15 of them. They understand memory allocation and derivatives ``` -Relu, Log, Exp # unary ops -Sum, Max # reduce ops (with axis argument) -Add, Sub, Mul, Pow # binary ops (no broadcasting, use expand) -Reshape, Permute, Slice, Expand # movement ops -Conv2D(NCHW) # processing op (Matmul is also Conv2D) +Relu, Log, Exp # unary ops +Sum, Max # reduce ops (with axis argument) +Add, Sub, Mul, Pow # binary ops (no broadcasting, use expand) +Reshape, Permute, Slice, Expand, Flip # movement ops +Conv2D(NCHW) # processing op (Matmul is also Conv2D) ``` You no longer need to write mlops for a new accelerator diff --git a/test/test_ops.py b/test/test_ops.py index e840b2e8..1c5e1e18 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -161,6 +161,12 @@ class TestOps(unittest.TestCase): helper_test_op([(4,3,6,6)], lambda x: torch.reshape(x, (-1,3,6,6)), lambda x: x.reshape(shape=(-1,3,6,6))) helper_test_op([(4,3,6,6)], lambda x: torch.reshape(x, (-1,1,6,6)), lambda x: x.reshape(shape=(-1,1,6,6))) + def test_flip(self): + helper_test_op([(4,3,6,6)], lambda x: torch.flip(x, (0,)), lambda x: x.flip(axis=(0,))) + helper_test_op([(4,3,6,6)], lambda x: torch.flip(x, (0,1)), lambda x: x.flip(axis=(0,1))) + helper_test_op([(4,3,6,6)], lambda x: torch.flip(x, (0,1,3)), lambda x: x.flip(axis=(0,1,3))) + helper_test_op([(4,3,6,6)], lambda x: torch.flip(x, (3,)), lambda x: x.flip(axis=(3,))) + def test_flatten(self): for axis in range(3): helper_test_op([(4,3,6,6)], lambda x: torch.flatten(x, start_dim=axis), lambda x: x.flatten(axis)) @@ -196,7 +202,8 @@ class TestOps(unittest.TestCase): W = 2 helper_test_op([(bs,cin,64,64), (6,cin//groups,H,W)], lambda x,w: torch.nn.functional.conv2d(x,w,groups=groups).relu(), - lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), atol=1e-4, grad_rtol=1e-5) + # needed to relax tolerance on NVIDIA + lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), atol=1e-3, grad_rtol=1e-5) def test_simple_grouped_conv2d(self): bs = 1 diff --git a/tinygrad/llops/ops_cpu.py b/tinygrad/llops/ops_cpu.py index 7c20b078..95222461 100644 --- a/tinygrad/llops/ops_cpu.py +++ b/tinygrad/llops/ops_cpu.py @@ -8,6 +8,7 @@ class CPUBuffer(np.ndarray): def exp(x): return np.exp(x) def log(x): return np.log(x) def sign(x): return np.sign(x) + def flip(x, axis): return np.flip(x, axis) def amax(x, *args, **kwargs): return np.amax(x, *args, **kwargs) def permute(x, order): return x.transpose(order) def custompad(x, padding): return np.pad(x, padding) @@ -50,6 +51,7 @@ def reduce_op(op, inp, ret): def movement_op(op, x, ret, arg=None): if op == MovementOps.RESHAPE: ret[:] = x.reshape(arg) elif op == MovementOps.PERMUTE: ret[:] = x.permute(arg) + elif op == MovementOps.FLIP: ret[:] = x.flip(arg) elif op == MovementOps.SLICE: padding = [(max(0, -p[0]), max(0, p[1]-x.shape[i])) for i,p in enumerate(arg)] x = x.custompad(padding) diff --git a/tinygrad/llops/ops_torch.py b/tinygrad/llops/ops_torch.py index 6d3e19e1..16870c72 100644 --- a/tinygrad/llops/ops_torch.py +++ b/tinygrad/llops/ops_torch.py @@ -7,7 +7,7 @@ class TorchBuffer(torch.Tensor): if isinstance(shape, torch.Tensor): return super().__new__(cls, shape) else: - return TorchBuffer(torch.zeros(shape)) + return TorchBuffer(torch.zeros(shape)).to(device) custompad = lambda x,padding: torch.nn.functional.pad(x, [item for sublist in padding[::-1] for item in sublist]) @staticmethod def fromCPU(data): diff --git a/tinygrad/mlops.py b/tinygrad/mlops.py index 17993c8d..457fd05e 100644 --- a/tinygrad/mlops.py +++ b/tinygrad/mlops.py @@ -125,6 +125,15 @@ class Expand(Function): in_shape, = ctx.saved_tensors return ctx.reduce_op(ReduceOps.SUM, grad_output, in_shape) +class Flip(Function): + def forward(ctx, x, axis): + ctx.save_for_backward(axis) + return ctx.movement_op(MovementOps.FLIP, x, axis) + + def backward(ctx, grad_output): + axis, = ctx.saved_tensors + return ctx.movement_op(MovementOps.FLIP, grad_output, axis) + class Reshape(Function): def forward(ctx, x, shape): ctx.save_for_backward(x.shape)