pow+div on GPU (#57)

This commit is contained in:
Ryan Neph 2020-11-05 07:49:45 -08:00 committed by GitHub
parent b16fadc5c6
commit 22a5f9975d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 17 additions and 1 deletions

View File

@ -85,6 +85,22 @@ class Mul(Function):
binary_op(ctx, 'res_g[gid] = a_g[gid] * b_g[gid];', x, grad_output)
register('mul', Mul, gpu=True)
class Pow(Function):
@staticmethod
def forward(ctx, x, y):
ctx.save_for_backward(x, y)
return binary_op(ctx, 'res_g[gid] = pow(a_g[gid], b_g[gid]);', x, y)
@staticmethod
def backward(ctx, grad_output):
x,y = ctx.saved_tensors
gradx = binary_op(ctx, 'res_g[gid] = a_g[gid] * b_g[gid];', grad_output,
binary_op(ctx, 'res_g[gid] = b_g[gid] * (pow((float)a_g[gid], (float)(b_g[gid]-1.0)));', x, y))
grady = binary_op(ctx, 'res_g[gid] = a_g[gid] * b_g[gid];', grad_output,
binary_op(ctx, 'res_g[gid] = pow((float)a_g[gid], (float)b_g[gid]) * log(a_g[gid]);', x, y))
return gradx, grady
register('pow', Pow, gpu=True)
class Sum(Function):
@staticmethod
def forward(ctx, input):
@ -133,7 +149,7 @@ class Dot(Function):
{
int X = get_global_id(0); // isize
int Y = get_global_id(1); // osize
float ret = 0.0;
for (int x = 0; x < msize; x++) {
ret += input[X * is0 + x * is1] * weight[Y * ws0 + x * ws1];