mirror of https://github.com/commaai/tinygrad.git
14 ops to write for GPU
This commit is contained in:
parent
06928cf3cc
commit
e01e35e545
|
@ -88,8 +88,10 @@ python -m pytest
|
|||
|
||||
### TODO
|
||||
|
||||
* Train an EfficientNet
|
||||
* EfficientNet backward pass
|
||||
* Tensors on GPU (GPU support, must support Mac)
|
||||
* Reduce code
|
||||
* Increase speed
|
||||
* Add features
|
||||
* In that order
|
||||
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
# TODO: implement BatchNorm2d and Swish
|
||||
# aka batch_norm, pad, swish, dropout
|
||||
# load weights from
|
||||
# https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth
|
||||
# a rough copy of
|
||||
# https://github.com/lukemelas/EfficientNet-PyTorch/blob/master/efficientnet_pytorch/model.py
|
||||
|
|
|
@ -48,6 +48,21 @@ class Pow(Function):
|
|||
return y * (x**(y-1.0)) * grad_output, (x**y) * np.log(x) * grad_output
|
||||
register('pow', Pow)
|
||||
|
||||
class Sum(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
ctx.save_for_backward(input)
|
||||
return np.array([input.sum()])
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, = ctx.saved_tensors
|
||||
return grad_output * np.ones_like(input)
|
||||
register('sum', Sum)
|
||||
|
||||
|
||||
# ************* GEMM *************
|
||||
|
||||
class Dot(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input, weight):
|
||||
|
@ -63,20 +78,8 @@ class Dot(Function):
|
|||
register('dot', Dot)
|
||||
register('matmul', Dot)
|
||||
|
||||
class Sum(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
ctx.save_for_backward(input)
|
||||
return np.array([input.sum()])
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, = ctx.saved_tensors
|
||||
return grad_output * np.ones_like(input)
|
||||
register('sum', Sum)
|
||||
|
||||
|
||||
# ************* nn ops *************
|
||||
# ************* simple ops *************
|
||||
|
||||
class Pad2D(Function):
|
||||
@staticmethod
|
||||
|
@ -88,6 +91,21 @@ class Pad2D(Function):
|
|||
raise Exception("write this")
|
||||
register('pad2d', Pad2D)
|
||||
|
||||
class Reshape(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, shape):
|
||||
ctx.save_for_backward(x.shape)
|
||||
return x.reshape(shape)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
in_shape, = ctx.saved_tensors
|
||||
return grad_output.reshape(in_shape)
|
||||
register('reshape', Reshape)
|
||||
|
||||
|
||||
# ************* activation ops *************
|
||||
|
||||
class ReLU(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
|
@ -118,18 +136,6 @@ class Sigmoid(Function):
|
|||
return grad_input
|
||||
register('sigmoid', Sigmoid)
|
||||
|
||||
class Reshape(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, shape):
|
||||
ctx.save_for_backward(x.shape)
|
||||
return x.reshape(shape)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
in_shape, = ctx.saved_tensors
|
||||
return grad_output.reshape(in_shape)
|
||||
register('reshape', Reshape)
|
||||
|
||||
class LogSoftmax(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
|
|
Loading…
Reference in New Issue