mirror of https://github.com/commaai/tinygrad.git
Fix batchnorm shapes for resnet.load_pretrained (#5167)
* Fix batchnorm shapes * make it general reshape
This commit is contained in:
parent
396ce6cfc9
commit
f1c7944c44
|
@ -146,15 +146,12 @@ class ResNet:
|
|||
print("skipping fully connected layer")
|
||||
continue # Skip FC if transfer learning
|
||||
|
||||
if dat.shape == ():
|
||||
assert obj.shape == (1,), obj.shape
|
||||
dat = dat.reshape(1)
|
||||
assert obj.shape == dat.shape, (k, obj.shape, dat.shape)
|
||||
obj.assign(dat)
|
||||
if 'bn' not in k and 'downsample' not in k: assert obj.shape == dat.shape, (k, obj.shape, dat.shape)
|
||||
obj.assign(dat.reshape(obj.shape))
|
||||
|
||||
ResNet18 = lambda num_classes=1000: ResNet(18, num_classes=num_classes)
|
||||
ResNet34 = lambda num_classes=1000: ResNet(34, num_classes=num_classes)
|
||||
ResNet50 = lambda num_classes=1000: ResNet(50, num_classes=num_classes)
|
||||
ResNet101 = lambda num_classes=1000: ResNet(101, num_classes=num_classes)
|
||||
ResNet152 = lambda num_classes=1000: ResNet(152, num_classes=num_classes)
|
||||
ResNeXt50_32X4D = lambda num_classes=1000: ResNet(50, num_classes=num_classes, groups=32, width_per_group=4)
|
||||
ResNeXt50_32X4D = lambda num_classes=1000: ResNet(50, num_classes=num_classes, groups=32, width_per_group=4)
|
||||
|
|
Loading…
Reference in New Issue