mirror of https://github.com/commaai/tinygrad.git
14 ops to write for GPU
This commit is contained in:
parent
06928cf3cc
commit
e01e35e545
|
@ -88,8 +88,10 @@ python -m pytest
|
||||||
|
|
||||||
### TODO
|
### TODO
|
||||||
|
|
||||||
|
* Train an EfficientNet
|
||||||
|
* EfficientNet backward pass
|
||||||
|
* Tensors on GPU (GPU support, must support Mac)
|
||||||
* Reduce code
|
* Reduce code
|
||||||
* Increase speed
|
* Increase speed
|
||||||
* Add features
|
* Add features
|
||||||
* In that order
|
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
# TODO: implement BatchNorm2d and Swish
|
# load weights from
|
||||||
# aka batch_norm, pad, swish, dropout
|
|
||||||
# https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth
|
# https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth
|
||||||
# a rough copy of
|
# a rough copy of
|
||||||
# https://github.com/lukemelas/EfficientNet-PyTorch/blob/master/efficientnet_pytorch/model.py
|
# https://github.com/lukemelas/EfficientNet-PyTorch/blob/master/efficientnet_pytorch/model.py
|
||||||
|
|
|
@ -48,6 +48,21 @@ class Pow(Function):
|
||||||
return y * (x**(y-1.0)) * grad_output, (x**y) * np.log(x) * grad_output
|
return y * (x**(y-1.0)) * grad_output, (x**y) * np.log(x) * grad_output
|
||||||
register('pow', Pow)
|
register('pow', Pow)
|
||||||
|
|
||||||
|
class Sum(Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, input):
|
||||||
|
ctx.save_for_backward(input)
|
||||||
|
return np.array([input.sum()])
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
input, = ctx.saved_tensors
|
||||||
|
return grad_output * np.ones_like(input)
|
||||||
|
register('sum', Sum)
|
||||||
|
|
||||||
|
|
||||||
|
# ************* GEMM *************
|
||||||
|
|
||||||
class Dot(Function):
|
class Dot(Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, input, weight):
|
def forward(ctx, input, weight):
|
||||||
|
@ -63,20 +78,8 @@ class Dot(Function):
|
||||||
register('dot', Dot)
|
register('dot', Dot)
|
||||||
register('matmul', Dot)
|
register('matmul', Dot)
|
||||||
|
|
||||||
class Sum(Function):
|
|
||||||
@staticmethod
|
|
||||||
def forward(ctx, input):
|
|
||||||
ctx.save_for_backward(input)
|
|
||||||
return np.array([input.sum()])
|
|
||||||
|
|
||||||
@staticmethod
|
# ************* simple ops *************
|
||||||
def backward(ctx, grad_output):
|
|
||||||
input, = ctx.saved_tensors
|
|
||||||
return grad_output * np.ones_like(input)
|
|
||||||
register('sum', Sum)
|
|
||||||
|
|
||||||
|
|
||||||
# ************* nn ops *************
|
|
||||||
|
|
||||||
class Pad2D(Function):
|
class Pad2D(Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -88,6 +91,21 @@ class Pad2D(Function):
|
||||||
raise Exception("write this")
|
raise Exception("write this")
|
||||||
register('pad2d', Pad2D)
|
register('pad2d', Pad2D)
|
||||||
|
|
||||||
|
class Reshape(Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, x, shape):
|
||||||
|
ctx.save_for_backward(x.shape)
|
||||||
|
return x.reshape(shape)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
in_shape, = ctx.saved_tensors
|
||||||
|
return grad_output.reshape(in_shape)
|
||||||
|
register('reshape', Reshape)
|
||||||
|
|
||||||
|
|
||||||
|
# ************* activation ops *************
|
||||||
|
|
||||||
class ReLU(Function):
|
class ReLU(Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, input):
|
def forward(ctx, input):
|
||||||
|
@ -118,18 +136,6 @@ class Sigmoid(Function):
|
||||||
return grad_input
|
return grad_input
|
||||||
register('sigmoid', Sigmoid)
|
register('sigmoid', Sigmoid)
|
||||||
|
|
||||||
class Reshape(Function):
|
|
||||||
@staticmethod
|
|
||||||
def forward(ctx, x, shape):
|
|
||||||
ctx.save_for_backward(x.shape)
|
|
||||||
return x.reshape(shape)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, grad_output):
|
|
||||||
in_shape, = ctx.saved_tensors
|
|
||||||
return grad_output.reshape(in_shape)
|
|
||||||
register('reshape', Reshape)
|
|
||||||
|
|
||||||
class LogSoftmax(Function):
|
class LogSoftmax(Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, input):
|
def forward(ctx, input):
|
||||||
|
|
Loading…
Reference in New Issue