From c2eeb6950be748d681420d43a4c1ecc06038b93a Mon Sep 17 00:00:00 2001 From: George Hotz Date: Sun, 3 Jan 2021 08:29:57 -0800 Subject: [PATCH] add support for sign. technically relu can be second class now --- README.md | 6 +++--- test/test_ops.py | 6 ++++++ tinygrad/ops_cpu.py | 9 +++++++++ tinygrad/ops_gpu.py | 9 +++++++++ tinygrad/tensor.py | 8 +++++++- 5 files changed, 34 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index e7b333c5..3a847afd 100644 --- a/README.md +++ b/README.md @@ -105,17 +105,17 @@ Warning: do not rely on the ANE port. It segfaults sometimes. So if you were doi ### Adding an accelerator -You need to support 14 first class ops: +You need to support 15 first class ops: ``` -Relu, Log, Exp # unary ops +Relu, Log, Exp, Sign # unary ops Sum, Max # reduce ops (with axis argument) Add, Sub, Mul, Pow # binary ops (with broadcasting) Reshape, Transpose, Slice # movement ops Matmul, Conv2D # processing ops ``` -While more ops may be added (like Sign), I think these base 14 are stable. +While more ops may be added, I think these base 15 are stable. ## ImageNet inference diff --git a/test/test_ops.py b/test/test_ops.py index 68eb82e0..69d4daf6 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -62,10 +62,16 @@ class TestOps(unittest.TestCase): helper_test_op([(45,65)], lambda x: torch.log(x), Tensor.log) def test_exp(self): helper_test_op([(45,65)], lambda x: torch.exp(x), Tensor.exp) + def test_sign(self): + helper_test_op([(45,65)], lambda x: torch.sign(x), Tensor.sign) def test_sigmoid(self): helper_test_op([(45,65)], lambda x: x.sigmoid(), Tensor.sigmoid) def test_softplus(self): helper_test_op([(45,65)], lambda x: torch.nn.functional.softplus(x), Tensor.softplus, atol=1e-6, grad_atol=1e-6) + def test_relu6(self): + helper_test_op([(45,65)], lambda x: torch.nn.functional.relu6(x), Tensor.relu6) + def test_hardswish(self): + helper_test_op([(45,65)], lambda x: torch.nn.functional.hardswish(x), Tensor.hardswish, atol=1e-6, grad_atol=1e-6) def test_mish(self): def _mish_pytorch(x): return x*torch.tanh(torch.nn.functional.softplus(x)) diff --git a/tinygrad/ops_cpu.py b/tinygrad/ops_cpu.py index 879d225a..2479434b 100644 --- a/tinygrad/ops_cpu.py +++ b/tinygrad/ops_cpu.py @@ -37,6 +37,15 @@ class Exp(Function): ret, = ctx.saved_tensors return grad_output * ret +class Sign(Function): + @staticmethod + def forward(ctx, input): + return np.sign(input) + + @staticmethod + def backward(ctx, grad_output): + return grad_output * 0 + # ************* reduce ops ************* class Sum(Function): diff --git a/tinygrad/ops_gpu.py b/tinygrad/ops_gpu.py index 2f757c8f..c3b165b4 100644 --- a/tinygrad/ops_gpu.py +++ b/tinygrad/ops_gpu.py @@ -64,6 +64,15 @@ class Exp(Function): ret, = ctx.saved_tensors return binary_op(ctx, 'a * b', grad_output, ret) +class Sign(Function): + @staticmethod + def forward(ctx, input): + return unary_op(ctx, 'a >= 0', input) + + @staticmethod + def backward(ctx, grad_output): + return unary_op(ctx, 'a * 0', grad_output) + # ************* reduce ops ************* def reduce_op(ctx, code, code2, inp, axis=None): diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 82272e75..332dd2dc 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -222,6 +222,7 @@ class Tensor: def div(self, y): return self * (y ** -1.0) + __truediv__ = div def sigmoid(self): e = self.exp() @@ -230,6 +231,12 @@ class Tensor: def swish(self): return self * self.sigmoid() + def relu6(self): + return self.relu() * (6-self).sign() + + def hardswish(self): + return self * (self+3).relu6()/6 + def tanh(self): return 2.0 * ((2.0 * self).sigmoid()) - 1.0 @@ -314,7 +321,6 @@ def register(name, fxn, device=Device.CPU): f.cl_ctx, f.cl_queue, f.ane, f.device = cl_ctx, cl_queue, ane, tt.device return f.apply(f, *x, **kwargs) setattr(Tensor, name, dispatch) - # TODO: div is a second class op, so it doesn't work here if name in ['add', 'sub', 'mul', 'pow', 'matmul']: setattr(Tensor, f"__{name}__", dispatch) setattr(Tensor, f"__i{name}__", lambda self,x: self.assign(dispatch(self,x)))