mirror of https://github.com/commaai/tinygrad.git
minor resnet cleanups (#6382)
* minor resnet cleanups * that should have been long * jit * meh
This commit is contained in:
parent
86d34daac9
commit
9d72119a0c
|
@ -126,8 +126,6 @@ class ResNet:
|
|||
return self.forward(x)
|
||||
|
||||
def load_from_pretrained(self):
|
||||
# TODO replace with fake torch load
|
||||
|
||||
model_urls = {
|
||||
(18, 1, 64): 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
||||
(34, 1, 64): 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
|
||||
|
@ -138,16 +136,15 @@ class ResNet:
|
|||
}
|
||||
|
||||
self.url = model_urls[(self.num, self.groups, self.base_width)]
|
||||
for k, v in torch_load(fetch(self.url)).items():
|
||||
for k, dat in torch_load(fetch(self.url)).items():
|
||||
obj: Tensor = get_child(self, k)
|
||||
dat = v.numpy()
|
||||
|
||||
if 'fc.' in k and obj.shape != dat.shape:
|
||||
print("skipping fully connected layer")
|
||||
continue # Skip FC if transfer learning
|
||||
|
||||
if 'bn' not in k and 'downsample' not in k: assert obj.shape == dat.shape, (k, obj.shape, dat.shape)
|
||||
obj.assign(dat.reshape(obj.shape))
|
||||
obj.assign(dat.to(None).reshape(obj.shape))
|
||||
|
||||
ResNet18 = lambda num_classes=1000: ResNet(18, num_classes=num_classes)
|
||||
ResNet34 = lambda num_classes=1000: ResNet(34, num_classes=num_classes)
|
||||
|
@ -155,3 +152,13 @@ ResNet50 = lambda num_classes=1000: ResNet(50, num_classes=num_classes)
|
|||
ResNet101 = lambda num_classes=1000: ResNet(101, num_classes=num_classes)
|
||||
ResNet152 = lambda num_classes=1000: ResNet(152, num_classes=num_classes)
|
||||
ResNeXt50_32X4D = lambda num_classes=1000: ResNet(50, num_classes=num_classes, groups=32, width_per_group=4)
|
||||
|
||||
if __name__ == "__main__":
|
||||
model = ResNet18()
|
||||
model.load_from_pretrained()
|
||||
from tinygrad import Context, GlobalCounters, TinyJit
|
||||
jmodel = TinyJit(model)
|
||||
jmodel(Tensor.rand(1, 3, 224, 224)).realize()
|
||||
GlobalCounters.reset()
|
||||
with Context(GRAPH=1): jmodel(Tensor.rand(1, 3, 224, 224)).realize()
|
||||
for i in range(10): jmodel(Tensor.rand(1, 3, 224, 224))
|
||||
|
|
|
@ -63,7 +63,7 @@ def log_lazybuffer(lb:'LazyBuffer', scheduled=False):
|
|||
label = '"' + \
|
||||
(str(set(x.shape for x in lb.srcs))+"\n"+str(lb.shape) if lb.op in ReduceOps else str(lb.shape)) + \
|
||||
(f"\n{lb.dtype.name}" if lb.dtype.name != "float" else "")+f"\n{lb.op}"+(f"\n{lb.arg}" if lb.op in {MetaOps.CONST, UnaryOps.CAST} else "") + \
|
||||
(f"\n{lb.device}" if lb.device != Device.DEFAULT else "") + ''.join(label_append) + f'\n{lb.metadata}"'
|
||||
(f"\n{lb.device[:15]}" if lb.device != Device.DEFAULT else "") + ''.join(label_append) + f'\n{lb.metadata}"'
|
||||
G.add_node(nm(lb), style='"filled,dashed"', fillcolor=[v for k,v in top_colors.items() if lb.op in k][0] + "80", color="black", label=label)
|
||||
if scheduled: G.nodes[nm(lb)]['shape'] = 'box'
|
||||
else:
|
||||
|
|
|
@ -35,7 +35,7 @@ class BatchNorm:
|
|||
self.weight: Optional[Tensor] = Tensor.ones(sz) if affine else None
|
||||
self.bias: Optional[Tensor] = Tensor.zeros(sz) if affine else None
|
||||
|
||||
self.num_batches_tracked = Tensor.zeros(1, requires_grad=False)
|
||||
self.num_batches_tracked = Tensor.zeros(1, dtype='long', requires_grad=False)
|
||||
if track_running_stats: self.running_mean, self.running_var = Tensor.zeros(sz, requires_grad=False), Tensor.ones(sz, requires_grad=False)
|
||||
|
||||
def calc_stats(self, x:Tensor) -> Tuple[Tensor, Tensor]:
|
||||
|
|
Loading…
Reference in New Issue