mirror of https://github.com/commaai/tinygrad.git
enet runs
This commit is contained in:
parent
9166eb58bb
commit
e84ad3e27d
|
@ -5,6 +5,9 @@
|
||||||
# https://github.com/lukemelas/EfficientNet-PyTorch/blob/master/efficientnet_pytorch/model.py
|
# https://github.com/lukemelas/EfficientNet-PyTorch/blob/master/efficientnet_pytorch/model.py
|
||||||
from tinygrad.tensor import Tensor
|
from tinygrad.tensor import Tensor
|
||||||
|
|
||||||
|
def swish(x):
|
||||||
|
return x.mul(x.sigmoid())
|
||||||
|
|
||||||
class BatchNorm2D:
|
class BatchNorm2D:
|
||||||
def __init__(self, sz):
|
def __init__(self, sz):
|
||||||
self.weight = Tensor.zeros(sz)
|
self.weight = Tensor.zeros(sz)
|
||||||
|
@ -13,7 +16,9 @@ class BatchNorm2D:
|
||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
# this work at inference?
|
# this work at inference?
|
||||||
return x * self.weight + self.bias
|
x = x.mul(self.weight.reshape(shape=[1, -1, 1, 1]))
|
||||||
|
x = x.add(self.bias.reshape(shape=[1, -1, 1, 1]))
|
||||||
|
return x
|
||||||
|
|
||||||
class MBConvBlock:
|
class MBConvBlock:
|
||||||
def __init__(self, kernel_size, strides, expand_ratio, input_filters, output_filters, se_ratio):
|
def __init__(self, kernel_size, strides, expand_ratio, input_filters, output_filters, se_ratio):
|
||||||
|
@ -21,6 +26,8 @@ class MBConvBlock:
|
||||||
if expand_ratio != 1:
|
if expand_ratio != 1:
|
||||||
self._expand_conv = Tensor.zeros(oup, input_filters, 1, 1)
|
self._expand_conv = Tensor.zeros(oup, input_filters, 1, 1)
|
||||||
self._bn0 = BatchNorm2D(oup)
|
self._bn0 = BatchNorm2D(oup)
|
||||||
|
else:
|
||||||
|
self._expand_conv = None
|
||||||
|
|
||||||
self.pad = (kernel_size-1)//2
|
self.pad = (kernel_size-1)//2
|
||||||
self.strides = strides
|
self.strides = strides
|
||||||
|
@ -38,18 +45,20 @@ class MBConvBlock:
|
||||||
self._bn2 = BatchNorm2D(output_filters)
|
self._bn2 = BatchNorm2D(output_filters)
|
||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
x = self._bn0(x.conv2d(self._expand_conv)).swish()
|
if self._expand_conv:
|
||||||
x = x.pad(self.pad, self.pad, self.pad, self.pad)
|
x = swish(self._bn0(x.conv2d(self._expand_conv)))
|
||||||
x = self._bn1(x.conv2d(self._depthwise_conv, stride=self.stride)).swish() # TODO: repeat on axis 1
|
x = x.pad2d(padding=(self.pad, self.pad, self.pad, self.pad))
|
||||||
|
x = x.conv2d(self._depthwise_conv, stride=self.strides, groups=self._depthwise_conv.shape[0])
|
||||||
|
x = swish(self._bn1(x))
|
||||||
|
|
||||||
# has_se
|
# has_se
|
||||||
x_squeezed = x.avg_pool2d()
|
x_squeezed = x.avg_pool2d(kernel_size=x.shape[2:4])
|
||||||
x_squeezed = (x_squeezed.conv2d(self._se_reduce) + self._se_reduce_bias).swish()
|
x_squeezed = swish(x_squeezed.conv2d(self._se_reduce).add(self._se_reduce_bias.reshape(shape=[1, -1, 1, 1])))
|
||||||
x_squeezed = x_squeezed.conv2d(self._se_expand) + self._se_expand_bias
|
x_squeezed = x_squeezed.conv2d(self._se_expand).add(self._se_expand_bias.reshape(shape=[1, -1, 1, 1]))
|
||||||
x = x * x_squeezed.sigmoid()
|
x = x.mul(x_squeezed.sigmoid())
|
||||||
|
|
||||||
x = self._bn2(x.conv2d(self._project_conv))
|
x = self._bn2(x.conv2d(self._project_conv))
|
||||||
return x.swish()
|
return swish(x)
|
||||||
|
|
||||||
class EfficientNet:
|
class EfficientNet:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -67,21 +76,27 @@ class EfficientNet:
|
||||||
self._blocks = []
|
self._blocks = []
|
||||||
# num_repeats, kernel_size, strides, expand_ratio, input_filters, output_filters, se_ratio
|
# num_repeats, kernel_size, strides, expand_ratio, input_filters, output_filters, se_ratio
|
||||||
for b in blocks_args:
|
for b in blocks_args:
|
||||||
|
args = b[1:]
|
||||||
for n in range(b[0]):
|
for n in range(b[0]):
|
||||||
self._blocks.append(MBConvBlock(*b[1:]))
|
self._blocks.append(MBConvBlock(*args))
|
||||||
|
args[3] = args[4]
|
||||||
|
args[1] = (1,1)
|
||||||
self._conv_head = Tensor.zeros(1280, 320, 1, 1)
|
self._conv_head = Tensor.zeros(1280, 320, 1, 1)
|
||||||
self._bn1 = BatchNorm2D(1280)
|
self._bn1 = BatchNorm2D(1280)
|
||||||
self._fc = Tensor.zeros(1280, 1000)
|
self._fc = Tensor.zeros(1280, 1000)
|
||||||
|
|
||||||
def forward(x):
|
def forward(self, x):
|
||||||
x = self._bn0(x.pad(0,1,0,1).conv2d(self._conv_stem, stride=2))
|
x = x.pad2d(padding=(0,1,0,1))
|
||||||
|
x = self._bn0(x.conv2d(self._conv_stem, stride=2))
|
||||||
for b in self._blocks:
|
for b in self._blocks:
|
||||||
x = b(x)
|
x = b(x)
|
||||||
x = self._bn1(x.conv2d(self._conv_head))
|
x = self._bn1(x.conv2d(self._conv_head))
|
||||||
x = x.avg_pool2d() # wrong?
|
x = x.avg_pool2d(kernel_size=x.shape[2:4]).reshape(shape=(-1, 1280))
|
||||||
x = x.dropout(0.2)
|
#x = x.dropout(0.2)
|
||||||
return x.dot(self_fc).swish()
|
return swish(x.dot(self._fc))
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
model = EfficientNet()
|
model = EfficientNet()
|
||||||
|
out = model.forward(Tensor.zeros(1, 3, 224, 224))
|
||||||
|
print(out)
|
||||||
|
|
||||||
|
|
|
@ -55,6 +55,16 @@ register('sum', Sum)
|
||||||
|
|
||||||
# ************* nn ops *************
|
# ************* nn ops *************
|
||||||
|
|
||||||
|
class Pad2D(Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, x, padding=None):
|
||||||
|
return np.pad(x, ((0,0), (0,0), (padding[0], padding[1]), (padding[2], padding[3])))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
raise Exception("write this")
|
||||||
|
register('pad2d', Pad2D)
|
||||||
|
|
||||||
class ReLU(Function):
|
class ReLU(Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, input):
|
def forward(ctx, input):
|
||||||
|
@ -116,11 +126,13 @@ register('logsoftmax', LogSoftmax)
|
||||||
|
|
||||||
class Conv2D(Function):
|
class Conv2D(Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, x, w, stride=1):
|
def forward(ctx, x, w, stride=1, groups=1):
|
||||||
if type(ctx.stride) == int:
|
if type(ctx.stride) == int:
|
||||||
ctx.stride = (ctx.stride, ctx.stride)
|
ctx.stride = (ctx.stride, ctx.stride)
|
||||||
|
|
||||||
cout,cin,H,W = w.shape
|
cout,cin,H,W = w.shape
|
||||||
|
if groups > 1:
|
||||||
|
w = np.repeat(w, groups, axis=1)
|
||||||
tw = w.reshape(cout, -1).T
|
tw = w.reshape(cout, -1).T
|
||||||
ys,xs = ctx.stride
|
ys,xs = ctx.stride
|
||||||
bs,oy,ox = x.shape[0], (x.shape[2]-(H-ys))//ys, (x.shape[3]-(W-xs))//xs
|
bs,oy,ox = x.shape[0], (x.shape[2]-(H-ys))//ys, (x.shape[3]-(W-xs))//xs
|
||||||
|
|
Loading…
Reference in New Issue