mirror of https://github.com/commaai/tinygrad.git
more readable and faster
This commit is contained in:
parent
d1b6f9822c
commit
11d0cfec77
|
@ -14,19 +14,19 @@ class UnaryOp(Function):
|
|||
|
||||
def backward(ctx, grad_output):
|
||||
input, = ctx.saved_tensors
|
||||
return binary_op(ctx, ctx.bop, grad_output, input)
|
||||
return binary_op(ctx, ctx.bop, input, grad_output)
|
||||
|
||||
class ReLU(UnaryOp):
|
||||
fop = 'max(a, (float)0.)'
|
||||
bop = 'a * (b >= 0)'
|
||||
bop = 'b * (a >= 0)'
|
||||
|
||||
class Log(UnaryOp):
|
||||
fop = 'log(a)'
|
||||
bop = 'a / b'
|
||||
bop = 'b / a'
|
||||
|
||||
class Exp(UnaryOp):
|
||||
fop = 'exp(a)'
|
||||
bop = 'a * exp(b)'
|
||||
bop = 'b * exp(a)'
|
||||
|
||||
# ************* reduce ops *************
|
||||
|
||||
|
@ -37,7 +37,8 @@ class Sum(Function):
|
|||
|
||||
def backward(ctx, grad_output):
|
||||
shape_input, = ctx.saved_tensors
|
||||
return binary_op(ctx, 'a+b', grad_output, buffer_new(ctx, shape_input, zero=True))
|
||||
# NOTE: the b buffer_new isn't used, since this is just for broadcast
|
||||
return binary_op(ctx, 'a', grad_output, buffer_new(ctx, shape_input))
|
||||
|
||||
class Max(Function):
|
||||
def forward(ctx, input, axis=None):
|
||||
|
|
Loading…
Reference in New Issue