diff --git a/examples/efficientnet.py b/examples/efficientnet.py index 94031c40..0284d2d0 100644 --- a/examples/efficientnet.py +++ b/examples/efficientnet.py @@ -5,6 +5,9 @@ # https://github.com/lukemelas/EfficientNet-PyTorch/blob/master/efficientnet_pytorch/model.py from tinygrad.tensor import Tensor +def swish(x): + return x.mul(x.sigmoid()) + class BatchNorm2D: def __init__(self, sz): self.weight = Tensor.zeros(sz) @@ -13,7 +16,9 @@ class BatchNorm2D: def __call__(self, x): # this work at inference? - return x * self.weight + self.bias + x = x.mul(self.weight.reshape(shape=[1, -1, 1, 1])) + x = x.add(self.bias.reshape(shape=[1, -1, 1, 1])) + return x class MBConvBlock: def __init__(self, kernel_size, strides, expand_ratio, input_filters, output_filters, se_ratio): @@ -21,6 +26,8 @@ class MBConvBlock: if expand_ratio != 1: self._expand_conv = Tensor.zeros(oup, input_filters, 1, 1) self._bn0 = BatchNorm2D(oup) + else: + self._expand_conv = None self.pad = (kernel_size-1)//2 self.strides = strides @@ -38,18 +45,20 @@ class MBConvBlock: self._bn2 = BatchNorm2D(output_filters) def __call__(self, x): - x = self._bn0(x.conv2d(self._expand_conv)).swish() - x = x.pad(self.pad, self.pad, self.pad, self.pad) - x = self._bn1(x.conv2d(self._depthwise_conv, stride=self.stride)).swish() # TODO: repeat on axis 1 + if self._expand_conv: + x = swish(self._bn0(x.conv2d(self._expand_conv))) + x = x.pad2d(padding=(self.pad, self.pad, self.pad, self.pad)) + x = x.conv2d(self._depthwise_conv, stride=self.strides, groups=self._depthwise_conv.shape[0]) + x = swish(self._bn1(x)) # has_se - x_squeezed = x.avg_pool2d() - x_squeezed = (x_squeezed.conv2d(self._se_reduce) + self._se_reduce_bias).swish() - x_squeezed = x_squeezed.conv2d(self._se_expand) + self._se_expand_bias - x = x * x_squeezed.sigmoid() + x_squeezed = x.avg_pool2d(kernel_size=x.shape[2:4]) + x_squeezed = swish(x_squeezed.conv2d(self._se_reduce).add(self._se_reduce_bias.reshape(shape=[1, -1, 1, 1]))) + x_squeezed = x_squeezed.conv2d(self._se_expand).add(self._se_expand_bias.reshape(shape=[1, -1, 1, 1])) + x = x.mul(x_squeezed.sigmoid()) x = self._bn2(x.conv2d(self._project_conv)) - return x.swish() + return swish(x) class EfficientNet: def __init__(self): @@ -67,21 +76,27 @@ class EfficientNet: self._blocks = [] # num_repeats, kernel_size, strides, expand_ratio, input_filters, output_filters, se_ratio for b in blocks_args: + args = b[1:] for n in range(b[0]): - self._blocks.append(MBConvBlock(*b[1:])) + self._blocks.append(MBConvBlock(*args)) + args[3] = args[4] + args[1] = (1,1) self._conv_head = Tensor.zeros(1280, 320, 1, 1) self._bn1 = BatchNorm2D(1280) self._fc = Tensor.zeros(1280, 1000) - def forward(x): - x = self._bn0(x.pad(0,1,0,1).conv2d(self._conv_stem, stride=2)) + def forward(self, x): + x = x.pad2d(padding=(0,1,0,1)) + x = self._bn0(x.conv2d(self._conv_stem, stride=2)) for b in self._blocks: x = b(x) x = self._bn1(x.conv2d(self._conv_head)) - x = x.avg_pool2d() # wrong? - x = x.dropout(0.2) - return x.dot(self_fc).swish() + x = x.avg_pool2d(kernel_size=x.shape[2:4]).reshape(shape=(-1, 1280)) + #x = x.dropout(0.2) + return swish(x.dot(self._fc)) if __name__ == "__main__": model = EfficientNet() + out = model.forward(Tensor.zeros(1, 3, 224, 224)) + print(out) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 2b7c0c57..57fa1b53 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -55,6 +55,16 @@ register('sum', Sum) # ************* nn ops ************* +class Pad2D(Function): + @staticmethod + def forward(ctx, x, padding=None): + return np.pad(x, ((0,0), (0,0), (padding[0], padding[1]), (padding[2], padding[3]))) + + @staticmethod + def backward(ctx, grad_output): + raise Exception("write this") +register('pad2d', Pad2D) + class ReLU(Function): @staticmethod def forward(ctx, input): @@ -116,11 +126,13 @@ register('logsoftmax', LogSoftmax) class Conv2D(Function): @staticmethod - def forward(ctx, x, w, stride=1): + def forward(ctx, x, w, stride=1, groups=1): if type(ctx.stride) == int: ctx.stride = (ctx.stride, ctx.stride) cout,cin,H,W = w.shape + if groups > 1: + w = np.repeat(w, groups, axis=1) tw = w.reshape(cout, -1).T ys,xs = ctx.stride bs,oy,ox = x.shape[0], (x.shape[2]-(H-ys))//ys, (x.shape[3]-(W-xs))//xs