more readable and faster

This commit is contained in:
George Hotz 2022-06-05 14:13:08 -07:00
parent d1b6f9822c
commit 11d0cfec77
1 changed files with 6 additions and 5 deletions

View File

@ -14,19 +14,19 @@ class UnaryOp(Function):
def backward(ctx, grad_output):
input, = ctx.saved_tensors
return binary_op(ctx, ctx.bop, grad_output, input)
return binary_op(ctx, ctx.bop, input, grad_output)
class ReLU(UnaryOp):
fop = 'max(a, (float)0.)'
bop = 'a * (b >= 0)'
bop = 'b * (a >= 0)'
class Log(UnaryOp):
fop = 'log(a)'
bop = 'a / b'
bop = 'b / a'
class Exp(UnaryOp):
fop = 'exp(a)'
bop = 'a * exp(b)'
bop = 'b * exp(a)'
# ************* reduce ops *************
@ -37,7 +37,8 @@ class Sum(Function):
def backward(ctx, grad_output):
shape_input, = ctx.saved_tensors
return binary_op(ctx, 'a+b', grad_output, buffer_new(ctx, shape_input, zero=True))
# NOTE: the b buffer_new isn't used, since this is just for broadcast
return binary_op(ctx, 'a', grad_output, buffer_new(ctx, shape_input))
class Max(Function):
def forward(ctx, input, axis=None):