minor resnet cleanups (#6382)

* minor resnet cleanups

* that should have been long

* jit

* meh
This commit is contained in:
George Hotz 2024-09-06 12:50:21 +08:00 committed by GitHub
parent 86d34daac9
commit 9d72119a0c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 14 additions and 7 deletions

View File

@ -126,8 +126,6 @@ class ResNet:
return self.forward(x)
def load_from_pretrained(self):
# TODO replace with fake torch load
model_urls = {
(18, 1, 64): 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
(34, 1, 64): 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
@ -138,16 +136,15 @@ class ResNet:
}
self.url = model_urls[(self.num, self.groups, self.base_width)]
for k, v in torch_load(fetch(self.url)).items():
for k, dat in torch_load(fetch(self.url)).items():
obj: Tensor = get_child(self, k)
dat = v.numpy()
if 'fc.' in k and obj.shape != dat.shape:
print("skipping fully connected layer")
continue # Skip FC if transfer learning
if 'bn' not in k and 'downsample' not in k: assert obj.shape == dat.shape, (k, obj.shape, dat.shape)
obj.assign(dat.reshape(obj.shape))
obj.assign(dat.to(None).reshape(obj.shape))
ResNet18 = lambda num_classes=1000: ResNet(18, num_classes=num_classes)
ResNet34 = lambda num_classes=1000: ResNet(34, num_classes=num_classes)
@ -155,3 +152,13 @@ ResNet50 = lambda num_classes=1000: ResNet(50, num_classes=num_classes)
ResNet101 = lambda num_classes=1000: ResNet(101, num_classes=num_classes)
ResNet152 = lambda num_classes=1000: ResNet(152, num_classes=num_classes)
ResNeXt50_32X4D = lambda num_classes=1000: ResNet(50, num_classes=num_classes, groups=32, width_per_group=4)
if __name__ == "__main__":
model = ResNet18()
model.load_from_pretrained()
from tinygrad import Context, GlobalCounters, TinyJit
jmodel = TinyJit(model)
jmodel(Tensor.rand(1, 3, 224, 224)).realize()
GlobalCounters.reset()
with Context(GRAPH=1): jmodel(Tensor.rand(1, 3, 224, 224)).realize()
for i in range(10): jmodel(Tensor.rand(1, 3, 224, 224))

View File

@ -63,7 +63,7 @@ def log_lazybuffer(lb:'LazyBuffer', scheduled=False):
label = '"' + \
(str(set(x.shape for x in lb.srcs))+"\n"+str(lb.shape) if lb.op in ReduceOps else str(lb.shape)) + \
(f"\n{lb.dtype.name}" if lb.dtype.name != "float" else "")+f"\n{lb.op}"+(f"\n{lb.arg}" if lb.op in {MetaOps.CONST, UnaryOps.CAST} else "") + \
(f"\n{lb.device}" if lb.device != Device.DEFAULT else "") + ''.join(label_append) + f'\n{lb.metadata}"'
(f"\n{lb.device[:15]}" if lb.device != Device.DEFAULT else "") + ''.join(label_append) + f'\n{lb.metadata}"'
G.add_node(nm(lb), style='"filled,dashed"', fillcolor=[v for k,v in top_colors.items() if lb.op in k][0] + "80", color="black", label=label)
if scheduled: G.nodes[nm(lb)]['shape'] = 'box'
else:

View File

@ -35,7 +35,7 @@ class BatchNorm:
self.weight: Optional[Tensor] = Tensor.ones(sz) if affine else None
self.bias: Optional[Tensor] = Tensor.zeros(sz) if affine else None
self.num_batches_tracked = Tensor.zeros(1, requires_grad=False)
self.num_batches_tracked = Tensor.zeros(1, dtype='long', requires_grad=False)
if track_running_stats: self.running_mean, self.running_var = Tensor.zeros(sz, requires_grad=False), Tensor.ones(sz, requires_grad=False)
def calc_stats(self, x:Tensor) -> Tuple[Tensor, Tensor]: