Fix batchnorm shapes for resnet.load_pretrained (#5167)

* Fix batchnorm shapes

* make it general reshape
This commit is contained in:
reddyn12 2024-06-26 18:44:10 -04:00 committed by GitHub
parent 396ce6cfc9
commit f1c7944c44
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 3 additions and 6 deletions

View File

@ -146,15 +146,12 @@ class ResNet:
print("skipping fully connected layer")
continue # Skip FC if transfer learning
if dat.shape == ():
assert obj.shape == (1,), obj.shape
dat = dat.reshape(1)
assert obj.shape == dat.shape, (k, obj.shape, dat.shape)
obj.assign(dat)
if 'bn' not in k and 'downsample' not in k: assert obj.shape == dat.shape, (k, obj.shape, dat.shape)
obj.assign(dat.reshape(obj.shape))
ResNet18 = lambda num_classes=1000: ResNet(18, num_classes=num_classes)
ResNet34 = lambda num_classes=1000: ResNet(34, num_classes=num_classes)
ResNet50 = lambda num_classes=1000: ResNet(50, num_classes=num_classes)
ResNet101 = lambda num_classes=1000: ResNet(101, num_classes=num_classes)
ResNet152 = lambda num_classes=1000: ResNet(152, num_classes=num_classes)
ResNeXt50_32X4D = lambda num_classes=1000: ResNet(50, num_classes=num_classes, groups=32, width_per_group=4)
ResNeXt50_32X4D = lambda num_classes=1000: ResNet(50, num_classes=num_classes, groups=32, width_per_group=4)