move gradcheck to extra, clean up unbroadcast

This commit is contained in:
George Hotz 2020-11-16 08:03:31 -08:00
parent ed4c35e2e9
commit 13d34373d1
4 changed files with 16 additions and 15 deletions

View File

@ -1,7 +1,10 @@
import numpy as np
from tinygrad.tensor import Tensor
from .utils import mask_like
from .tensor import Tensor
def mask_like(like, mask_inx, mask_value = 1.0):
mask = np.zeros_like(like).reshape(-1)
mask[mask_inx] = mask_value
return mask.reshape(like.shape)
def jacobian(func, input):
output = func(input)

View File

@ -2,7 +2,7 @@ import numpy as np
import torch
import unittest
from tinygrad.tensor import Tensor
from tinygrad.gradcheck import numerical_jacobian, jacobian, gradcheck
from extra.gradcheck import numerical_jacobian, jacobian, gradcheck
x_init = np.random.randn(1,3).astype(np.float32)
W_init = np.random.randn(3,3).astype(np.float32)

View File

@ -4,8 +4,10 @@ import numpy as np
from .tensor import Function, register
# ************* basic ops *************
def adBC(out, in_sh): #adjoint operation to broadcast is sum. Need to sum all axis with 1 = in_sh[i] < out.shape[i]
return out.sum(axis=tuple([i for i in range(len(in_sh)) if in_sh[i]==1 and out.shape[i]>1])).reshape(in_sh)
def unbroadcast(out, in_sh):
# adjoint operation to broadcast is sum. Need to sum all axis with 1 = in_sh[i] < out.shape[i]
sum_axis = [i for i in range(len(in_sh)) if in_sh[i]==1 and out.shape[i]>1]
return out.sum(axis=tuple(sum_axis)).reshape(in_sh)
class Add(Function):
@staticmethod
@ -16,7 +18,7 @@ class Add(Function):
@staticmethod
def backward(ctx, grad_output):
shape_x, shape_y = ctx.saved_tensors
return adBC(grad_output, shape_x), adBC(grad_output, shape_y)
return unbroadcast(grad_output, shape_x), unbroadcast(grad_output, shape_y)
register('add', Add)
class Sub(Function):
@ -28,7 +30,7 @@ class Sub(Function):
@staticmethod
def backward(ctx, grad_output):
shape_x, shape_y = ctx.saved_tensors
return adBC(grad_output, shape_x), adBC( -grad_output, shape_y)
return unbroadcast(grad_output, shape_x), unbroadcast(-grad_output, shape_y)
register('sub', Sub)
class Mul(Function):
@ -40,7 +42,7 @@ class Mul(Function):
@staticmethod
def backward(ctx, grad_output):
x,y = ctx.saved_tensors
return adBC(y*grad_output, x.shape), adBC(x*grad_output, y.shape)
return unbroadcast(y*grad_output, x.shape), unbroadcast(x*grad_output, y.shape)
register('mul', Mul)
class Div(Function):
@ -52,7 +54,7 @@ class Div(Function):
@staticmethod
def backward(ctx, grad_output):
x,y = ctx.saved_tensors
return adBC(grad_output / y, x.shape), adBC(-x * grad_output / y**2, y.shape)
return unbroadcast(grad_output / y, x.shape), unbroadcast(-x * grad_output / y**2, y.shape)
# TODO: registering this breaks the default div on the GPU
#register('div', Div)
@ -65,7 +67,8 @@ class Pow(Function):
@staticmethod
def backward(ctx, grad_output):
x,y = ctx.saved_tensors
return adBC(y * (x**(y-1.0)) * grad_output,x.shape), adBC((x**y) * np.log(x) * grad_output,y.shape)
return unbroadcast(y * (x**(y-1.0)) * grad_output, x.shape), \
unbroadcast((x**y) * np.log(x) * grad_output, y.shape)
register('pow', Pow)
class Sum(Function):

View File

@ -1,10 +1,5 @@
import numpy as np
def mask_like(like, mask_inx, mask_value = 1.0):
mask = np.zeros_like(like).reshape(-1)
mask[mask_inx] = mask_value
return mask.reshape(like.shape)
def layer_init_uniform(*x):
ret = np.random.uniform(-1., 1., size=x)/np.sqrt(np.prod(x))
return ret.astype(np.float32)