add support for sign. technically relu can be second class now

This commit is contained in:
George Hotz 2021-01-03 08:29:57 -08:00
parent 6842ad9ec8
commit c2eeb6950b
5 changed files with 34 additions and 4 deletions

View File

@ -105,17 +105,17 @@ Warning: do not rely on the ANE port. It segfaults sometimes. So if you were doi
### Adding an accelerator
You need to support 14 first class ops:
You need to support 15 first class ops:
```
Relu, Log, Exp # unary ops
Relu, Log, Exp, Sign # unary ops
Sum, Max # reduce ops (with axis argument)
Add, Sub, Mul, Pow # binary ops (with broadcasting)
Reshape, Transpose, Slice # movement ops
Matmul, Conv2D # processing ops
```
While more ops may be added (like Sign), I think these base 14 are stable.
While more ops may be added, I think these base 15 are stable.
## ImageNet inference

View File

@ -62,10 +62,16 @@ class TestOps(unittest.TestCase):
helper_test_op([(45,65)], lambda x: torch.log(x), Tensor.log)
def test_exp(self):
helper_test_op([(45,65)], lambda x: torch.exp(x), Tensor.exp)
def test_sign(self):
helper_test_op([(45,65)], lambda x: torch.sign(x), Tensor.sign)
def test_sigmoid(self):
helper_test_op([(45,65)], lambda x: x.sigmoid(), Tensor.sigmoid)
def test_softplus(self):
helper_test_op([(45,65)], lambda x: torch.nn.functional.softplus(x), Tensor.softplus, atol=1e-6, grad_atol=1e-6)
def test_relu6(self):
helper_test_op([(45,65)], lambda x: torch.nn.functional.relu6(x), Tensor.relu6)
def test_hardswish(self):
helper_test_op([(45,65)], lambda x: torch.nn.functional.hardswish(x), Tensor.hardswish, atol=1e-6, grad_atol=1e-6)
def test_mish(self):
def _mish_pytorch(x):
return x*torch.tanh(torch.nn.functional.softplus(x))

View File

@ -37,6 +37,15 @@ class Exp(Function):
ret, = ctx.saved_tensors
return grad_output * ret
class Sign(Function):
@staticmethod
def forward(ctx, input):
return np.sign(input)
@staticmethod
def backward(ctx, grad_output):
return grad_output * 0
# ************* reduce ops *************
class Sum(Function):

View File

@ -64,6 +64,15 @@ class Exp(Function):
ret, = ctx.saved_tensors
return binary_op(ctx, 'a * b', grad_output, ret)
class Sign(Function):
@staticmethod
def forward(ctx, input):
return unary_op(ctx, 'a >= 0', input)
@staticmethod
def backward(ctx, grad_output):
return unary_op(ctx, 'a * 0', grad_output)
# ************* reduce ops *************
def reduce_op(ctx, code, code2, inp, axis=None):

View File

@ -222,6 +222,7 @@ class Tensor:
def div(self, y):
return self * (y ** -1.0)
__truediv__ = div
def sigmoid(self):
e = self.exp()
@ -230,6 +231,12 @@ class Tensor:
def swish(self):
return self * self.sigmoid()
def relu6(self):
return self.relu() * (6-self).sign()
def hardswish(self):
return self * (self+3).relu6()/6
def tanh(self):
return 2.0 * ((2.0 * self).sigmoid()) - 1.0
@ -314,7 +321,6 @@ def register(name, fxn, device=Device.CPU):
f.cl_ctx, f.cl_queue, f.ane, f.device = cl_ctx, cl_queue, ane, tt.device
return f.apply(f, *x, **kwargs)
setattr(Tensor, name, dispatch)
# TODO: div is a second class op, so it doesn't work here
if name in ['add', 'sub', 'mul', 'pow', 'matmul']:
setattr(Tensor, f"__{name}__", dispatch)
setattr(Tensor, f"__i{name}__", lambda self,x: self.assign(dispatch(self,x)))