mirror of https://github.com/commaai/tinygrad.git
cleaner relu
This commit is contained in:
parent
5c179d18ad
commit
d38367b561
|
@ -45,8 +45,7 @@ class ReLU(Function):
|
|||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, = ctx.saved_tensors
|
||||
grad_input = grad_output.copy()
|
||||
grad_input[input < 0] = 0
|
||||
grad_input = grad_output * (input >= 0)
|
||||
return grad_input
|
||||
register('relu', ReLU)
|
||||
|
||||
|
|
Loading…
Reference in New Issue