cleaner relu

This commit is contained in:
George Hotz 2020-10-25 14:35:04 -07:00
parent 5c179d18ad
commit d38367b561
1 changed files with 1 additions and 2 deletions

View File

@ -45,8 +45,7 @@ class ReLU(Function):
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
grad_input = grad_output.copy()
grad_input[input < 0] = 0
grad_input = grad_output * (input >= 0)
return grad_input
register('relu', ReLU)