14 ops to write for GPU

This commit is contained in:
George Hotz 2020-10-31 10:59:30 -07:00
parent 06928cf3cc
commit e01e35e545
3 changed files with 35 additions and 28 deletions

View File

@ -88,8 +88,10 @@ python -m pytest
### TODO ### TODO
* Train an EfficientNet
* EfficientNet backward pass
* Tensors on GPU (GPU support, must support Mac)
* Reduce code * Reduce code
* Increase speed * Increase speed
* Add features * Add features
* In that order

View File

@ -1,5 +1,4 @@
# TODO: implement BatchNorm2d and Swish # load weights from
# aka batch_norm, pad, swish, dropout
# https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth # https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth
# a rough copy of # a rough copy of
# https://github.com/lukemelas/EfficientNet-PyTorch/blob/master/efficientnet_pytorch/model.py # https://github.com/lukemelas/EfficientNet-PyTorch/blob/master/efficientnet_pytorch/model.py

View File

@ -48,6 +48,21 @@ class Pow(Function):
return y * (x**(y-1.0)) * grad_output, (x**y) * np.log(x) * grad_output return y * (x**(y-1.0)) * grad_output, (x**y) * np.log(x) * grad_output
register('pow', Pow) register('pow', Pow)
class Sum(Function):
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
return np.array([input.sum()])
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
return grad_output * np.ones_like(input)
register('sum', Sum)
# ************* GEMM *************
class Dot(Function): class Dot(Function):
@staticmethod @staticmethod
def forward(ctx, input, weight): def forward(ctx, input, weight):
@ -63,20 +78,8 @@ class Dot(Function):
register('dot', Dot) register('dot', Dot)
register('matmul', Dot) register('matmul', Dot)
class Sum(Function):
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
return np.array([input.sum()])
@staticmethod # ************* simple ops *************
def backward(ctx, grad_output):
input, = ctx.saved_tensors
return grad_output * np.ones_like(input)
register('sum', Sum)
# ************* nn ops *************
class Pad2D(Function): class Pad2D(Function):
@staticmethod @staticmethod
@ -88,6 +91,21 @@ class Pad2D(Function):
raise Exception("write this") raise Exception("write this")
register('pad2d', Pad2D) register('pad2d', Pad2D)
class Reshape(Function):
@staticmethod
def forward(ctx, x, shape):
ctx.save_for_backward(x.shape)
return x.reshape(shape)
@staticmethod
def backward(ctx, grad_output):
in_shape, = ctx.saved_tensors
return grad_output.reshape(in_shape)
register('reshape', Reshape)
# ************* activation ops *************
class ReLU(Function): class ReLU(Function):
@staticmethod @staticmethod
def forward(ctx, input): def forward(ctx, input):
@ -118,18 +136,6 @@ class Sigmoid(Function):
return grad_input return grad_input
register('sigmoid', Sigmoid) register('sigmoid', Sigmoid)
class Reshape(Function):
@staticmethod
def forward(ctx, x, shape):
ctx.save_for_backward(x.shape)
return x.reshape(shape)
@staticmethod
def backward(ctx, grad_output):
in_shape, = ctx.saved_tensors
return grad_output.reshape(in_shape)
register('reshape', Reshape)
class LogSoftmax(Function): class LogSoftmax(Function):
@staticmethod @staticmethod
def forward(ctx, input): def forward(ctx, input):