mirror of https://github.com/commaai/tinygrad.git
se optional, track time better
This commit is contained in:
parent
609d11e699
commit
521098cc2f
|
@ -41,7 +41,8 @@ if __name__ == "__main__":
|
||||||
if TINY:
|
if TINY:
|
||||||
model = TinyConvNet(classes)
|
model = TinyConvNet(classes)
|
||||||
else:
|
else:
|
||||||
model = EfficientNet(int(os.getenv("NUM", "0")), classes)
|
model = EfficientNet(int(os.getenv("NUM", "0")), classes, has_se=False)
|
||||||
|
|
||||||
parameters = get_parameters(model)
|
parameters = get_parameters(model)
|
||||||
print("parameters", len(parameters))
|
print("parameters", len(parameters))
|
||||||
optimizer = optim.Adam(parameters, lr=0.001)
|
optimizer = optim.Adam(parameters, lr=0.001)
|
||||||
|
@ -74,12 +75,19 @@ if __name__ == "__main__":
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
opt_time = (time.time()-st)*1000.0
|
opt_time = (time.time()-st)*1000.0
|
||||||
|
|
||||||
|
st = time.time()
|
||||||
|
loss = loss.cpu().data
|
||||||
cat = np.argmax(out.cpu().data, axis=1)
|
cat = np.argmax(out.cpu().data, axis=1)
|
||||||
accuracy = (cat == Y).mean()
|
accuracy = (cat == Y).mean()
|
||||||
|
finish_time = (time.time()-st)*1000.0
|
||||||
|
|
||||||
|
|
||||||
# printing
|
# printing
|
||||||
t.set_description("loss %.2f accuracy %.2f -- %.2f %.2f %.2f -- %d" %
|
t.set_description("loss %.2f accuracy %.2f -- %.2f + %.2f + %.2f + %.2f = %.2f -- %d" %
|
||||||
(loss.cpu().data, accuracy, fp_time, bp_time, opt_time, Tensor.allocated))
|
(loss, accuracy,
|
||||||
|
fp_time, bp_time, opt_time, finish_time,
|
||||||
|
fp_time + bp_time + opt_time + finish_time,
|
||||||
|
Tensor.allocated))
|
||||||
|
|
||||||
del out, y, loss
|
del out, y, loss
|
||||||
|
|
||||||
|
|
|
@ -70,7 +70,7 @@ def fake_torch_load(b0):
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
class MBConvBlock:
|
class MBConvBlock:
|
||||||
def __init__(self, kernel_size, strides, expand_ratio, input_filters, output_filters, se_ratio):
|
def __init__(self, kernel_size, strides, expand_ratio, input_filters, output_filters, se_ratio, has_se):
|
||||||
oup = expand_ratio * input_filters
|
oup = expand_ratio * input_filters
|
||||||
if expand_ratio != 1:
|
if expand_ratio != 1:
|
||||||
self._expand_conv = Tensor.zeros(oup, input_filters, 1, 1)
|
self._expand_conv = Tensor.zeros(oup, input_filters, 1, 1)
|
||||||
|
@ -87,11 +87,13 @@ class MBConvBlock:
|
||||||
self._depthwise_conv = Tensor.zeros(oup, 1, kernel_size, kernel_size)
|
self._depthwise_conv = Tensor.zeros(oup, 1, kernel_size, kernel_size)
|
||||||
self._bn1 = BatchNorm2D(oup)
|
self._bn1 = BatchNorm2D(oup)
|
||||||
|
|
||||||
num_squeezed_channels = max(1, int(input_filters * se_ratio))
|
self.has_se = has_se
|
||||||
self._se_reduce = Tensor.zeros(num_squeezed_channels, oup, 1, 1)
|
if self.has_se:
|
||||||
self._se_reduce_bias = Tensor.zeros(num_squeezed_channels)
|
num_squeezed_channels = max(1, int(input_filters * se_ratio))
|
||||||
self._se_expand = Tensor.zeros(oup, num_squeezed_channels, 1, 1)
|
self._se_reduce = Tensor.zeros(num_squeezed_channels, oup, 1, 1)
|
||||||
self._se_expand_bias = Tensor.zeros(oup)
|
self._se_reduce_bias = Tensor.zeros(num_squeezed_channels)
|
||||||
|
self._se_expand = Tensor.zeros(oup, num_squeezed_channels, 1, 1)
|
||||||
|
self._se_expand_bias = Tensor.zeros(oup)
|
||||||
|
|
||||||
self._project_conv = Tensor.zeros(output_filters, oup, 1, 1)
|
self._project_conv = Tensor.zeros(output_filters, oup, 1, 1)
|
||||||
self._bn2 = BatchNorm2D(output_filters)
|
self._bn2 = BatchNorm2D(output_filters)
|
||||||
|
@ -105,10 +107,11 @@ class MBConvBlock:
|
||||||
x = self._bn1(x).swish()
|
x = self._bn1(x).swish()
|
||||||
|
|
||||||
# has_se
|
# has_se
|
||||||
x_squeezed = x.avg_pool2d(kernel_size=x.shape[2:4])
|
if self.has_se:
|
||||||
x_squeezed = x_squeezed.conv2d(self._se_reduce).add(self._se_reduce_bias.reshape(shape=[1, -1, 1, 1])).swish()
|
x_squeezed = x.avg_pool2d(kernel_size=x.shape[2:4])
|
||||||
x_squeezed = x_squeezed.conv2d(self._se_expand).add(self._se_expand_bias.reshape(shape=[1, -1, 1, 1]))
|
x_squeezed = x_squeezed.conv2d(self._se_reduce).add(self._se_reduce_bias.reshape(shape=[1, -1, 1, 1])).swish()
|
||||||
x = x.mul(x_squeezed.sigmoid())
|
x_squeezed = x_squeezed.conv2d(self._se_expand).add(self._se_expand_bias.reshape(shape=[1, -1, 1, 1]))
|
||||||
|
x = x.mul(x_squeezed.sigmoid())
|
||||||
|
|
||||||
x = self._bn2(x.conv2d(self._project_conv))
|
x = self._bn2(x.conv2d(self._project_conv))
|
||||||
if x.shape == inputs.shape:
|
if x.shape == inputs.shape:
|
||||||
|
@ -116,7 +119,7 @@ class MBConvBlock:
|
||||||
return x
|
return x
|
||||||
|
|
||||||
class EfficientNet:
|
class EfficientNet:
|
||||||
def __init__(self, number=0, classes=1000):
|
def __init__(self, number=0, classes=1000, has_se=True):
|
||||||
self.number = number
|
self.number = number
|
||||||
global_params = [
|
global_params = [
|
||||||
# width, depth
|
# width, depth
|
||||||
|
@ -163,7 +166,7 @@ class EfficientNet:
|
||||||
args[3] = round_filters(args[3])
|
args[3] = round_filters(args[3])
|
||||||
args[4] = round_filters(args[4])
|
args[4] = round_filters(args[4])
|
||||||
for n in range(round_repeats(b[0])):
|
for n in range(round_repeats(b[0])):
|
||||||
self._blocks.append(MBConvBlock(*args))
|
self._blocks.append(MBConvBlock(*args, has_se=has_se))
|
||||||
args[3] = args[4]
|
args[3] = args[4]
|
||||||
args[1] = (1,1)
|
args[1] = (1,1)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue