mirror of https://github.com/commaai/tinygrad.git
pow+div on GPU (#57)
This commit is contained in:
parent
b16fadc5c6
commit
22a5f9975d
|
@ -85,6 +85,22 @@ class Mul(Function):
|
|||
binary_op(ctx, 'res_g[gid] = a_g[gid] * b_g[gid];', x, grad_output)
|
||||
register('mul', Mul, gpu=True)
|
||||
|
||||
class Pow(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, y):
|
||||
ctx.save_for_backward(x, y)
|
||||
return binary_op(ctx, 'res_g[gid] = pow(a_g[gid], b_g[gid]);', x, y)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
x,y = ctx.saved_tensors
|
||||
gradx = binary_op(ctx, 'res_g[gid] = a_g[gid] * b_g[gid];', grad_output,
|
||||
binary_op(ctx, 'res_g[gid] = b_g[gid] * (pow((float)a_g[gid], (float)(b_g[gid]-1.0)));', x, y))
|
||||
grady = binary_op(ctx, 'res_g[gid] = a_g[gid] * b_g[gid];', grad_output,
|
||||
binary_op(ctx, 'res_g[gid] = pow((float)a_g[gid], (float)b_g[gid]) * log(a_g[gid]);', x, y))
|
||||
return gradx, grady
|
||||
register('pow', Pow, gpu=True)
|
||||
|
||||
class Sum(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
|
@ -133,7 +149,7 @@ class Dot(Function):
|
|||
{
|
||||
int X = get_global_id(0); // isize
|
||||
int Y = get_global_id(1); // osize
|
||||
|
||||
|
||||
float ret = 0.0;
|
||||
for (int x = 0; x < msize; x++) {
|
||||
ret += input[X * is0 + x * is1] * weight[Y * ws0 + x * ws1];
|
||||
|
|
Loading…
Reference in New Issue