mirror of https://github.com/commaai/tinygrad.git
add retinanet with resnet backbone (#813)
* add retinanet with resnet backbone * adds resnext to support loading retinanet pretrained on openimages * object detection post processing with numpy * data is downloaded and converted to coco format with fiftyone * data loading and mAP evaluation with pycocotools * remove fiftyone dep * * eval freq * fix model timing * del jit for last batch * faster accumulate
This commit is contained in:
parent
46327f7420
commit
65d09031f2
|
@ -0,0 +1,165 @@
|
|||
import os
|
||||
import math
|
||||
import json
|
||||
from extra.utils import OSX
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import pathlib
|
||||
import boto3, botocore
|
||||
from extra.utils import download_file
|
||||
from tqdm import tqdm
|
||||
import pandas as pd
|
||||
import concurrent.futures
|
||||
|
||||
BASEDIR = pathlib.Path(__file__).parent.parent / "datasets/open-images-v6-mlperf"
|
||||
BUCKET_NAME = "open-images-dataset"
|
||||
BBOX_ANNOTATIONS_URL = "https://storage.googleapis.com/openimages/v5/validation-annotations-bbox.csv"
|
||||
MAP_CLASSES_URL = "https://storage.googleapis.com/openimages/v5/class-descriptions-boxable.csv"
|
||||
MLPERF_CLASSES = ['Airplane', 'Antelope', 'Apple', 'Backpack', 'Balloon', 'Banana',
|
||||
'Barrel', 'Baseball bat', 'Baseball glove', 'Bee', 'Beer', 'Bench', 'Bicycle',
|
||||
'Bicycle helmet', 'Bicycle wheel', 'Billboard', 'Book', 'Bookcase', 'Boot',
|
||||
'Bottle', 'Bowl', 'Bowling equipment', 'Box', 'Boy', 'Brassiere', 'Bread',
|
||||
'Broccoli', 'Bronze sculpture', 'Bull', 'Bus', 'Bust', 'Butterfly', 'Cabinetry',
|
||||
'Cake', 'Camel', 'Camera', 'Candle', 'Candy', 'Cannon', 'Canoe', 'Carrot', 'Cart',
|
||||
'Castle', 'Cat', 'Cattle', 'Cello', 'Chair', 'Cheese', 'Chest of drawers', 'Chicken',
|
||||
'Christmas tree', 'Coat', 'Cocktail', 'Coffee', 'Coffee cup', 'Coffee table', 'Coin',
|
||||
'Common sunflower', 'Computer keyboard', 'Computer monitor', 'Convenience store',
|
||||
'Cookie', 'Countertop', 'Cowboy hat', 'Crab', 'Crocodile', 'Cucumber', 'Cupboard',
|
||||
'Curtain', 'Deer', 'Desk', 'Dinosaur', 'Dog', 'Doll', 'Dolphin', 'Door', 'Dragonfly',
|
||||
'Drawer', 'Dress', 'Drum', 'Duck', 'Eagle', 'Earrings', 'Egg (Food)', 'Elephant',
|
||||
'Falcon', 'Fedora', 'Flag', 'Flowerpot', 'Football', 'Football helmet', 'Fork',
|
||||
'Fountain', 'French fries', 'French horn', 'Frog', 'Giraffe', 'Girl', 'Glasses',
|
||||
'Goat', 'Goggles', 'Goldfish', 'Gondola', 'Goose', 'Grape', 'Grapefruit', 'Guitar',
|
||||
'Hamburger', 'Handbag', 'Harbor seal', 'Headphones', 'Helicopter', 'High heels',
|
||||
'Hiking equipment', 'Horse', 'House', 'Houseplant', 'Human arm', 'Human beard',
|
||||
'Human body', 'Human ear', 'Human eye', 'Human face', 'Human foot', 'Human hair',
|
||||
'Human hand', 'Human head', 'Human leg', 'Human mouth', 'Human nose', 'Ice cream',
|
||||
'Jacket', 'Jeans', 'Jellyfish', 'Juice', 'Kitchen & dining room table', 'Kite',
|
||||
'Lamp', 'Lantern', 'Laptop', 'Lavender (Plant)', 'Lemon', 'Light bulb', 'Lighthouse',
|
||||
'Lily', 'Lion', 'Lipstick', 'Lizard', 'Man', 'Maple', 'Microphone', 'Mirror',
|
||||
'Mixing bowl', 'Mobile phone', 'Monkey', 'Motorcycle', 'Muffin', 'Mug', 'Mule',
|
||||
'Mushroom', 'Musical keyboard', 'Necklace', 'Nightstand', 'Office building',
|
||||
'Orange', 'Owl', 'Oyster', 'Paddle', 'Palm tree', 'Parachute', 'Parrot', 'Pen',
|
||||
'Penguin', 'Personal flotation device', 'Piano', 'Picture frame', 'Pig', 'Pillow',
|
||||
'Pizza', 'Plate', 'Platter', 'Porch', 'Poster', 'Pumpkin', 'Rabbit', 'Rifle',
|
||||
'Roller skates', 'Rose', 'Salad', 'Sandal', 'Saucer', 'Saxophone', 'Scarf', 'Sea lion',
|
||||
'Sea turtle', 'Sheep', 'Shelf', 'Shirt', 'Shorts', 'Shrimp', 'Sink', 'Skateboard',
|
||||
'Ski', 'Skull', 'Skyscraper', 'Snake', 'Sock', 'Sofa bed', 'Sparrow', 'Spider', 'Spoon',
|
||||
'Sports uniform', 'Squirrel', 'Stairs', 'Stool', 'Strawberry', 'Street light',
|
||||
'Studio couch', 'Suit', 'Sun hat', 'Sunglasses', 'Surfboard', 'Sushi', 'Swan',
|
||||
'Swimming pool', 'Swimwear', 'Tank', 'Tap', 'Taxi', 'Tea', 'Teddy bear', 'Television',
|
||||
'Tent', 'Tie', 'Tiger', 'Tin can', 'Tire', 'Toilet', 'Tomato', 'Tortoise', 'Tower',
|
||||
'Traffic light', 'Train', 'Tripod', 'Truck', 'Trumpet', 'Umbrella', 'Van', 'Vase',
|
||||
'Vehicle registration plate', 'Violin', 'Wall clock', 'Waste container', 'Watch',
|
||||
'Whale', 'Wheel', 'Wheelchair', 'Whiteboard', 'Window', 'Wine', 'Wine glass', 'Woman',
|
||||
'Zebra', 'Zucchini',
|
||||
]
|
||||
|
||||
def openimages():
|
||||
ann_file = BASEDIR / "validation/labels/openimages-mlperf.json"
|
||||
if not ann_file.is_file():
|
||||
fetch_openimages(ann_file)
|
||||
return ann_file
|
||||
|
||||
# this slows down the conversion a lot!
|
||||
# maybe use https://raw.githubusercontent.com/scardine/image_size/master/get_image_size.py
|
||||
def extract_dims(path): return Image.open(path).size[::-1]
|
||||
|
||||
def export_to_coco(class_map, annotations, image_list, dataset_path, output_path, classes=MLPERF_CLASSES):
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
cats = [{"id": i, "name": c, "supercategory": None} for i, c in enumerate(classes)]
|
||||
categories_map = pd.DataFrame([(i, c) for i, c in enumerate(classes)], columns=["category_id", "category_name"])
|
||||
class_map = class_map.merge(categories_map, left_on="DisplayName", right_on="category_name", how="inner")
|
||||
annotations = annotations[np.isin(annotations["ImageID"], image_list)]
|
||||
annotations = annotations.merge(class_map, on="LabelName", how="inner")
|
||||
annotations["image_id"] = pd.factorize(annotations["ImageID"].tolist())[0]
|
||||
annotations[["height", "width"]] = annotations.apply(lambda x: extract_dims(dataset_path / f"{x['ImageID']}.jpg"), axis=1, result_type="expand")
|
||||
|
||||
# Images
|
||||
imgs = [{"id": int(id + 1), "file_name": f"{image_id}.jpg", "height": row["height"], "width": row["width"], "license": None, "coco_url": None}
|
||||
for (id, image_id), row in (annotations.groupby(["image_id", "ImageID"]).first().iterrows())
|
||||
]
|
||||
|
||||
# Annotations
|
||||
annots = []
|
||||
for i, row in annotations.iterrows():
|
||||
xmin, ymin, xmax, ymax, img_w, img_h = [row[k] for k in ["XMin", "YMin", "XMax", "YMax", "width", "height"]]
|
||||
x, y, w, h = xmin * img_w, ymin * img_h, (xmax - xmin) * img_w, (ymax - ymin) * img_h
|
||||
coco_annot = {"id": int(i) + 1, "image_id": int(row["image_id"] + 1), "category_id": int(row["category_id"]), "bbox": [x, y, w, h], "area": w * h}
|
||||
coco_annot.update({k: row[k] for k in ["IsOccluded", "IsInside", "IsDepiction", "IsTruncated", "IsGroupOf"]})
|
||||
coco_annot["iscrowd"] = int(row["IsGroupOf"])
|
||||
annots.append(coco_annot)
|
||||
|
||||
info = {"dataset": "openimages_mlperf", "version": "v6"}
|
||||
coco_annotations = {"info": info, "licenses": [], "categories": cats, "images": imgs, "annotations": annots}
|
||||
with open(output_path, "w") as fp:
|
||||
json.dump(coco_annotations, fp)
|
||||
|
||||
def get_image_list(class_map, annotations, classes=MLPERF_CLASSES):
|
||||
labels = class_map[np.isin(class_map["DisplayName"], classes)]["LabelName"]
|
||||
image_ids = annotations[np.isin(annotations["LabelName"], labels)]["ImageID"].unique()
|
||||
return image_ids
|
||||
|
||||
def download_image(bucket, image_id, data_dir):
|
||||
try:
|
||||
bucket.download_file(f"validation/{image_id}.jpg", f"{data_dir}/{image_id}.jpg")
|
||||
except botocore.exceptions.ClientError as exception:
|
||||
sys.exit(f"ERROR when downloading image `validation/{image_id}`: {str(exception)}")
|
||||
|
||||
def fetch_openimages(output_fn):
|
||||
bucket = boto3.resource("s3", config=botocore.config.Config(signature_version=botocore.UNSIGNED)).Bucket(BUCKET_NAME)
|
||||
|
||||
annotations_dir, data_dir = BASEDIR / "annotations", BASEDIR / "validation/data"
|
||||
annotations_dir.mkdir(parents=True, exist_ok=True)
|
||||
data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
annotations_fn = annotations_dir / BBOX_ANNOTATIONS_URL.split('/')[-1]
|
||||
download_file(BBOX_ANNOTATIONS_URL, annotations_fn)
|
||||
annotations = pd.read_csv(annotations_fn)
|
||||
|
||||
classmap_fn = annotations_dir / MAP_CLASSES_URL.split('/')[-1]
|
||||
download_file(MAP_CLASSES_URL, classmap_fn)
|
||||
class_map = pd.read_csv(classmap_fn, names=["LabelName", "DisplayName"])
|
||||
|
||||
image_list = get_image_list(class_map, annotations)
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
futures = [executor.submit(download_image, bucket, image_id, data_dir) for image_id in image_list]
|
||||
for future in (t := tqdm(concurrent.futures.as_completed(futures), total=len(image_list))):
|
||||
t.set_description(f"Downloading images")
|
||||
future.result()
|
||||
|
||||
print("Converting annotations to COCO format...")
|
||||
export_to_coco(class_map, annotations, image_list, data_dir, output_fn)
|
||||
|
||||
def image_load(fn):
|
||||
img_folder = BASEDIR / "validation/data"
|
||||
img = Image.open(img_folder / fn).convert('RGB')
|
||||
import torchvision.transforms.functional as F
|
||||
ret = F.resize(img, size=(800, 800))
|
||||
ret = np.array(ret)
|
||||
return ret, img.size[::-1]
|
||||
|
||||
def prepare_target(annotations, img_id, img_size):
|
||||
boxes = [annot["bbox"] for annot in annotations]
|
||||
boxes = np.array(boxes, dtype=np.float32).reshape(-1, 4)
|
||||
boxes[:, 2:] += boxes[:, :2]
|
||||
boxes[:, 0::2] = boxes[:, 0::2].clip(0, img_size[1])
|
||||
boxes[:, 1::2] = boxes[:, 1::2].clip(0, img_size[0])
|
||||
keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
|
||||
boxes = boxes[keep]
|
||||
classes = [annot["category_id"] for annot in annotations]
|
||||
classes = np.array(classes, dtype=np.int64)
|
||||
classes = classes[keep]
|
||||
return {"boxes": boxes, "labels": classes, "image_id": img_id, "image_size": img_size}
|
||||
|
||||
def iterate(coco, bs=8):
|
||||
image_ids = sorted(coco.imgs.keys())
|
||||
for i in range(0, len(image_ids), bs):
|
||||
X, targets = [], []
|
||||
for img_id in image_ids[i:i+bs]:
|
||||
x, original_size = image_load(coco.loadImgs(img_id)[0]["file_name"])
|
||||
X.append(x)
|
||||
annotations = coco.loadAnns(coco.getAnnIds(img_id))
|
||||
targets.append(prepare_target(annotations, img_id, original_size))
|
||||
yield np.array(X), targets
|
|
@ -42,6 +42,64 @@ def eval_resnet():
|
|||
print(f"****** {n}/{d} {n*100.0/d:.2f}%")
|
||||
st = time.perf_counter()
|
||||
|
||||
def eval_retinanet():
|
||||
# RetinaNet with ResNeXt50_32X4D
|
||||
from models.resnet import ResNeXt50_32X4D
|
||||
from models.retinanet import RetinaNet
|
||||
mdl = RetinaNet(ResNeXt50_32X4D())
|
||||
mdl.load_from_pretrained()
|
||||
|
||||
input_mean = Tensor([0.485, 0.456, 0.406]).reshape(1, -1, 1, 1)
|
||||
input_std = Tensor([0.229, 0.224, 0.225]).reshape(1, -1, 1, 1)
|
||||
def input_fixup(x):
|
||||
x = x.permute([0,3,1,2]) / 255.0
|
||||
x -= input_mean
|
||||
x /= input_std
|
||||
return x
|
||||
|
||||
from datasets.openimages import openimages, iterate
|
||||
from pycocotools.coco import COCO
|
||||
from pycocotools.cocoeval import COCOeval
|
||||
from contextlib import redirect_stdout
|
||||
coco = COCO(openimages())
|
||||
coco_eval = COCOeval(coco, iouType="bbox")
|
||||
coco_evalimgs, evaluated_imgs, ncats, narea = [], [], len(coco_eval.params.catIds), len(coco_eval.params.areaRng)
|
||||
|
||||
from tinygrad.jit import TinyJit
|
||||
mdlrun = TinyJit(lambda x: mdl(input_fixup(x)).realize())
|
||||
|
||||
n, bs = 0, 8
|
||||
st = time.perf_counter()
|
||||
for x, targets in iterate(coco, bs):
|
||||
dat = Tensor(x.astype(np.float32))
|
||||
mt = time.perf_counter()
|
||||
if dat.shape[0] == bs:
|
||||
outs = mdlrun(dat).numpy()
|
||||
else:
|
||||
mdlrun.jit_cache = None
|
||||
outs = mdl(input_fixup(dat)).numpy()
|
||||
et = time.perf_counter()
|
||||
predictions = mdl.postprocess_detections(outs, input_size=dat.shape[1:3], orig_image_sizes=[t["image_size"] for t in targets])
|
||||
ext = time.perf_counter()
|
||||
n += len(targets)
|
||||
print(f"[{n}/{len(coco.imgs)}] == {(mt-st)*1000:.2f} ms loading data, {(et-mt)*1000:.2f} ms to run model, {(ext-et)*1000:.2f} ms for postprocessing")
|
||||
img_ids = [t["image_id"] for t in targets]
|
||||
coco_results = [{"image_id": targets[i]["image_id"], "category_id": label, "bbox": box, "score": score}
|
||||
for i, prediction in enumerate(predictions) for box, score, label in zip(*prediction.values())]
|
||||
with redirect_stdout(None):
|
||||
coco_eval.cocoDt = coco.loadRes(coco_results)
|
||||
coco_eval.params.imgIds = img_ids
|
||||
coco_eval.evaluate()
|
||||
evaluated_imgs.extend(img_ids)
|
||||
coco_evalimgs.append(np.array(coco_eval.evalImgs).reshape(ncats, narea, len(img_ids)))
|
||||
st = time.perf_counter()
|
||||
|
||||
coco_eval.params.imgIds = evaluated_imgs
|
||||
coco_eval._paramsEval.imgIds = evaluated_imgs
|
||||
coco_eval.evalImgs = list(np.concatenate(coco_evalimgs, -1).flatten())
|
||||
coco_eval.accumulate()
|
||||
coco_eval.summarize()
|
||||
|
||||
def eval_rnnt():
|
||||
# RNN-T
|
||||
from models.rnnt import RNNT
|
||||
|
|
|
@ -17,8 +17,12 @@ def spec_resnet():
|
|||
test_model(mdl, img)
|
||||
|
||||
def spec_retinanet():
|
||||
# TODO: Retinanet
|
||||
pass
|
||||
# Retinanet with ResNet backbone
|
||||
from models.resnet import ResNet50
|
||||
from models.retinanet import RetinaNet
|
||||
mdl = RetinaNet(ResNet50(), num_classes=91, num_anchors=9)
|
||||
img = Tensor.randn(1, 3, 224, 224)
|
||||
test_model(mdl, img)
|
||||
|
||||
def spec_unet3d():
|
||||
# 3D UNET
|
||||
|
|
|
@ -5,7 +5,8 @@ from extra.utils import get_child
|
|||
class BasicBlock:
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, in_planes, planes, stride=1):
|
||||
def __init__(self, in_planes, planes, stride=1, groups=1, base_width=64):
|
||||
assert groups == 1 and base_width == 64, "BasicBlock only supports groups=1 and base_width=64"
|
||||
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, stride=1, bias=False)
|
||||
|
@ -29,12 +30,13 @@ class Bottleneck:
|
|||
# NOTE: the original implementation places stride at the first convolution (self.conv1), this is the v1.5 variant
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, in_planes, planes, stride=1):
|
||||
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, stride=stride, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
|
||||
def __init__(self, in_planes, planes, stride=1, groups=1, base_width=64):
|
||||
width = int(planes * (base_width / 64.0)) * groups
|
||||
self.conv1 = nn.Conv2d(in_planes, width, kernel_size=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(width)
|
||||
self.conv2 = nn.Conv2d(width, width, kernel_size=3, padding=1, stride=stride, groups=groups, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(width)
|
||||
self.conv3 = nn.Conv2d(width, self.expansion*planes, kernel_size=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(self.expansion*planes)
|
||||
self.downsample = []
|
||||
if stride != 1 or in_planes != self.expansion*planes:
|
||||
|
@ -52,7 +54,7 @@ class Bottleneck:
|
|||
return out
|
||||
|
||||
class ResNet:
|
||||
def __init__(self, num, num_classes):
|
||||
def __init__(self, num, num_classes, groups=1, width_per_group=64):
|
||||
self.num = num
|
||||
|
||||
self.block = {
|
||||
|
@ -73,6 +75,8 @@ class ResNet:
|
|||
|
||||
self.in_planes = 64
|
||||
|
||||
self.groups = groups
|
||||
self.base_width = width_per_group
|
||||
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, bias=False, padding=3)
|
||||
self.bn1 = nn.BatchNorm2d(64)
|
||||
self.layer1 = self._make_layer(self.block, 64, self.num_blocks[0], stride=1)
|
||||
|
@ -85,7 +89,7 @@ class ResNet:
|
|||
strides = [stride] + [1] * (num_blocks-1)
|
||||
layers = []
|
||||
for stride in strides:
|
||||
layers.append(block(self.in_planes, planes, stride))
|
||||
layers.append(block(self.in_planes, planes, stride, self.groups, self.base_width))
|
||||
self.in_planes = planes * block.expansion
|
||||
return layers
|
||||
|
||||
|
@ -107,14 +111,15 @@ class ResNet:
|
|||
# TODO replace with fake torch load
|
||||
|
||||
model_urls = {
|
||||
18: 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
||||
34: 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
|
||||
50: 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
||||
101: 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
|
||||
152: 'https://download.pytorch.org/models/resnet152-b121ed2d.pth'
|
||||
(18, 1, 64): 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
||||
(34, 1, 64): 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
|
||||
(50, 1, 64): 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
||||
(50, 32, 4): 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
|
||||
(101, 1, 64): 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
|
||||
(152, 1, 64): 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
|
||||
}
|
||||
|
||||
self.url = model_urls[self.num]
|
||||
self.url = model_urls[(self.num, self.groups, self.base_width)]
|
||||
|
||||
from torch.hub import load_state_dict_from_url
|
||||
state_dict = load_state_dict_from_url(self.url, progress=True)
|
||||
|
@ -126,7 +131,8 @@ class ResNet:
|
|||
print("skipping fully connected layer")
|
||||
continue # Skip FC if transfer learning
|
||||
|
||||
assert obj.shape == dat.shape, (k, obj.shape, dat.shape)
|
||||
# TODO: remove or when #777 is merged
|
||||
assert obj.shape == dat.shape or (obj.shape == (1,) and dat.shape == ()), (k, obj.shape, dat.shape)
|
||||
obj.assign(dat)
|
||||
|
||||
ResNet18 = lambda num_classes=1000: ResNet(18, num_classes=num_classes)
|
||||
|
@ -134,3 +140,4 @@ ResNet34 = lambda num_classes=1000: ResNet(34, num_classes=num_classes)
|
|||
ResNet50 = lambda num_classes=1000: ResNet(50, num_classes=num_classes)
|
||||
ResNet101 = lambda num_classes=1000: ResNet(101, num_classes=num_classes)
|
||||
ResNet152 = lambda num_classes=1000: ResNet(152, num_classes=num_classes)
|
||||
ResNeXt50_32X4D = lambda num_classes=1000: ResNet(50, num_classes=num_classes, groups=32, width_per_group=4)
|
||||
|
|
|
@ -0,0 +1,237 @@
|
|||
import math
|
||||
from tinygrad.helpers import flatten
|
||||
import tinygrad.nn as nn
|
||||
from models.resnet import ResNet
|
||||
from extra.utils import get_child
|
||||
import numpy as np
|
||||
|
||||
def nms(boxes, scores, thresh=0.5):
|
||||
x1, y1, x2, y2 = np.rollaxis(boxes, 1)
|
||||
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
||||
to_process, keep = scores.argsort()[::-1], []
|
||||
while to_process.size > 0:
|
||||
cur, to_process = to_process[0], to_process[1:]
|
||||
keep.append(cur)
|
||||
inter_x1 = np.maximum(x1[cur], x1[to_process])
|
||||
inter_y1 = np.maximum(y1[cur], y1[to_process])
|
||||
inter_x2 = np.minimum(x2[cur], x2[to_process])
|
||||
inter_y2 = np.minimum(y2[cur], y2[to_process])
|
||||
inter_area = np.maximum(0, inter_x2 - inter_x1 + 1) * np.maximum(0, inter_y2 - inter_y1 + 1)
|
||||
iou = inter_area / (areas[cur] + areas[to_process] - inter_area)
|
||||
to_process = to_process[np.where(iou <= thresh)[0]]
|
||||
return keep
|
||||
|
||||
def decode_bbox(offsets, anchors):
|
||||
dx, dy, dw, dh = np.rollaxis(offsets, 1)
|
||||
widths, heights = anchors[:, 2] - anchors[:, 0], anchors[:, 3] - anchors[:, 1]
|
||||
cx, cy = anchors[:, 0] + 0.5 * widths, anchors[:, 1] + 0.5 * heights
|
||||
pred_cx, pred_cy = dx * widths + cx, dy * heights + cy
|
||||
pred_w, pred_h = np.exp(dw) * widths, np.exp(dh) * heights
|
||||
pred_x1, pred_y1 = pred_cx - 0.5 * pred_w, pred_cy - 0.5 * pred_h
|
||||
pred_x2, pred_y2 = pred_cx + 0.5 * pred_w, pred_cy + 0.5 * pred_h
|
||||
return np.stack([pred_x1, pred_y1, pred_x2, pred_y2], axis=1, dtype=np.float32)
|
||||
|
||||
def generate_anchors(input_size, grid_sizes, scales, aspect_ratios):
|
||||
assert len(scales) == len(aspect_ratios) == len(grid_sizes)
|
||||
anchors = []
|
||||
for s, ar, gs in zip(scales, aspect_ratios, grid_sizes):
|
||||
s, ar = np.array(s), np.array(ar)
|
||||
h_ratios = np.sqrt(ar)
|
||||
w_ratios = 1 / h_ratios
|
||||
ws = (w_ratios[:, None] * s[None, :]).reshape(-1)
|
||||
hs = (h_ratios[:, None] * s[None, :]).reshape(-1)
|
||||
base_anchors = (np.stack([-ws, -hs, ws, hs], axis=1) / 2).round()
|
||||
stride_h, stride_w = input_size[0] // gs[0], input_size[1] // gs[1]
|
||||
shifts_x, shifts_y = np.meshgrid(np.arange(gs[1]) * stride_w, np.arange(gs[0]) * stride_h)
|
||||
shifts_x = shifts_x.reshape(-1)
|
||||
shifts_y = shifts_y.reshape(-1)
|
||||
shifts = np.stack([shifts_x, shifts_y, shifts_x, shifts_y], axis=1, dtype=np.float32)
|
||||
anchors.append((shifts[:, None] + base_anchors[None, :]).reshape(-1, 4))
|
||||
return anchors
|
||||
|
||||
class RetinaNet:
|
||||
def __init__(self, backbone: ResNet, num_classes=264, num_anchors=9, scales=None, aspect_ratios=None):
|
||||
assert isinstance(backbone, ResNet)
|
||||
scales = tuple((i, int(i*2**(1/3)), int(i*2**(2/3))) for i in 2**np.arange(5, 10)) if scales is None else scales
|
||||
aspect_ratios = ((0.5, 1.0, 2.0),) * len(scales) if aspect_ratios is None else aspect_ratios
|
||||
self.num_anchors, self.num_classes = num_anchors, num_classes
|
||||
assert len(scales) == len(aspect_ratios) and all([self.num_anchors == len(s) * len(ar) for s, ar in zip(scales, aspect_ratios)])
|
||||
|
||||
self.backbone = ResNetFPN(backbone)
|
||||
self.head = RetinaHead(self.backbone.out_channels, num_anchors=num_anchors, num_classes=num_classes)
|
||||
self.anchor_gen = lambda input_size: generate_anchors(input_size, self.backbone.compute_grid_sizes(input_size), scales, aspect_ratios)
|
||||
|
||||
def __call__(self, x):
|
||||
return self.forward(x)
|
||||
def forward(self, x):
|
||||
return self.head(self.backbone(x))
|
||||
|
||||
def load_from_pretrained(self):
|
||||
model_urls = {
|
||||
(50, 1, 64): "https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth",
|
||||
(50, 32, 4): "https://zenodo.org/record/6605272/files/retinanet_model_10.zip",
|
||||
}
|
||||
self.url = model_urls[(self.backbone.body.num, self.backbone.body.groups, self.backbone.body.base_width)]
|
||||
from torch.hub import load_state_dict_from_url
|
||||
state_dict = load_state_dict_from_url(self.url, progress=True, map_location='cpu')
|
||||
state_dict = state_dict['model'] if 'model' in state_dict.keys() else state_dict
|
||||
for k, v in state_dict.items():
|
||||
obj = get_child(self, k)
|
||||
dat = v.detach().numpy()
|
||||
assert obj.shape == dat.shape, (k, obj.shape, dat.shape)
|
||||
obj.assign(dat)
|
||||
|
||||
# predictions: (BS, (H1W1+...+HmWm)A, 4 + K)
|
||||
def postprocess_detections(self, predictions, input_size=(800, 800), image_sizes=None, orig_image_sizes=None, score_thresh=0.05, topk_candidates=1000, nms_thresh=0.5):
|
||||
anchors = self.anchor_gen(input_size)
|
||||
grid_sizes = self.backbone.compute_grid_sizes(input_size)
|
||||
split_idx = np.cumsum([int(self.num_anchors * sz[0] * sz[1]) for sz in grid_sizes[:-1]])
|
||||
detections = []
|
||||
for i, predictions_per_image in enumerate(predictions):
|
||||
h, w = input_size if image_sizes is None else image_sizes[i]
|
||||
|
||||
predictions_per_image = np.split(predictions_per_image, split_idx)
|
||||
offsets_per_image = [br[:, :4] for br in predictions_per_image]
|
||||
scores_per_image = [cl[:, 4:] for cl in predictions_per_image]
|
||||
|
||||
image_boxes, image_scores, image_labels = [], [], []
|
||||
for offsets_per_level, scores_per_level, anchors_per_level in zip(offsets_per_image, scores_per_image, anchors):
|
||||
# remove low scoring boxes
|
||||
scores_per_level = scores_per_level.flatten()
|
||||
keep_idxs = scores_per_level > score_thresh
|
||||
scores_per_level = scores_per_level[keep_idxs]
|
||||
|
||||
# keep topk
|
||||
topk_idxs = np.where(keep_idxs)[0]
|
||||
num_topk = min(len(topk_idxs), topk_candidates)
|
||||
sort_idxs = scores_per_level.argsort()[-num_topk:][::-1]
|
||||
topk_idxs, scores_per_level = topk_idxs[sort_idxs], scores_per_level[sort_idxs]
|
||||
|
||||
# bbox coords from offsets
|
||||
anchor_idxs = topk_idxs // self.num_classes
|
||||
labels_per_level = topk_idxs % self.num_classes
|
||||
boxes_per_level = decode_bbox(offsets_per_level[anchor_idxs], anchors_per_level[anchor_idxs])
|
||||
# clip to image size
|
||||
clipped_x = boxes_per_level[:, 0::2].clip(0, w)
|
||||
clipped_y = boxes_per_level[:, 1::2].clip(0, h)
|
||||
boxes_per_level = np.stack([clipped_x, clipped_y], axis=2).reshape(-1, 4)
|
||||
|
||||
image_boxes.append(boxes_per_level)
|
||||
image_scores.append(scores_per_level)
|
||||
image_labels.append(labels_per_level)
|
||||
|
||||
image_boxes = np.concatenate(image_boxes)
|
||||
image_scores = np.concatenate(image_scores)
|
||||
image_labels = np.concatenate(image_labels)
|
||||
|
||||
# nms for each class
|
||||
keep_mask = np.zeros_like(image_scores, dtype=bool)
|
||||
for class_id in np.unique(image_labels):
|
||||
curr_indices = np.where(image_labels == class_id)[0]
|
||||
curr_keep_indices = nms(image_boxes[curr_indices], image_scores[curr_indices], nms_thresh)
|
||||
keep_mask[curr_indices[curr_keep_indices]] = True
|
||||
keep = np.where(keep_mask)[0]
|
||||
keep = keep[image_scores[keep].argsort()[::-1]]
|
||||
|
||||
# resize bboxes back to original size
|
||||
image_boxes = image_boxes[keep]
|
||||
if orig_image_sizes is not None:
|
||||
resized_x = image_boxes[:, 0::2] * orig_image_sizes[i][1] / w
|
||||
resized_y = image_boxes[:, 1::2] * orig_image_sizes[i][0] / h
|
||||
image_boxes = np.stack([resized_x, resized_y], axis=2).reshape(-1, 4)
|
||||
# xywh format
|
||||
image_boxes = np.concatenate([image_boxes[:, :2], image_boxes[:, 2:] - image_boxes[:, :2]], axis=1)
|
||||
|
||||
detections.append({"boxes":image_boxes, "scores":image_scores[keep], "labels":image_labels[keep]})
|
||||
return detections
|
||||
|
||||
class ClassificationHead:
|
||||
def __init__(self, in_channels, num_anchors, num_classes):
|
||||
self.num_classes = num_classes
|
||||
self.conv = flatten([(nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1), lambda x: x.relu()) for _ in range(4)])
|
||||
self.cls_logits = nn.Conv2d(in_channels, num_anchors * num_classes, kernel_size=3, padding=1)
|
||||
def __call__(self, x):
|
||||
out = [self.cls_logits(feat.sequential(self.conv)).permute(0, 2, 3, 1).reshape(feat.shape[0], -1, self.num_classes) for feat in x]
|
||||
return out[0].cat(*out[1:], dim=1).sigmoid()
|
||||
|
||||
class RegressionHead:
|
||||
def __init__(self, in_channels, num_anchors):
|
||||
self.conv = flatten([(nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1), lambda x: x.relu()) for _ in range(4)])
|
||||
self.bbox_reg = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=3, padding=1)
|
||||
def __call__(self, x):
|
||||
out = [self.bbox_reg(feat.sequential(self.conv)).permute(0, 2, 3, 1).reshape(feat.shape[0], -1, 4) for feat in x]
|
||||
return out[0].cat(*out[1:], dim=1)
|
||||
|
||||
class RetinaHead:
|
||||
def __init__(self, in_channels, num_anchors, num_classes):
|
||||
self.classification_head = ClassificationHead(in_channels, num_anchors, num_classes)
|
||||
self.regression_head = RegressionHead(in_channels, num_anchors)
|
||||
def __call__(self, x):
|
||||
pred_bbox, pred_class = self.regression_head(x), self.classification_head(x)
|
||||
out = pred_bbox.cat(pred_class, dim=-1)
|
||||
return out
|
||||
|
||||
class ResNetFPN:
|
||||
def __init__(self, resnet, out_channels=256, returned_layers=[2, 3, 4]):
|
||||
self.out_channels = out_channels
|
||||
self.body = resnet
|
||||
in_channels_list = [(self.body.in_planes // 8) * 2 ** (i - 1) for i in returned_layers]
|
||||
self.fpn = FPN(in_channels_list, out_channels)
|
||||
|
||||
# this is needed to decouple inference from postprocessing (anchors generation)
|
||||
def compute_grid_sizes(self, input_size):
|
||||
return np.ceil(np.array(input_size)[None, :] / 2 ** np.arange(3, 8)[:, None])
|
||||
|
||||
def __call__(self, x):
|
||||
out = self.body.bn1(self.body.conv1(x)).relu()
|
||||
out = out.pad2d([1,1,1,1]).max_pool2d((3,3), 2)
|
||||
out = out.sequential(self.body.layer1)
|
||||
p3 = out.sequential(self.body.layer2)
|
||||
p4 = p3.sequential(self.body.layer3)
|
||||
p5 = p4.sequential(self.body.layer4)
|
||||
return self.fpn([p3, p4, p5])
|
||||
|
||||
class ExtraFPNBlock:
|
||||
def __init__(self, in_channels, out_channels):
|
||||
self.p6 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1)
|
||||
self.p7 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=2, padding=1)
|
||||
self.use_P5 = in_channels == out_channels
|
||||
|
||||
def __call__(self, p, c):
|
||||
p5, c5 = p[-1], c[-1]
|
||||
x = p5 if self.use_P5 else c5
|
||||
p6 = self.p6(x)
|
||||
p7 = self.p7(p6.relu())
|
||||
p.extend([p6, p7])
|
||||
return p
|
||||
|
||||
class FPN:
|
||||
def __init__(self, in_channels_list, out_channels, extra_blocks=None):
|
||||
self.inner_blocks, self.layer_blocks = [], []
|
||||
for in_channels in in_channels_list:
|
||||
self.inner_blocks.append(nn.Conv2d(in_channels, out_channels, kernel_size=1))
|
||||
self.layer_blocks.append(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1))
|
||||
self.extra_blocks = ExtraFPNBlock(256, 256) if extra_blocks is None else extra_blocks
|
||||
|
||||
def __call__(self, x):
|
||||
last_inner = self.inner_blocks[-1](x[-1])
|
||||
results = [self.layer_blocks[-1](last_inner)]
|
||||
for idx in range(len(x) - 2, -1, -1):
|
||||
inner_lateral = self.inner_blocks[idx](x[idx])
|
||||
|
||||
# upsample to inner_lateral's shape
|
||||
(ih, iw), (oh, ow), prefix = last_inner.shape[-2:], inner_lateral.shape[-2:], last_inner.shape[:-2]
|
||||
eh, ew = math.ceil(oh / ih), math.ceil(ow / iw)
|
||||
inner_top_down = last_inner.reshape(*prefix, ih, 1, iw, 1).expand(*prefix, ih, eh, iw, ew).reshape(*prefix, ih*eh, iw*ew)[:, :, :oh, :ow]
|
||||
|
||||
last_inner = inner_lateral + inner_top_down
|
||||
results.insert(0, self.layer_blocks[idx](last_inner))
|
||||
if self.extra_blocks is not None:
|
||||
results = self.extra_blocks(results, x)
|
||||
return results
|
||||
|
||||
if __name__ == "__main__":
|
||||
from models.resnet import ResNeXt50_32X4D
|
||||
backbone = ResNeXt50_32X4D()
|
||||
retina = RetinaNet(backbone)
|
||||
retina.load_from_pretrained()
|
Loading…
Reference in New Issue