mirror of https://github.com/commaai/tinygrad.git
Fix batchnorm shapes for resnet.load_pretrained (#5167)
* Fix batchnorm shapes * make it general reshape
This commit is contained in:
parent
396ce6cfc9
commit
f1c7944c44
|
@ -146,11 +146,8 @@ class ResNet:
|
||||||
print("skipping fully connected layer")
|
print("skipping fully connected layer")
|
||||||
continue # Skip FC if transfer learning
|
continue # Skip FC if transfer learning
|
||||||
|
|
||||||
if dat.shape == ():
|
if 'bn' not in k and 'downsample' not in k: assert obj.shape == dat.shape, (k, obj.shape, dat.shape)
|
||||||
assert obj.shape == (1,), obj.shape
|
obj.assign(dat.reshape(obj.shape))
|
||||||
dat = dat.reshape(1)
|
|
||||||
assert obj.shape == dat.shape, (k, obj.shape, dat.shape)
|
|
||||||
obj.assign(dat)
|
|
||||||
|
|
||||||
ResNet18 = lambda num_classes=1000: ResNet(18, num_classes=num_classes)
|
ResNet18 = lambda num_classes=1000: ResNet(18, num_classes=num_classes)
|
||||||
ResNet34 = lambda num_classes=1000: ResNet(34, num_classes=num_classes)
|
ResNet34 = lambda num_classes=1000: ResNet(34, num_classes=num_classes)
|
||||||
|
|
Loading…
Reference in New Issue