add retinanet with resnet backbone (#813)

* add retinanet with resnet backbone

* adds resnext to support loading retinanet pretrained on openimages
* object detection post processing with numpy
* data is downloaded and converted to coco format with fiftyone
* data loading and mAP evaluation with pycocotools

* remove fiftyone dep

* * eval freq

* fix model timing

* del jit for last batch

* faster accumulate
This commit is contained in:
Sohaib 2023-05-29 03:20:16 +00:00 committed by GitHub
parent 46327f7420
commit 65d09031f2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 489 additions and 18 deletions

165
datasets/openimages.py Normal file
View File

@ -0,0 +1,165 @@
import os
import math
import json
from extra.utils import OSX
import numpy as np
from PIL import Image
import pathlib
import boto3, botocore
from extra.utils import download_file
from tqdm import tqdm
import pandas as pd
import concurrent.futures
BASEDIR = pathlib.Path(__file__).parent.parent / "datasets/open-images-v6-mlperf"
BUCKET_NAME = "open-images-dataset"
BBOX_ANNOTATIONS_URL = "https://storage.googleapis.com/openimages/v5/validation-annotations-bbox.csv"
MAP_CLASSES_URL = "https://storage.googleapis.com/openimages/v5/class-descriptions-boxable.csv"
MLPERF_CLASSES = ['Airplane', 'Antelope', 'Apple', 'Backpack', 'Balloon', 'Banana',
'Barrel', 'Baseball bat', 'Baseball glove', 'Bee', 'Beer', 'Bench', 'Bicycle',
'Bicycle helmet', 'Bicycle wheel', 'Billboard', 'Book', 'Bookcase', 'Boot',
'Bottle', 'Bowl', 'Bowling equipment', 'Box', 'Boy', 'Brassiere', 'Bread',
'Broccoli', 'Bronze sculpture', 'Bull', 'Bus', 'Bust', 'Butterfly', 'Cabinetry',
'Cake', 'Camel', 'Camera', 'Candle', 'Candy', 'Cannon', 'Canoe', 'Carrot', 'Cart',
'Castle', 'Cat', 'Cattle', 'Cello', 'Chair', 'Cheese', 'Chest of drawers', 'Chicken',
'Christmas tree', 'Coat', 'Cocktail', 'Coffee', 'Coffee cup', 'Coffee table', 'Coin',
'Common sunflower', 'Computer keyboard', 'Computer monitor', 'Convenience store',
'Cookie', 'Countertop', 'Cowboy hat', 'Crab', 'Crocodile', 'Cucumber', 'Cupboard',
'Curtain', 'Deer', 'Desk', 'Dinosaur', 'Dog', 'Doll', 'Dolphin', 'Door', 'Dragonfly',
'Drawer', 'Dress', 'Drum', 'Duck', 'Eagle', 'Earrings', 'Egg (Food)', 'Elephant',
'Falcon', 'Fedora', 'Flag', 'Flowerpot', 'Football', 'Football helmet', 'Fork',
'Fountain', 'French fries', 'French horn', 'Frog', 'Giraffe', 'Girl', 'Glasses',
'Goat', 'Goggles', 'Goldfish', 'Gondola', 'Goose', 'Grape', 'Grapefruit', 'Guitar',
'Hamburger', 'Handbag', 'Harbor seal', 'Headphones', 'Helicopter', 'High heels',
'Hiking equipment', 'Horse', 'House', 'Houseplant', 'Human arm', 'Human beard',
'Human body', 'Human ear', 'Human eye', 'Human face', 'Human foot', 'Human hair',
'Human hand', 'Human head', 'Human leg', 'Human mouth', 'Human nose', 'Ice cream',
'Jacket', 'Jeans', 'Jellyfish', 'Juice', 'Kitchen & dining room table', 'Kite',
'Lamp', 'Lantern', 'Laptop', 'Lavender (Plant)', 'Lemon', 'Light bulb', 'Lighthouse',
'Lily', 'Lion', 'Lipstick', 'Lizard', 'Man', 'Maple', 'Microphone', 'Mirror',
'Mixing bowl', 'Mobile phone', 'Monkey', 'Motorcycle', 'Muffin', 'Mug', 'Mule',
'Mushroom', 'Musical keyboard', 'Necklace', 'Nightstand', 'Office building',
'Orange', 'Owl', 'Oyster', 'Paddle', 'Palm tree', 'Parachute', 'Parrot', 'Pen',
'Penguin', 'Personal flotation device', 'Piano', 'Picture frame', 'Pig', 'Pillow',
'Pizza', 'Plate', 'Platter', 'Porch', 'Poster', 'Pumpkin', 'Rabbit', 'Rifle',
'Roller skates', 'Rose', 'Salad', 'Sandal', 'Saucer', 'Saxophone', 'Scarf', 'Sea lion',
'Sea turtle', 'Sheep', 'Shelf', 'Shirt', 'Shorts', 'Shrimp', 'Sink', 'Skateboard',
'Ski', 'Skull', 'Skyscraper', 'Snake', 'Sock', 'Sofa bed', 'Sparrow', 'Spider', 'Spoon',
'Sports uniform', 'Squirrel', 'Stairs', 'Stool', 'Strawberry', 'Street light',
'Studio couch', 'Suit', 'Sun hat', 'Sunglasses', 'Surfboard', 'Sushi', 'Swan',
'Swimming pool', 'Swimwear', 'Tank', 'Tap', 'Taxi', 'Tea', 'Teddy bear', 'Television',
'Tent', 'Tie', 'Tiger', 'Tin can', 'Tire', 'Toilet', 'Tomato', 'Tortoise', 'Tower',
'Traffic light', 'Train', 'Tripod', 'Truck', 'Trumpet', 'Umbrella', 'Van', 'Vase',
'Vehicle registration plate', 'Violin', 'Wall clock', 'Waste container', 'Watch',
'Whale', 'Wheel', 'Wheelchair', 'Whiteboard', 'Window', 'Wine', 'Wine glass', 'Woman',
'Zebra', 'Zucchini',
]
def openimages():
ann_file = BASEDIR / "validation/labels/openimages-mlperf.json"
if not ann_file.is_file():
fetch_openimages(ann_file)
return ann_file
# this slows down the conversion a lot!
# maybe use https://raw.githubusercontent.com/scardine/image_size/master/get_image_size.py
def extract_dims(path): return Image.open(path).size[::-1]
def export_to_coco(class_map, annotations, image_list, dataset_path, output_path, classes=MLPERF_CLASSES):
output_path.parent.mkdir(parents=True, exist_ok=True)
cats = [{"id": i, "name": c, "supercategory": None} for i, c in enumerate(classes)]
categories_map = pd.DataFrame([(i, c) for i, c in enumerate(classes)], columns=["category_id", "category_name"])
class_map = class_map.merge(categories_map, left_on="DisplayName", right_on="category_name", how="inner")
annotations = annotations[np.isin(annotations["ImageID"], image_list)]
annotations = annotations.merge(class_map, on="LabelName", how="inner")
annotations["image_id"] = pd.factorize(annotations["ImageID"].tolist())[0]
annotations[["height", "width"]] = annotations.apply(lambda x: extract_dims(dataset_path / f"{x['ImageID']}.jpg"), axis=1, result_type="expand")
# Images
imgs = [{"id": int(id + 1), "file_name": f"{image_id}.jpg", "height": row["height"], "width": row["width"], "license": None, "coco_url": None}
for (id, image_id), row in (annotations.groupby(["image_id", "ImageID"]).first().iterrows())
]
# Annotations
annots = []
for i, row in annotations.iterrows():
xmin, ymin, xmax, ymax, img_w, img_h = [row[k] for k in ["XMin", "YMin", "XMax", "YMax", "width", "height"]]
x, y, w, h = xmin * img_w, ymin * img_h, (xmax - xmin) * img_w, (ymax - ymin) * img_h
coco_annot = {"id": int(i) + 1, "image_id": int(row["image_id"] + 1), "category_id": int(row["category_id"]), "bbox": [x, y, w, h], "area": w * h}
coco_annot.update({k: row[k] for k in ["IsOccluded", "IsInside", "IsDepiction", "IsTruncated", "IsGroupOf"]})
coco_annot["iscrowd"] = int(row["IsGroupOf"])
annots.append(coco_annot)
info = {"dataset": "openimages_mlperf", "version": "v6"}
coco_annotations = {"info": info, "licenses": [], "categories": cats, "images": imgs, "annotations": annots}
with open(output_path, "w") as fp:
json.dump(coco_annotations, fp)
def get_image_list(class_map, annotations, classes=MLPERF_CLASSES):
labels = class_map[np.isin(class_map["DisplayName"], classes)]["LabelName"]
image_ids = annotations[np.isin(annotations["LabelName"], labels)]["ImageID"].unique()
return image_ids
def download_image(bucket, image_id, data_dir):
try:
bucket.download_file(f"validation/{image_id}.jpg", f"{data_dir}/{image_id}.jpg")
except botocore.exceptions.ClientError as exception:
sys.exit(f"ERROR when downloading image `validation/{image_id}`: {str(exception)}")
def fetch_openimages(output_fn):
bucket = boto3.resource("s3", config=botocore.config.Config(signature_version=botocore.UNSIGNED)).Bucket(BUCKET_NAME)
annotations_dir, data_dir = BASEDIR / "annotations", BASEDIR / "validation/data"
annotations_dir.mkdir(parents=True, exist_ok=True)
data_dir.mkdir(parents=True, exist_ok=True)
annotations_fn = annotations_dir / BBOX_ANNOTATIONS_URL.split('/')[-1]
download_file(BBOX_ANNOTATIONS_URL, annotations_fn)
annotations = pd.read_csv(annotations_fn)
classmap_fn = annotations_dir / MAP_CLASSES_URL.split('/')[-1]
download_file(MAP_CLASSES_URL, classmap_fn)
class_map = pd.read_csv(classmap_fn, names=["LabelName", "DisplayName"])
image_list = get_image_list(class_map, annotations)
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = [executor.submit(download_image, bucket, image_id, data_dir) for image_id in image_list]
for future in (t := tqdm(concurrent.futures.as_completed(futures), total=len(image_list))):
t.set_description(f"Downloading images")
future.result()
print("Converting annotations to COCO format...")
export_to_coco(class_map, annotations, image_list, data_dir, output_fn)
def image_load(fn):
img_folder = BASEDIR / "validation/data"
img = Image.open(img_folder / fn).convert('RGB')
import torchvision.transforms.functional as F
ret = F.resize(img, size=(800, 800))
ret = np.array(ret)
return ret, img.size[::-1]
def prepare_target(annotations, img_id, img_size):
boxes = [annot["bbox"] for annot in annotations]
boxes = np.array(boxes, dtype=np.float32).reshape(-1, 4)
boxes[:, 2:] += boxes[:, :2]
boxes[:, 0::2] = boxes[:, 0::2].clip(0, img_size[1])
boxes[:, 1::2] = boxes[:, 1::2].clip(0, img_size[0])
keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
boxes = boxes[keep]
classes = [annot["category_id"] for annot in annotations]
classes = np.array(classes, dtype=np.int64)
classes = classes[keep]
return {"boxes": boxes, "labels": classes, "image_id": img_id, "image_size": img_size}
def iterate(coco, bs=8):
image_ids = sorted(coco.imgs.keys())
for i in range(0, len(image_ids), bs):
X, targets = [], []
for img_id in image_ids[i:i+bs]:
x, original_size = image_load(coco.loadImgs(img_id)[0]["file_name"])
X.append(x)
annotations = coco.loadAnns(coco.getAnnIds(img_id))
targets.append(prepare_target(annotations, img_id, original_size))
yield np.array(X), targets

View File

@ -42,6 +42,64 @@ def eval_resnet():
print(f"****** {n}/{d} {n*100.0/d:.2f}%") print(f"****** {n}/{d} {n*100.0/d:.2f}%")
st = time.perf_counter() st = time.perf_counter()
def eval_retinanet():
# RetinaNet with ResNeXt50_32X4D
from models.resnet import ResNeXt50_32X4D
from models.retinanet import RetinaNet
mdl = RetinaNet(ResNeXt50_32X4D())
mdl.load_from_pretrained()
input_mean = Tensor([0.485, 0.456, 0.406]).reshape(1, -1, 1, 1)
input_std = Tensor([0.229, 0.224, 0.225]).reshape(1, -1, 1, 1)
def input_fixup(x):
x = x.permute([0,3,1,2]) / 255.0
x -= input_mean
x /= input_std
return x
from datasets.openimages import openimages, iterate
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
from contextlib import redirect_stdout
coco = COCO(openimages())
coco_eval = COCOeval(coco, iouType="bbox")
coco_evalimgs, evaluated_imgs, ncats, narea = [], [], len(coco_eval.params.catIds), len(coco_eval.params.areaRng)
from tinygrad.jit import TinyJit
mdlrun = TinyJit(lambda x: mdl(input_fixup(x)).realize())
n, bs = 0, 8
st = time.perf_counter()
for x, targets in iterate(coco, bs):
dat = Tensor(x.astype(np.float32))
mt = time.perf_counter()
if dat.shape[0] == bs:
outs = mdlrun(dat).numpy()
else:
mdlrun.jit_cache = None
outs = mdl(input_fixup(dat)).numpy()
et = time.perf_counter()
predictions = mdl.postprocess_detections(outs, input_size=dat.shape[1:3], orig_image_sizes=[t["image_size"] for t in targets])
ext = time.perf_counter()
n += len(targets)
print(f"[{n}/{len(coco.imgs)}] == {(mt-st)*1000:.2f} ms loading data, {(et-mt)*1000:.2f} ms to run model, {(ext-et)*1000:.2f} ms for postprocessing")
img_ids = [t["image_id"] for t in targets]
coco_results = [{"image_id": targets[i]["image_id"], "category_id": label, "bbox": box, "score": score}
for i, prediction in enumerate(predictions) for box, score, label in zip(*prediction.values())]
with redirect_stdout(None):
coco_eval.cocoDt = coco.loadRes(coco_results)
coco_eval.params.imgIds = img_ids
coco_eval.evaluate()
evaluated_imgs.extend(img_ids)
coco_evalimgs.append(np.array(coco_eval.evalImgs).reshape(ncats, narea, len(img_ids)))
st = time.perf_counter()
coco_eval.params.imgIds = evaluated_imgs
coco_eval._paramsEval.imgIds = evaluated_imgs
coco_eval.evalImgs = list(np.concatenate(coco_evalimgs, -1).flatten())
coco_eval.accumulate()
coco_eval.summarize()
def eval_rnnt(): def eval_rnnt():
# RNN-T # RNN-T
from models.rnnt import RNNT from models.rnnt import RNNT

View File

@ -17,8 +17,12 @@ def spec_resnet():
test_model(mdl, img) test_model(mdl, img)
def spec_retinanet(): def spec_retinanet():
# TODO: Retinanet # Retinanet with ResNet backbone
pass from models.resnet import ResNet50
from models.retinanet import RetinaNet
mdl = RetinaNet(ResNet50(), num_classes=91, num_anchors=9)
img = Tensor.randn(1, 3, 224, 224)
test_model(mdl, img)
def spec_unet3d(): def spec_unet3d():
# 3D UNET # 3D UNET

View File

@ -5,7 +5,8 @@ from extra.utils import get_child
class BasicBlock: class BasicBlock:
expansion = 1 expansion = 1
def __init__(self, in_planes, planes, stride=1): def __init__(self, in_planes, planes, stride=1, groups=1, base_width=64):
assert groups == 1 and base_width == 64, "BasicBlock only supports groups=1 and base_width=64"
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes) self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, stride=1, bias=False) self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, stride=1, bias=False)
@ -29,12 +30,13 @@ class Bottleneck:
# NOTE: the original implementation places stride at the first convolution (self.conv1), this is the v1.5 variant # NOTE: the original implementation places stride at the first convolution (self.conv1), this is the v1.5 variant
expansion = 4 expansion = 4
def __init__(self, in_planes, planes, stride=1): def __init__(self, in_planes, planes, stride=1, groups=1, base_width=64):
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) width = int(planes * (base_width / 64.0)) * groups
self.bn1 = nn.BatchNorm2d(planes) self.conv1 = nn.Conv2d(in_planes, width, kernel_size=1, bias=False)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, stride=stride, bias=False) self.bn1 = nn.BatchNorm2d(width)
self.bn2 = nn.BatchNorm2d(planes) self.conv2 = nn.Conv2d(width, width, kernel_size=3, padding=1, stride=stride, groups=groups, bias=False)
self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) self.bn2 = nn.BatchNorm2d(width)
self.conv3 = nn.Conv2d(width, self.expansion*planes, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(self.expansion*planes) self.bn3 = nn.BatchNorm2d(self.expansion*planes)
self.downsample = [] self.downsample = []
if stride != 1 or in_planes != self.expansion*planes: if stride != 1 or in_planes != self.expansion*planes:
@ -52,7 +54,7 @@ class Bottleneck:
return out return out
class ResNet: class ResNet:
def __init__(self, num, num_classes): def __init__(self, num, num_classes, groups=1, width_per_group=64):
self.num = num self.num = num
self.block = { self.block = {
@ -73,6 +75,8 @@ class ResNet:
self.in_planes = 64 self.in_planes = 64
self.groups = groups
self.base_width = width_per_group
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, bias=False, padding=3) self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, bias=False, padding=3)
self.bn1 = nn.BatchNorm2d(64) self.bn1 = nn.BatchNorm2d(64)
self.layer1 = self._make_layer(self.block, 64, self.num_blocks[0], stride=1) self.layer1 = self._make_layer(self.block, 64, self.num_blocks[0], stride=1)
@ -85,7 +89,7 @@ class ResNet:
strides = [stride] + [1] * (num_blocks-1) strides = [stride] + [1] * (num_blocks-1)
layers = [] layers = []
for stride in strides: for stride in strides:
layers.append(block(self.in_planes, planes, stride)) layers.append(block(self.in_planes, planes, stride, self.groups, self.base_width))
self.in_planes = planes * block.expansion self.in_planes = planes * block.expansion
return layers return layers
@ -107,14 +111,15 @@ class ResNet:
# TODO replace with fake torch load # TODO replace with fake torch load
model_urls = { model_urls = {
18: 'https://download.pytorch.org/models/resnet18-5c106cde.pth', (18, 1, 64): 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
34: 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', (34, 1, 64): 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
50: 'https://download.pytorch.org/models/resnet50-19c8e357.pth', (50, 1, 64): 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
101: 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', (50, 32, 4): 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
152: 'https://download.pytorch.org/models/resnet152-b121ed2d.pth' (101, 1, 64): 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
(152, 1, 64): 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
} }
self.url = model_urls[self.num] self.url = model_urls[(self.num, self.groups, self.base_width)]
from torch.hub import load_state_dict_from_url from torch.hub import load_state_dict_from_url
state_dict = load_state_dict_from_url(self.url, progress=True) state_dict = load_state_dict_from_url(self.url, progress=True)
@ -126,7 +131,8 @@ class ResNet:
print("skipping fully connected layer") print("skipping fully connected layer")
continue # Skip FC if transfer learning continue # Skip FC if transfer learning
assert obj.shape == dat.shape, (k, obj.shape, dat.shape) # TODO: remove or when #777 is merged
assert obj.shape == dat.shape or (obj.shape == (1,) and dat.shape == ()), (k, obj.shape, dat.shape)
obj.assign(dat) obj.assign(dat)
ResNet18 = lambda num_classes=1000: ResNet(18, num_classes=num_classes) ResNet18 = lambda num_classes=1000: ResNet(18, num_classes=num_classes)
@ -134,3 +140,4 @@ ResNet34 = lambda num_classes=1000: ResNet(34, num_classes=num_classes)
ResNet50 = lambda num_classes=1000: ResNet(50, num_classes=num_classes) ResNet50 = lambda num_classes=1000: ResNet(50, num_classes=num_classes)
ResNet101 = lambda num_classes=1000: ResNet(101, num_classes=num_classes) ResNet101 = lambda num_classes=1000: ResNet(101, num_classes=num_classes)
ResNet152 = lambda num_classes=1000: ResNet(152, num_classes=num_classes) ResNet152 = lambda num_classes=1000: ResNet(152, num_classes=num_classes)
ResNeXt50_32X4D = lambda num_classes=1000: ResNet(50, num_classes=num_classes, groups=32, width_per_group=4)

237
models/retinanet.py Normal file
View File

@ -0,0 +1,237 @@
import math
from tinygrad.helpers import flatten
import tinygrad.nn as nn
from models.resnet import ResNet
from extra.utils import get_child
import numpy as np
def nms(boxes, scores, thresh=0.5):
x1, y1, x2, y2 = np.rollaxis(boxes, 1)
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
to_process, keep = scores.argsort()[::-1], []
while to_process.size > 0:
cur, to_process = to_process[0], to_process[1:]
keep.append(cur)
inter_x1 = np.maximum(x1[cur], x1[to_process])
inter_y1 = np.maximum(y1[cur], y1[to_process])
inter_x2 = np.minimum(x2[cur], x2[to_process])
inter_y2 = np.minimum(y2[cur], y2[to_process])
inter_area = np.maximum(0, inter_x2 - inter_x1 + 1) * np.maximum(0, inter_y2 - inter_y1 + 1)
iou = inter_area / (areas[cur] + areas[to_process] - inter_area)
to_process = to_process[np.where(iou <= thresh)[0]]
return keep
def decode_bbox(offsets, anchors):
dx, dy, dw, dh = np.rollaxis(offsets, 1)
widths, heights = anchors[:, 2] - anchors[:, 0], anchors[:, 3] - anchors[:, 1]
cx, cy = anchors[:, 0] + 0.5 * widths, anchors[:, 1] + 0.5 * heights
pred_cx, pred_cy = dx * widths + cx, dy * heights + cy
pred_w, pred_h = np.exp(dw) * widths, np.exp(dh) * heights
pred_x1, pred_y1 = pred_cx - 0.5 * pred_w, pred_cy - 0.5 * pred_h
pred_x2, pred_y2 = pred_cx + 0.5 * pred_w, pred_cy + 0.5 * pred_h
return np.stack([pred_x1, pred_y1, pred_x2, pred_y2], axis=1, dtype=np.float32)
def generate_anchors(input_size, grid_sizes, scales, aspect_ratios):
assert len(scales) == len(aspect_ratios) == len(grid_sizes)
anchors = []
for s, ar, gs in zip(scales, aspect_ratios, grid_sizes):
s, ar = np.array(s), np.array(ar)
h_ratios = np.sqrt(ar)
w_ratios = 1 / h_ratios
ws = (w_ratios[:, None] * s[None, :]).reshape(-1)
hs = (h_ratios[:, None] * s[None, :]).reshape(-1)
base_anchors = (np.stack([-ws, -hs, ws, hs], axis=1) / 2).round()
stride_h, stride_w = input_size[0] // gs[0], input_size[1] // gs[1]
shifts_x, shifts_y = np.meshgrid(np.arange(gs[1]) * stride_w, np.arange(gs[0]) * stride_h)
shifts_x = shifts_x.reshape(-1)
shifts_y = shifts_y.reshape(-1)
shifts = np.stack([shifts_x, shifts_y, shifts_x, shifts_y], axis=1, dtype=np.float32)
anchors.append((shifts[:, None] + base_anchors[None, :]).reshape(-1, 4))
return anchors
class RetinaNet:
def __init__(self, backbone: ResNet, num_classes=264, num_anchors=9, scales=None, aspect_ratios=None):
assert isinstance(backbone, ResNet)
scales = tuple((i, int(i*2**(1/3)), int(i*2**(2/3))) for i in 2**np.arange(5, 10)) if scales is None else scales
aspect_ratios = ((0.5, 1.0, 2.0),) * len(scales) if aspect_ratios is None else aspect_ratios
self.num_anchors, self.num_classes = num_anchors, num_classes
assert len(scales) == len(aspect_ratios) and all([self.num_anchors == len(s) * len(ar) for s, ar in zip(scales, aspect_ratios)])
self.backbone = ResNetFPN(backbone)
self.head = RetinaHead(self.backbone.out_channels, num_anchors=num_anchors, num_classes=num_classes)
self.anchor_gen = lambda input_size: generate_anchors(input_size, self.backbone.compute_grid_sizes(input_size), scales, aspect_ratios)
def __call__(self, x):
return self.forward(x)
def forward(self, x):
return self.head(self.backbone(x))
def load_from_pretrained(self):
model_urls = {
(50, 1, 64): "https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth",
(50, 32, 4): "https://zenodo.org/record/6605272/files/retinanet_model_10.zip",
}
self.url = model_urls[(self.backbone.body.num, self.backbone.body.groups, self.backbone.body.base_width)]
from torch.hub import load_state_dict_from_url
state_dict = load_state_dict_from_url(self.url, progress=True, map_location='cpu')
state_dict = state_dict['model'] if 'model' in state_dict.keys() else state_dict
for k, v in state_dict.items():
obj = get_child(self, k)
dat = v.detach().numpy()
assert obj.shape == dat.shape, (k, obj.shape, dat.shape)
obj.assign(dat)
# predictions: (BS, (H1W1+...+HmWm)A, 4 + K)
def postprocess_detections(self, predictions, input_size=(800, 800), image_sizes=None, orig_image_sizes=None, score_thresh=0.05, topk_candidates=1000, nms_thresh=0.5):
anchors = self.anchor_gen(input_size)
grid_sizes = self.backbone.compute_grid_sizes(input_size)
split_idx = np.cumsum([int(self.num_anchors * sz[0] * sz[1]) for sz in grid_sizes[:-1]])
detections = []
for i, predictions_per_image in enumerate(predictions):
h, w = input_size if image_sizes is None else image_sizes[i]
predictions_per_image = np.split(predictions_per_image, split_idx)
offsets_per_image = [br[:, :4] for br in predictions_per_image]
scores_per_image = [cl[:, 4:] for cl in predictions_per_image]
image_boxes, image_scores, image_labels = [], [], []
for offsets_per_level, scores_per_level, anchors_per_level in zip(offsets_per_image, scores_per_image, anchors):
# remove low scoring boxes
scores_per_level = scores_per_level.flatten()
keep_idxs = scores_per_level > score_thresh
scores_per_level = scores_per_level[keep_idxs]
# keep topk
topk_idxs = np.where(keep_idxs)[0]
num_topk = min(len(topk_idxs), topk_candidates)
sort_idxs = scores_per_level.argsort()[-num_topk:][::-1]
topk_idxs, scores_per_level = topk_idxs[sort_idxs], scores_per_level[sort_idxs]
# bbox coords from offsets
anchor_idxs = topk_idxs // self.num_classes
labels_per_level = topk_idxs % self.num_classes
boxes_per_level = decode_bbox(offsets_per_level[anchor_idxs], anchors_per_level[anchor_idxs])
# clip to image size
clipped_x = boxes_per_level[:, 0::2].clip(0, w)
clipped_y = boxes_per_level[:, 1::2].clip(0, h)
boxes_per_level = np.stack([clipped_x, clipped_y], axis=2).reshape(-1, 4)
image_boxes.append(boxes_per_level)
image_scores.append(scores_per_level)
image_labels.append(labels_per_level)
image_boxes = np.concatenate(image_boxes)
image_scores = np.concatenate(image_scores)
image_labels = np.concatenate(image_labels)
# nms for each class
keep_mask = np.zeros_like(image_scores, dtype=bool)
for class_id in np.unique(image_labels):
curr_indices = np.where(image_labels == class_id)[0]
curr_keep_indices = nms(image_boxes[curr_indices], image_scores[curr_indices], nms_thresh)
keep_mask[curr_indices[curr_keep_indices]] = True
keep = np.where(keep_mask)[0]
keep = keep[image_scores[keep].argsort()[::-1]]
# resize bboxes back to original size
image_boxes = image_boxes[keep]
if orig_image_sizes is not None:
resized_x = image_boxes[:, 0::2] * orig_image_sizes[i][1] / w
resized_y = image_boxes[:, 1::2] * orig_image_sizes[i][0] / h
image_boxes = np.stack([resized_x, resized_y], axis=2).reshape(-1, 4)
# xywh format
image_boxes = np.concatenate([image_boxes[:, :2], image_boxes[:, 2:] - image_boxes[:, :2]], axis=1)
detections.append({"boxes":image_boxes, "scores":image_scores[keep], "labels":image_labels[keep]})
return detections
class ClassificationHead:
def __init__(self, in_channels, num_anchors, num_classes):
self.num_classes = num_classes
self.conv = flatten([(nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1), lambda x: x.relu()) for _ in range(4)])
self.cls_logits = nn.Conv2d(in_channels, num_anchors * num_classes, kernel_size=3, padding=1)
def __call__(self, x):
out = [self.cls_logits(feat.sequential(self.conv)).permute(0, 2, 3, 1).reshape(feat.shape[0], -1, self.num_classes) for feat in x]
return out[0].cat(*out[1:], dim=1).sigmoid()
class RegressionHead:
def __init__(self, in_channels, num_anchors):
self.conv = flatten([(nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1), lambda x: x.relu()) for _ in range(4)])
self.bbox_reg = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=3, padding=1)
def __call__(self, x):
out = [self.bbox_reg(feat.sequential(self.conv)).permute(0, 2, 3, 1).reshape(feat.shape[0], -1, 4) for feat in x]
return out[0].cat(*out[1:], dim=1)
class RetinaHead:
def __init__(self, in_channels, num_anchors, num_classes):
self.classification_head = ClassificationHead(in_channels, num_anchors, num_classes)
self.regression_head = RegressionHead(in_channels, num_anchors)
def __call__(self, x):
pred_bbox, pred_class = self.regression_head(x), self.classification_head(x)
out = pred_bbox.cat(pred_class, dim=-1)
return out
class ResNetFPN:
def __init__(self, resnet, out_channels=256, returned_layers=[2, 3, 4]):
self.out_channels = out_channels
self.body = resnet
in_channels_list = [(self.body.in_planes // 8) * 2 ** (i - 1) for i in returned_layers]
self.fpn = FPN(in_channels_list, out_channels)
# this is needed to decouple inference from postprocessing (anchors generation)
def compute_grid_sizes(self, input_size):
return np.ceil(np.array(input_size)[None, :] / 2 ** np.arange(3, 8)[:, None])
def __call__(self, x):
out = self.body.bn1(self.body.conv1(x)).relu()
out = out.pad2d([1,1,1,1]).max_pool2d((3,3), 2)
out = out.sequential(self.body.layer1)
p3 = out.sequential(self.body.layer2)
p4 = p3.sequential(self.body.layer3)
p5 = p4.sequential(self.body.layer4)
return self.fpn([p3, p4, p5])
class ExtraFPNBlock:
def __init__(self, in_channels, out_channels):
self.p6 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1)
self.p7 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=2, padding=1)
self.use_P5 = in_channels == out_channels
def __call__(self, p, c):
p5, c5 = p[-1], c[-1]
x = p5 if self.use_P5 else c5
p6 = self.p6(x)
p7 = self.p7(p6.relu())
p.extend([p6, p7])
return p
class FPN:
def __init__(self, in_channels_list, out_channels, extra_blocks=None):
self.inner_blocks, self.layer_blocks = [], []
for in_channels in in_channels_list:
self.inner_blocks.append(nn.Conv2d(in_channels, out_channels, kernel_size=1))
self.layer_blocks.append(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1))
self.extra_blocks = ExtraFPNBlock(256, 256) if extra_blocks is None else extra_blocks
def __call__(self, x):
last_inner = self.inner_blocks[-1](x[-1])
results = [self.layer_blocks[-1](last_inner)]
for idx in range(len(x) - 2, -1, -1):
inner_lateral = self.inner_blocks[idx](x[idx])
# upsample to inner_lateral's shape
(ih, iw), (oh, ow), prefix = last_inner.shape[-2:], inner_lateral.shape[-2:], last_inner.shape[:-2]
eh, ew = math.ceil(oh / ih), math.ceil(ow / iw)
inner_top_down = last_inner.reshape(*prefix, ih, 1, iw, 1).expand(*prefix, ih, eh, iw, ew).reshape(*prefix, ih*eh, iw*ew)[:, :, :oh, :ow]
last_inner = inner_lateral + inner_top_down
results.insert(0, self.layer_blocks[idx](last_inner))
if self.extra_blocks is not None:
results = self.extra_blocks(results, x)
return results
if __name__ == "__main__":
from models.resnet import ResNeXt50_32X4D
backbone = ResNeXt50_32X4D()
retina = RetinaNet(backbone)
retina.load_from_pretrained()