getting 77% on imagenet eval

This commit is contained in:
George Hotz 2023-05-13 07:46:27 -07:00
parent 810f03dafa
commit b705510d5c
4 changed files with 66 additions and 17 deletions

View File

@ -1,28 +1,53 @@
import os
# for imagenet download prepare.sh and run it
import os, glob, random
import json
import numpy as np
from PIL import Image
import functools
import torchvision.transforms as transforms
BASEDIR = "/Users/kafka/fun/imagenet"
train_files = open(os.path.join(BASEDIR, "train_files")).read().strip().split("\n")
val_files = open(os.path.join(BASEDIR, "val_files")).read().strip().split("\n")
ci = json.load(open(os.path.join(BASEDIR, "imagenet_class_index.json")))
cir = {v[0]: int(k) for k,v in ci.items()}
rrc = transforms.RandomResizedCrop(224)
@functools.lru_cache(None)
def get_train_files():
train_files = open(os.path.join(BASEDIR, "train_files")).read().strip().split("\n")
return [os.path.join(BASEDIR, "train", x) for x in train_files]
@functools.lru_cache(None)
def get_val_files():
#val_files = open(os.path.join(BASEDIR, "val_files")).read().strip().split("\n")
val_files = glob.glob(os.path.join(BASEDIR, "val", "*", "*"))
return val_files
#rrc = transforms.RandomResizedCrop(224)
import torchvision.transforms.functional as F
def image_load(fn):
img = Image.open(fn).convert('RGB')
ret = np.array(rrc(img))
img = F.resize(img, 256, Image.BILINEAR)
img = F.center_crop(img, 224)
img = F.to_tensor(img)
img = F.normalize(img, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], inplace=False)
ret = np.array(img, dtype='float32')
return ret
def iterate(bs, val=False, shuffle=True):
files = get_val_files() if val else get_train_files()
order = list(range(0, len(files)))
if shuffle: random.shuffle(order)
for i in range(0, len(files), bs):
X = [image_load(files[i]) for i in order[i:i+bs]]
Y = [cir[files[i].split("/")[-2]] for i in order[i:i+bs]]
yield (np.array(X), np.array(Y))
def fetch_batch(bs, val=False):
files = val_files if val else train_files
files = get_val_files() if val else get_train_files()
samp = np.random.randint(0, len(files), size=(bs))
files = [files[i] for i in samp]
X = [image_load(os.path.join(BASEDIR, "val" if val else "train", x)) for x in files]
X = [image_load(x) for x in files]
Y = [cir[x.split("/")[0]] for x in files]
return np.transpose(np.array(X), (0,3,1,2)), np.array(Y)
return np.array(X), np.array(Y)
if __name__ == "__main__":
X,Y = fetch_batch(64)

View File

@ -0,0 +1,27 @@
import numpy as np
from tinygrad.tensor import Tensor
if __name__ == "__main__":
# inference only
Tensor.training = False
Tensor.no_grad = True
# Resnet50-v1.5
from models.resnet import ResNet50
mdl = ResNet50()
mdl.load_from_pretrained()
# evaluation on the mlperf classes of the validation set from imagenet
from datasets.imagenet import iterate
n,d = 0,0
for x,y in iterate(32, True, shuffle=True):
dat = Tensor(x.astype(np.float32))
outs = mdl(dat)
t = outs.numpy().argmax(axis=1)
print(t)
print(y)
n += (t==y).sum()
d += len(t)
print(f"****** {n}/{d} {n*100.0/d:.2f}%")

View File

@ -14,19 +14,17 @@ if __name__ == "__main__":
Tensor.no_grad = True
# Resnet50-v1.5
"""
from models.resnet import ResNet50
mdl = ResNet50()
img = Tensor.randn(1, 3, 224, 224)
test_model(mdl, img)
"""
# Retinanet
# 3D UNET
from models.unet3d import UNet3D
mdl = UNet3D()
mdl.load_from_pretrained()
#mdl.load_from_pretrained()
img = Tensor.randn(1, 1, 5, 224, 224)
test_model(mdl, img)

View File

@ -1,7 +1,6 @@
from tinygrad.tensor import Tensor
import tinygrad.nn as nn
from extra.utils import get_child
import numpy as np
class BasicBlock:
expansion = 1
@ -80,8 +79,7 @@ class ResNet:
self.layer2 = self._make_layer(self.block, 128, self.num_blocks[1], stride=2)
self.layer3 = self._make_layer(self.block, 256, self.num_blocks[2], stride=2)
self.layer4 = self._make_layer(self.block, 512, self.num_blocks[3], stride=2)
# TODO: replace with nn.Linear
self.fc = {"weight": Tensor.scaled_uniform(512 * self.block.expansion, num_classes), "bias": Tensor.zeros(num_classes)}
self.fc = nn.Linear(512 * self.block.expansion, num_classes)
def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1] * (num_blocks-1)
@ -93,12 +91,13 @@ class ResNet:
def forward(self, x):
out = self.bn1(self.conv1(x)).relu()
out = out.pad2d([1,1,1,1]).max_pool2d((3,3), 2)
out = out.sequential(self.layer1)
out = out.sequential(self.layer2)
out = out.sequential(self.layer3)
out = out.sequential(self.layer4)
out = out.mean(3).mean(2)
out = out.linear(**self.fc).log_softmax()
out = out.mean([2,3])
out = self.fc(out).log_softmax()
return out
def __call__(self, x):
@ -121,7 +120,7 @@ class ResNet:
state_dict = load_state_dict_from_url(self.url, progress=True)
for k, v in state_dict.items():
obj = get_child(self, k)
dat = v.detach().numpy().T if "fc.weight" in k else v.detach().numpy()
dat = v.detach().numpy()
if 'fc.' in k and obj.shape != dat.shape:
print("skipping fully connected layer")