mirror of https://github.com/commaai/tinygrad.git
1272 lines
41 KiB
Python
1272 lines
41 KiB
Python
import re
|
|
import math
|
|
import os
|
|
import numpy as np
|
|
from pathlib import Path
|
|
from tinygrad import nn, Tensor, dtypes
|
|
from tinygrad.tensor import _to_np_dtype
|
|
from tinygrad.helpers import get_child, fetch
|
|
from tinygrad.nn.state import torch_load
|
|
from extra.models.resnet import ResNet
|
|
from extra.models.retinanet import nms as _box_nms
|
|
|
|
USE_NP_GATHER = os.getenv('FULL_TINYGRAD', '0') == '0'
|
|
|
|
def rint(tensor):
|
|
x = (tensor*2).cast(dtypes.int32).contiguous().cast(dtypes.float32)/2
|
|
return (x<0).where(x.floor(), x.ceil())
|
|
|
|
def nearest_interpolate(tensor, scale_factor):
|
|
bs, c, py, px = tensor.shape
|
|
return tensor.reshape(bs, c, py, 1, px, 1).expand(bs, c, py, scale_factor, px, scale_factor).reshape(bs, c, py * scale_factor, px * scale_factor)
|
|
|
|
def meshgrid(x, y):
|
|
grid_x = Tensor.cat(*[x[idx:idx+1].expand(y.shape).unsqueeze(0) for idx in range(x.shape[0])])
|
|
grid_y = Tensor.cat(*[y.unsqueeze(0)]*x.shape[0])
|
|
return grid_x.reshape(-1, 1), grid_y.reshape(-1, 1)
|
|
|
|
def topk(input_, k, dim=-1, largest=True, sorted=False):
|
|
k = min(k, input_.shape[dim]-1)
|
|
input_ = input_.numpy()
|
|
if largest: input_ *= -1
|
|
ind = np.argpartition(input_, k, axis=dim)
|
|
if largest: input_ *= -1
|
|
ind = np.take(ind, np.arange(k), axis=dim) # k non-sorted indices
|
|
input_ = np.take_along_axis(input_, ind, axis=dim) # k non-sorted values
|
|
if not sorted: return Tensor(input_), ind
|
|
if largest: input_ *= -1
|
|
ind_part = np.argsort(input_, axis=dim)
|
|
ind = np.take_along_axis(ind, ind_part, axis=dim)
|
|
if largest: input_ *= -1
|
|
val = np.take_along_axis(input_, ind_part, axis=dim)
|
|
return Tensor(val), ind
|
|
|
|
# This is very slow for large arrays, or indices
|
|
def _gather(array, indices):
|
|
indices = indices.float().to(array.device)
|
|
reshape_arg = [1]*array.ndim + [array.shape[-1]]
|
|
return Tensor.where(
|
|
indices.unsqueeze(indices.ndim).expand(*indices.shape, array.shape[-1]) == Tensor.arange(array.shape[-1]).reshape(*reshape_arg).expand(*indices.shape, array.shape[-1]),
|
|
array, 0,
|
|
).sum(indices.ndim)
|
|
|
|
# TODO: replace npgather with a faster gather using tinygrad only
|
|
# NOTE: this blocks the gradient
|
|
def npgather(array,indices):
|
|
if isinstance(array, Tensor): array = array.numpy()
|
|
if isinstance(indices, Tensor): indices = indices.numpy()
|
|
if isinstance(indices, list): indices = np.asarray(indices)
|
|
return Tensor(array[indices.astype(int)])
|
|
|
|
def get_strides(shape):
|
|
prod = [1]
|
|
for idx in range(len(shape)-1, -1, -1): prod.append(prod[-1] * shape[idx])
|
|
# something about ints is broken with gpu, cuda
|
|
return Tensor(prod[::-1][1:], dtype=dtypes.int32).unsqueeze(0)
|
|
|
|
# with keys as integer array for all axes
|
|
def tensor_getitem(tensor, *keys):
|
|
# something about ints is broken with gpu, cuda
|
|
flat_keys = Tensor.stack(*[key.expand((sum(keys)).shape).reshape(-1) for key in keys], dim=1).cast(dtypes.int32)
|
|
strides = get_strides(tensor.shape)
|
|
idxs = (flat_keys * strides).sum(1)
|
|
gatherer = npgather if USE_NP_GATHER else _gather
|
|
return gatherer(tensor.reshape(-1), idxs).reshape(sum(keys).shape)
|
|
|
|
|
|
# for gather with indicies only on axis=0
|
|
def tensor_gather(tensor, indices):
|
|
if not isinstance(indices, Tensor):
|
|
indices = Tensor(indices, requires_grad=False)
|
|
if len(tensor.shape) > 2:
|
|
rem_shape = list(tensor.shape)[1:]
|
|
tensor = tensor.reshape(tensor.shape[0], -1)
|
|
else:
|
|
rem_shape = None
|
|
if len(tensor.shape) > 1:
|
|
tensor = tensor.T
|
|
repeat_arg = [1]*(tensor.ndim-1) + [tensor.shape[-2]]
|
|
indices = indices.unsqueeze(indices.ndim).repeat(repeat_arg)
|
|
ret = _gather(tensor, indices)
|
|
if rem_shape:
|
|
ret = ret.reshape([indices.shape[0]] + rem_shape)
|
|
else:
|
|
ret = _gather(tensor, indices)
|
|
del indices
|
|
return ret
|
|
|
|
|
|
class LastLevelMaxPool:
|
|
def __call__(self, x): return [Tensor.max_pool2d(x, 1, 2)]
|
|
|
|
|
|
# transpose
|
|
FLIP_LEFT_RIGHT = 0
|
|
FLIP_TOP_BOTTOM = 1
|
|
|
|
|
|
def permute_and_flatten(layer:Tensor, N, A, C, H, W):
|
|
layer = layer.reshape(N, -1, C, H, W)
|
|
layer = layer.permute(0, 3, 4, 1, 2)
|
|
layer = layer.reshape(N, -1, C)
|
|
return layer
|
|
|
|
|
|
class BoxList:
|
|
def __init__(self, bbox, image_size, mode="xyxy"):
|
|
if not isinstance(bbox, Tensor):
|
|
bbox = Tensor(bbox)
|
|
if bbox.ndim != 2:
|
|
raise ValueError(
|
|
"bbox should have 2 dimensions, got {}".format(bbox.ndim)
|
|
)
|
|
if bbox.shape[-1] != 4:
|
|
raise ValueError(
|
|
"last dimenion of bbox should have a "
|
|
"size of 4, got {}".format(bbox.shape[-1])
|
|
)
|
|
if mode not in ("xyxy", "xywh"):
|
|
raise ValueError("mode should be 'xyxy' or 'xywh'")
|
|
|
|
self.bbox = bbox
|
|
self.size = image_size # (image_width, image_height)
|
|
self.mode = mode
|
|
self.extra_fields = {}
|
|
|
|
def __repr__(self):
|
|
s = self.__class__.__name__ + "("
|
|
s += "num_boxes={}, ".format(len(self))
|
|
s += "image_width={}, ".format(self.size[0])
|
|
s += "image_height={}, ".format(self.size[1])
|
|
s += "mode={})".format(self.mode)
|
|
return s
|
|
|
|
def area(self):
|
|
box = self.bbox
|
|
if self.mode == "xyxy":
|
|
TO_REMOVE = 1
|
|
area = (box[:, 2] - box[:, 0] + TO_REMOVE) * (box[:, 3] - box[:, 1] + TO_REMOVE)
|
|
elif self.mode == "xywh":
|
|
area = box[:, 2] * box[:, 3]
|
|
return area
|
|
|
|
def add_field(self, field, field_data):
|
|
self.extra_fields[field] = field_data
|
|
|
|
def get_field(self, field):
|
|
return self.extra_fields[field]
|
|
|
|
def has_field(self, field):
|
|
return field in self.extra_fields
|
|
|
|
def fields(self):
|
|
return list(self.extra_fields.keys())
|
|
|
|
def _copy_extra_fields(self, bbox):
|
|
for k, v in bbox.extra_fields.items():
|
|
self.extra_fields[k] = v
|
|
|
|
def convert(self, mode):
|
|
if mode == self.mode:
|
|
return self
|
|
xmin, ymin, xmax, ymax = self._split_into_xyxy()
|
|
if mode == "xyxy":
|
|
bbox = Tensor.cat(*(xmin, ymin, xmax, ymax), dim=-1)
|
|
bbox = BoxList(bbox, self.size, mode=mode)
|
|
else:
|
|
TO_REMOVE = 1
|
|
bbox = Tensor.cat(
|
|
*(xmin, ymin, xmax - xmin + TO_REMOVE, ymax - ymin + TO_REMOVE), dim=-1
|
|
)
|
|
bbox = BoxList(bbox, self.size, mode=mode)
|
|
bbox._copy_extra_fields(self)
|
|
return bbox
|
|
|
|
def _split_into_xyxy(self):
|
|
if self.mode == "xyxy":
|
|
xmin, ymin, xmax, ymax = self.bbox.chunk(4, dim=-1)
|
|
return xmin, ymin, xmax, ymax
|
|
if self.mode == "xywh":
|
|
TO_REMOVE = 1
|
|
xmin, ymin, w, h = self.bbox.chunk(4, dim=-1)
|
|
return (
|
|
xmin,
|
|
ymin,
|
|
xmin + (w - TO_REMOVE).clamp(min=0),
|
|
ymin + (h - TO_REMOVE).clamp(min=0),
|
|
)
|
|
|
|
def resize(self, size, *args, **kwargs):
|
|
ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(size, self.size))
|
|
if ratios[0] == ratios[1]:
|
|
ratio = ratios[0]
|
|
scaled_box = self.bbox * ratio
|
|
bbox = BoxList(scaled_box, size, mode=self.mode)
|
|
for k, v in self.extra_fields.items():
|
|
if not isinstance(v, Tensor):
|
|
v = v.resize(size, *args, **kwargs)
|
|
bbox.add_field(k, v)
|
|
return bbox
|
|
|
|
ratio_width, ratio_height = ratios
|
|
xmin, ymin, xmax, ymax = self._split_into_xyxy()
|
|
scaled_xmin = xmin * ratio_width
|
|
scaled_xmax = xmax * ratio_width
|
|
scaled_ymin = ymin * ratio_height
|
|
scaled_ymax = ymax * ratio_height
|
|
scaled_box = Tensor.cat(
|
|
*(scaled_xmin, scaled_ymin, scaled_xmax, scaled_ymax), dim=-1
|
|
)
|
|
bbox = BoxList(scaled_box, size, mode="xyxy")
|
|
for k, v in self.extra_fields.items():
|
|
if not isinstance(v, Tensor):
|
|
v = v.resize(size, *args, **kwargs)
|
|
bbox.add_field(k, v)
|
|
|
|
return bbox.convert(self.mode)
|
|
|
|
def transpose(self, method):
|
|
image_width, image_height = self.size
|
|
xmin, ymin, xmax, ymax = self._split_into_xyxy()
|
|
if method == FLIP_LEFT_RIGHT:
|
|
TO_REMOVE = 1
|
|
transposed_xmin = image_width - xmax - TO_REMOVE
|
|
transposed_xmax = image_width - xmin - TO_REMOVE
|
|
transposed_ymin = ymin
|
|
transposed_ymax = ymax
|
|
elif method == FLIP_TOP_BOTTOM:
|
|
transposed_xmin = xmin
|
|
transposed_xmax = xmax
|
|
transposed_ymin = image_height - ymax
|
|
transposed_ymax = image_height - ymin
|
|
|
|
transposed_boxes = Tensor.cat(
|
|
*(transposed_xmin, transposed_ymin, transposed_xmax, transposed_ymax), dim=-1
|
|
)
|
|
bbox = BoxList(transposed_boxes, self.size, mode="xyxy")
|
|
for k, v in self.extra_fields.items():
|
|
if not isinstance(v, Tensor):
|
|
v = v.transpose(method)
|
|
bbox.add_field(k, v)
|
|
return bbox.convert(self.mode)
|
|
|
|
def clip_to_image(self, remove_empty=True):
|
|
TO_REMOVE = 1
|
|
bb1 = self.bbox.clip(min_=0, max_=self.size[0] - TO_REMOVE)[:, 0]
|
|
bb2 = self.bbox.clip(min_=0, max_=self.size[1] - TO_REMOVE)[:, 1]
|
|
bb3 = self.bbox.clip(min_=0, max_=self.size[0] - TO_REMOVE)[:, 2]
|
|
bb4 = self.bbox.clip(min_=0, max_=self.size[1] - TO_REMOVE)[:, 3]
|
|
self.bbox = Tensor.stack(bb1, bb2, bb3, bb4, dim=1)
|
|
if remove_empty:
|
|
box = self.bbox
|
|
keep = (box[:, 3] > box[:, 1]) & (box[:, 2] > box[:, 0])
|
|
return self[keep]
|
|
return self
|
|
|
|
def __getitem__(self, item):
|
|
if isinstance(item, list):
|
|
if len(item) == 0:
|
|
return []
|
|
if sum(item) == len(item) and isinstance(item[0], bool):
|
|
return self
|
|
bbox = BoxList(tensor_gather(self.bbox, item), self.size, self.mode)
|
|
for k, v in self.extra_fields.items():
|
|
bbox.add_field(k, tensor_gather(v, item))
|
|
return bbox
|
|
|
|
def __len__(self):
|
|
return self.bbox.shape[0]
|
|
|
|
|
|
def cat_boxlist(bboxes):
|
|
size = bboxes[0].size
|
|
mode = bboxes[0].mode
|
|
fields = set(bboxes[0].fields())
|
|
cat_box_list = [bbox.bbox for bbox in bboxes if bbox.bbox.shape[0] > 0]
|
|
|
|
if len(cat_box_list) > 0:
|
|
cat_boxes = BoxList(Tensor.cat(*cat_box_list, dim=0), size, mode)
|
|
else:
|
|
cat_boxes = BoxList(bboxes[0].bbox, size, mode)
|
|
for field in fields:
|
|
cat_field_list = [bbox.get_field(field) for bbox in bboxes if bbox.get_field(field).shape[0] > 0]
|
|
|
|
if len(cat_box_list) > 0:
|
|
data = Tensor.cat(*cat_field_list, dim=0)
|
|
else:
|
|
data = bboxes[0].get_field(field)
|
|
|
|
cat_boxes.add_field(field, data)
|
|
|
|
return cat_boxes
|
|
|
|
|
|
class FPN:
|
|
def __init__(self, in_channels_list, out_channels):
|
|
self.inner_blocks, self.layer_blocks = [], []
|
|
for in_channels in in_channels_list:
|
|
self.inner_blocks.append(nn.Conv2d(in_channels, out_channels, kernel_size=1))
|
|
self.layer_blocks.append(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1))
|
|
self.top_block = LastLevelMaxPool()
|
|
|
|
def __call__(self, x: Tensor):
|
|
last_inner = self.inner_blocks[-1](x[-1])
|
|
results = []
|
|
results.append(self.layer_blocks[-1](last_inner))
|
|
for feature, inner_block, layer_block in zip(
|
|
x[:-1][::-1], self.inner_blocks[:-1][::-1], self.layer_blocks[:-1][::-1]
|
|
):
|
|
if not inner_block:
|
|
continue
|
|
inner_top_down = nearest_interpolate(last_inner, scale_factor=2)
|
|
inner_lateral = inner_block(feature)
|
|
last_inner = inner_lateral + inner_top_down
|
|
layer_result = layer_block(last_inner)
|
|
results.insert(0, layer_result)
|
|
last_results = self.top_block(results[-1])
|
|
results.extend(last_results)
|
|
|
|
return tuple(results)
|
|
|
|
|
|
class ResNetFPN:
|
|
def __init__(self, resnet, out_channels=256):
|
|
self.out_channels = out_channels
|
|
self.body = resnet
|
|
in_channels_stage2 = 256
|
|
in_channels_list = [
|
|
in_channels_stage2,
|
|
in_channels_stage2 * 2,
|
|
in_channels_stage2 * 4,
|
|
in_channels_stage2 * 8,
|
|
]
|
|
self.fpn = FPN(in_channels_list, out_channels)
|
|
|
|
def __call__(self, x):
|
|
x = self.body(x)
|
|
return self.fpn(x)
|
|
|
|
|
|
class AnchorGenerator:
|
|
def __init__(
|
|
self,
|
|
sizes=(32, 64, 128, 256, 512),
|
|
aspect_ratios=(0.5, 1.0, 2.0),
|
|
anchor_strides=(4, 8, 16, 32, 64),
|
|
straddle_thresh=0,
|
|
):
|
|
if len(anchor_strides) == 1:
|
|
anchor_stride = anchor_strides[0]
|
|
cell_anchors = [
|
|
generate_anchors(anchor_stride, sizes, aspect_ratios)
|
|
]
|
|
else:
|
|
if len(anchor_strides) != len(sizes):
|
|
raise RuntimeError("FPN should have #anchor_strides == #sizes")
|
|
|
|
cell_anchors = [
|
|
generate_anchors(
|
|
anchor_stride,
|
|
size if isinstance(size, (tuple, list)) else (size,),
|
|
aspect_ratios
|
|
)
|
|
for anchor_stride, size in zip(anchor_strides, sizes)
|
|
]
|
|
self.strides = anchor_strides
|
|
self.cell_anchors = cell_anchors
|
|
self.straddle_thresh = straddle_thresh
|
|
|
|
def num_anchors_per_location(self):
|
|
return [cell_anchors.shape[0] for cell_anchors in self.cell_anchors]
|
|
|
|
def grid_anchors(self, grid_sizes):
|
|
anchors = []
|
|
for size, stride, base_anchors in zip(
|
|
grid_sizes, self.strides, self.cell_anchors
|
|
):
|
|
grid_height, grid_width = size
|
|
device = base_anchors.device
|
|
shifts_x = Tensor.arange(
|
|
start=0, stop=grid_width * stride, step=stride, dtype=dtypes.float32, device=device
|
|
)
|
|
shifts_y = Tensor.arange(
|
|
start=0, stop=grid_height * stride, step=stride, dtype=dtypes.float32, device=device
|
|
)
|
|
shift_y, shift_x = meshgrid(shifts_y, shifts_x)
|
|
shift_x = shift_x.reshape(-1)
|
|
shift_y = shift_y.reshape(-1)
|
|
shifts = Tensor.stack(shift_x, shift_y, shift_x, shift_y, dim=1)
|
|
|
|
anchors.append(
|
|
(shifts.reshape(-1, 1, 4) + base_anchors.reshape(1, -1, 4)).reshape(-1, 4)
|
|
)
|
|
|
|
return anchors
|
|
|
|
def add_visibility_to(self, boxlist):
|
|
image_width, image_height = boxlist.size
|
|
anchors = boxlist.bbox
|
|
if self.straddle_thresh >= 0:
|
|
inds_inside = (
|
|
(anchors[:, 0] >= -self.straddle_thresh)
|
|
* (anchors[:, 1] >= -self.straddle_thresh)
|
|
* (anchors[:, 2] < image_width + self.straddle_thresh)
|
|
* (anchors[:, 3] < image_height + self.straddle_thresh)
|
|
)
|
|
else:
|
|
device = anchors.device
|
|
inds_inside = Tensor.ones(anchors.shape[0], dtype=dtypes.uint8, device=device)
|
|
boxlist.add_field("visibility", inds_inside)
|
|
|
|
def __call__(self, image_list, feature_maps):
|
|
grid_sizes = [feature_map.shape[-2:] for feature_map in feature_maps]
|
|
anchors_over_all_feature_maps = self.grid_anchors(grid_sizes)
|
|
anchors = []
|
|
for (image_height, image_width) in image_list.image_sizes:
|
|
anchors_in_image = []
|
|
for anchors_per_feature_map in anchors_over_all_feature_maps:
|
|
boxlist = BoxList(
|
|
anchors_per_feature_map, (image_width, image_height), mode="xyxy"
|
|
)
|
|
self.add_visibility_to(boxlist)
|
|
anchors_in_image.append(boxlist)
|
|
anchors.append(anchors_in_image)
|
|
return anchors
|
|
|
|
|
|
def generate_anchors(
|
|
stride=16, sizes=(32, 64, 128, 256, 512), aspect_ratios=(0.5, 1, 2)
|
|
):
|
|
return _generate_anchors(stride, Tensor(list(sizes)) / stride, Tensor(list(aspect_ratios)))
|
|
|
|
|
|
def _generate_anchors(base_size, scales, aspect_ratios):
|
|
anchor = Tensor([1, 1, base_size, base_size]) - 1
|
|
anchors = _ratio_enum(anchor, aspect_ratios)
|
|
anchors = Tensor.cat(
|
|
*[_scale_enum(anchors[i, :], scales).reshape(-1, 4) for i in range(anchors.shape[0])]
|
|
)
|
|
return anchors
|
|
|
|
|
|
def _whctrs(anchor):
|
|
w = anchor[2] - anchor[0] + 1
|
|
h = anchor[3] - anchor[1] + 1
|
|
x_ctr = anchor[0] + 0.5 * (w - 1)
|
|
y_ctr = anchor[1] + 0.5 * (h - 1)
|
|
return w, h, x_ctr, y_ctr
|
|
|
|
|
|
def _mkanchors(ws, hs, x_ctr, y_ctr):
|
|
ws = ws[:, None]
|
|
hs = hs[:, None]
|
|
anchors = Tensor.cat(*(
|
|
x_ctr - 0.5 * (ws - 1),
|
|
y_ctr - 0.5 * (hs - 1),
|
|
x_ctr + 0.5 * (ws - 1),
|
|
y_ctr + 0.5 * (hs - 1),
|
|
), dim=1)
|
|
return anchors
|
|
|
|
|
|
def _ratio_enum(anchor, ratios):
|
|
w, h, x_ctr, y_ctr = _whctrs(anchor)
|
|
size = w * h
|
|
size_ratios = size / ratios
|
|
ws = rint(Tensor.sqrt(size_ratios))
|
|
hs = rint(ws * ratios)
|
|
anchors = _mkanchors(ws, hs, x_ctr, y_ctr)
|
|
return anchors
|
|
|
|
|
|
def _scale_enum(anchor, scales):
|
|
w, h, x_ctr, y_ctr = _whctrs(anchor)
|
|
ws = w * scales
|
|
hs = h * scales
|
|
anchors = _mkanchors(ws, hs, x_ctr, y_ctr)
|
|
return anchors
|
|
|
|
|
|
class RPNHead:
|
|
def __init__(self, in_channels, num_anchors):
|
|
self.conv = nn.Conv2d(in_channels, 256, kernel_size=3, padding=1)
|
|
self.cls_logits = nn.Conv2d(256, num_anchors, kernel_size=1)
|
|
self.bbox_pred = nn.Conv2d(256, num_anchors * 4, kernel_size=1)
|
|
|
|
def __call__(self, x):
|
|
logits = []
|
|
bbox_reg = []
|
|
for feature in x:
|
|
t = Tensor.relu(self.conv(feature))
|
|
logits.append(self.cls_logits(t))
|
|
bbox_reg.append(self.bbox_pred(t))
|
|
return logits, bbox_reg
|
|
|
|
|
|
class BoxCoder(object):
|
|
def __init__(self, weights, bbox_xform_clip=math.log(1000. / 16)):
|
|
self.weights = weights
|
|
self.bbox_xform_clip = bbox_xform_clip
|
|
|
|
def encode(self, reference_boxes, proposals):
|
|
TO_REMOVE = 1 # TODO remove
|
|
ex_widths = proposals[:, 2] - proposals[:, 0] + TO_REMOVE
|
|
ex_heights = proposals[:, 3] - proposals[:, 1] + TO_REMOVE
|
|
ex_ctr_x = proposals[:, 0] + 0.5 * ex_widths
|
|
ex_ctr_y = proposals[:, 1] + 0.5 * ex_heights
|
|
|
|
gt_widths = reference_boxes[:, 2] - reference_boxes[:, 0] + TO_REMOVE
|
|
gt_heights = reference_boxes[:, 3] - reference_boxes[:, 1] + TO_REMOVE
|
|
gt_ctr_x = reference_boxes[:, 0] + 0.5 * gt_widths
|
|
gt_ctr_y = reference_boxes[:, 1] + 0.5 * gt_heights
|
|
|
|
wx, wy, ww, wh = self.weights
|
|
targets_dx = wx * (gt_ctr_x - ex_ctr_x) / ex_widths
|
|
targets_dy = wy * (gt_ctr_y - ex_ctr_y) / ex_heights
|
|
targets_dw = ww * Tensor.log(gt_widths / ex_widths)
|
|
targets_dh = wh * Tensor.log(gt_heights / ex_heights)
|
|
|
|
targets = Tensor.stack(targets_dx, targets_dy, targets_dw, targets_dh, dim=1)
|
|
return targets
|
|
|
|
def decode(self, rel_codes, boxes):
|
|
boxes = boxes.cast(rel_codes.dtype)
|
|
rel_codes = rel_codes
|
|
|
|
TO_REMOVE = 1 # TODO remove
|
|
widths = boxes[:, 2] - boxes[:, 0] + TO_REMOVE
|
|
heights = boxes[:, 3] - boxes[:, 1] + TO_REMOVE
|
|
ctr_x = boxes[:, 0] + 0.5 * widths
|
|
ctr_y = boxes[:, 1] + 0.5 * heights
|
|
|
|
wx, wy, ww, wh = self.weights
|
|
dx = rel_codes[:, 0::4] / wx
|
|
dy = rel_codes[:, 1::4] / wy
|
|
dw = rel_codes[:, 2::4] / ww
|
|
dh = rel_codes[:, 3::4] / wh
|
|
|
|
# Prevent sending too large values into Tensor.exp()
|
|
dw = dw.clip(min_=dw.min(), max_=self.bbox_xform_clip)
|
|
dh = dh.clip(min_=dh.min(), max_=self.bbox_xform_clip)
|
|
|
|
pred_ctr_x = dx * widths[:, None] + ctr_x[:, None]
|
|
pred_ctr_y = dy * heights[:, None] + ctr_y[:, None]
|
|
pred_w = dw.exp() * widths[:, None]
|
|
pred_h = dh.exp() * heights[:, None]
|
|
x = pred_ctr_x - 0.5 * pred_w
|
|
y = pred_ctr_y - 0.5 * pred_h
|
|
w = pred_ctr_x + 0.5 * pred_w - 1
|
|
h = pred_ctr_y + 0.5 * pred_h - 1
|
|
pred_boxes = Tensor.stack(x, y, w, h).permute(1,2,0).reshape(rel_codes.shape[0], rel_codes.shape[1])
|
|
return pred_boxes
|
|
|
|
|
|
def boxlist_nms(boxlist, nms_thresh, max_proposals=-1, score_field="scores"):
|
|
if nms_thresh <= 0:
|
|
return boxlist
|
|
mode = boxlist.mode
|
|
boxlist = boxlist.convert("xyxy")
|
|
boxes = boxlist.bbox
|
|
score = boxlist.get_field(score_field)
|
|
keep = _box_nms(boxes.numpy(), score.numpy(), nms_thresh)
|
|
if max_proposals > 0:
|
|
keep = keep[:max_proposals]
|
|
boxlist = boxlist[keep]
|
|
return boxlist.convert(mode)
|
|
|
|
|
|
def remove_small_boxes(boxlist, min_size):
|
|
xywh_boxes = boxlist.convert("xywh").bbox
|
|
_, _, ws, hs = xywh_boxes.chunk(4, dim=1)
|
|
keep = ((
|
|
(ws >= min_size) * (hs >= min_size)
|
|
) > 0).reshape(-1)
|
|
if keep.sum().numpy() == len(boxlist):
|
|
return boxlist
|
|
else:
|
|
keep = keep.numpy().nonzero()[0]
|
|
return boxlist[keep]
|
|
|
|
|
|
class RPNPostProcessor:
|
|
# Not used in Loss calculation
|
|
def __init__(
|
|
self,
|
|
pre_nms_top_n,
|
|
post_nms_top_n,
|
|
nms_thresh,
|
|
min_size,
|
|
box_coder=None,
|
|
fpn_post_nms_top_n=None,
|
|
):
|
|
self.pre_nms_top_n = pre_nms_top_n
|
|
self.post_nms_top_n = post_nms_top_n
|
|
self.nms_thresh = nms_thresh
|
|
self.min_size = min_size
|
|
|
|
if box_coder is None:
|
|
box_coder = BoxCoder(weights=(1.0, 1.0, 1.0, 1.0))
|
|
self.box_coder = box_coder
|
|
|
|
if fpn_post_nms_top_n is None:
|
|
fpn_post_nms_top_n = post_nms_top_n
|
|
self.fpn_post_nms_top_n = fpn_post_nms_top_n
|
|
|
|
def forward_for_single_feature_map(self, anchors, objectness, box_regression):
|
|
device = objectness.device
|
|
N, A, H, W = objectness.shape
|
|
objectness = permute_and_flatten(objectness, N, A, 1, H, W).reshape(N, -1)
|
|
objectness = objectness.sigmoid()
|
|
|
|
box_regression = permute_and_flatten(box_regression, N, A, 4, H, W)
|
|
|
|
num_anchors = A * H * W
|
|
|
|
pre_nms_top_n = min(self.pre_nms_top_n, num_anchors)
|
|
objectness, topk_idx = topk(objectness, pre_nms_top_n, dim=1, sorted=False)
|
|
concat_anchors = Tensor.cat(*[a.bbox for a in anchors], dim=0).reshape(N, -1, 4)
|
|
image_shapes = [box.size for box in anchors]
|
|
|
|
box_regression_list = []
|
|
concat_anchors_list = []
|
|
for batch_idx in range(N):
|
|
box_regression_list.append(tensor_gather(box_regression[batch_idx], topk_idx[batch_idx]))
|
|
concat_anchors_list.append(tensor_gather(concat_anchors[batch_idx], topk_idx[batch_idx]))
|
|
|
|
box_regression = Tensor.stack(*box_regression_list)
|
|
concat_anchors = Tensor.stack(*concat_anchors_list)
|
|
|
|
proposals = self.box_coder.decode(
|
|
box_regression.reshape(-1, 4), concat_anchors.reshape(-1, 4)
|
|
)
|
|
|
|
proposals = proposals.reshape(N, -1, 4)
|
|
|
|
result = []
|
|
for proposal, score, im_shape in zip(proposals, objectness, image_shapes):
|
|
boxlist = BoxList(proposal, im_shape, mode="xyxy")
|
|
boxlist.add_field("objectness", score)
|
|
boxlist = boxlist.clip_to_image(remove_empty=False)
|
|
boxlist = remove_small_boxes(boxlist, self.min_size)
|
|
boxlist = boxlist_nms(
|
|
boxlist,
|
|
self.nms_thresh,
|
|
max_proposals=self.post_nms_top_n,
|
|
score_field="objectness",
|
|
)
|
|
result.append(boxlist)
|
|
return result
|
|
|
|
def __call__(self, anchors, objectness, box_regression):
|
|
sampled_boxes = []
|
|
num_levels = len(objectness)
|
|
anchors = list(zip(*anchors))
|
|
for a, o, b in zip(anchors, objectness, box_regression):
|
|
sampled_boxes.append(self.forward_for_single_feature_map(a, o, b))
|
|
|
|
boxlists = list(zip(*sampled_boxes))
|
|
boxlists = [cat_boxlist(boxlist) for boxlist in boxlists]
|
|
|
|
if num_levels > 1:
|
|
boxlists = self.select_over_all_levels(boxlists)
|
|
|
|
return boxlists
|
|
|
|
def select_over_all_levels(self, boxlists):
|
|
num_images = len(boxlists)
|
|
for i in range(num_images):
|
|
objectness = boxlists[i].get_field("objectness")
|
|
post_nms_top_n = min(self.fpn_post_nms_top_n, objectness.shape[0])
|
|
_, inds_sorted = topk(objectness,
|
|
post_nms_top_n, dim=0, sorted=False
|
|
)
|
|
boxlists[i] = boxlists[i][inds_sorted]
|
|
return boxlists
|
|
|
|
|
|
class RPN:
|
|
def __init__(self, in_channels):
|
|
self.anchor_generator = AnchorGenerator()
|
|
|
|
in_channels = 256
|
|
head = RPNHead(
|
|
in_channels, self.anchor_generator.num_anchors_per_location()[0]
|
|
)
|
|
rpn_box_coder = BoxCoder(weights=(1.0, 1.0, 1.0, 1.0))
|
|
box_selector_test = RPNPostProcessor(
|
|
pre_nms_top_n=1000,
|
|
post_nms_top_n=1000,
|
|
nms_thresh=0.7,
|
|
min_size=0,
|
|
box_coder=rpn_box_coder,
|
|
fpn_post_nms_top_n=1000
|
|
)
|
|
self.head = head
|
|
self.box_selector_test = box_selector_test
|
|
|
|
def __call__(self, images, features, targets=None):
|
|
objectness, rpn_box_regression = self.head(features)
|
|
anchors = self.anchor_generator(images, features)
|
|
boxes = self.box_selector_test(anchors, objectness, rpn_box_regression)
|
|
return boxes, {}
|
|
|
|
|
|
def make_conv3x3(
|
|
in_channels,
|
|
out_channels,
|
|
dilation=1,
|
|
stride=1,
|
|
use_gn=False,
|
|
):
|
|
conv = nn.Conv2d(
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size=3,
|
|
stride=stride,
|
|
padding=dilation,
|
|
dilation=dilation,
|
|
bias=False if use_gn else True
|
|
)
|
|
return conv
|
|
|
|
|
|
class MaskRCNNFPNFeatureExtractor:
|
|
def __init__(self):
|
|
resolution = 14
|
|
scales = (0.25, 0.125, 0.0625, 0.03125)
|
|
sampling_ratio = 2
|
|
pooler = Pooler(
|
|
output_size=(resolution, resolution),
|
|
scales=scales,
|
|
sampling_ratio=sampling_ratio,
|
|
)
|
|
input_size = 256
|
|
self.pooler = pooler
|
|
|
|
use_gn = False
|
|
layers = (256, 256, 256, 256)
|
|
dilation = 1
|
|
self.mask_fcn1 = make_conv3x3(input_size, layers[0], dilation=dilation, stride=1, use_gn=use_gn)
|
|
self.mask_fcn2 = make_conv3x3(layers[0], layers[1], dilation=dilation, stride=1, use_gn=use_gn)
|
|
self.mask_fcn3 = make_conv3x3(layers[1], layers[2], dilation=dilation, stride=1, use_gn=use_gn)
|
|
self.mask_fcn4 = make_conv3x3(layers[2], layers[3], dilation=dilation, stride=1, use_gn=use_gn)
|
|
self.blocks = [self.mask_fcn1, self.mask_fcn2, self.mask_fcn3, self.mask_fcn4]
|
|
|
|
def __call__(self, x, proposals):
|
|
x = self.pooler(x, proposals)
|
|
for layer in self.blocks:
|
|
if x is not None:
|
|
x = Tensor.relu(layer(x))
|
|
return x
|
|
|
|
|
|
class MaskRCNNC4Predictor:
|
|
def __init__(self):
|
|
num_classes = 81
|
|
dim_reduced = 256
|
|
num_inputs = dim_reduced
|
|
self.conv5_mask = nn.ConvTranspose2d(num_inputs, dim_reduced, 2, 2, 0)
|
|
self.mask_fcn_logits = nn.Conv2d(dim_reduced, num_classes, 1, 1, 0)
|
|
|
|
def __call__(self, x):
|
|
x = Tensor.relu(self.conv5_mask(x))
|
|
return self.mask_fcn_logits(x)
|
|
|
|
|
|
class FPN2MLPFeatureExtractor:
|
|
def __init__(self, cfg):
|
|
resolution = 7
|
|
scales = (0.25, 0.125, 0.0625, 0.03125)
|
|
sampling_ratio = 2
|
|
pooler = Pooler(
|
|
output_size=(resolution, resolution),
|
|
scales=scales,
|
|
sampling_ratio=sampling_ratio,
|
|
)
|
|
input_size = 256 * resolution ** 2
|
|
representation_size = 1024
|
|
self.pooler = pooler
|
|
self.fc6 = nn.Linear(input_size, representation_size)
|
|
self.fc7 = nn.Linear(representation_size, representation_size)
|
|
|
|
def __call__(self, x, proposals):
|
|
x = self.pooler(x, proposals)
|
|
x = x.reshape(x.shape[0], -1)
|
|
x = Tensor.relu(self.fc6(x))
|
|
x = Tensor.relu(self.fc7(x))
|
|
return x
|
|
|
|
|
|
def _bilinear_interpolate(
|
|
input, # [N, C, H, W]
|
|
roi_batch_ind, # [K]
|
|
y, # [K, PH, IY]
|
|
x, # [K, PW, IX]
|
|
ymask, # [K, IY]
|
|
xmask, # [K, IX]
|
|
):
|
|
_, channels, height, width = input.shape
|
|
y = y.clip(min_=0.0, max_=float(height-1))
|
|
x = x.clip(min_=0.0, max_=float(width-1))
|
|
|
|
# Tensor.where doesnt work well with int32 data so cast to float32
|
|
y_low = y.cast(dtypes.int32).contiguous().float().contiguous()
|
|
x_low = x.cast(dtypes.int32).contiguous().float().contiguous()
|
|
|
|
y_high = Tensor.where(y_low >= height - 1, float(height - 1), y_low + 1)
|
|
y_low = Tensor.where(y_low >= height - 1, float(height - 1), y_low)
|
|
|
|
x_high = Tensor.where(x_low >= width - 1, float(width - 1), x_low + 1)
|
|
x_low = Tensor.where(x_low >= width - 1, float(width - 1), x_low)
|
|
|
|
ly = y - y_low
|
|
lx = x - x_low
|
|
hy = 1.0 - ly
|
|
hx = 1.0 - lx
|
|
|
|
def masked_index(
|
|
y, # [K, PH, IY]
|
|
x, # [K, PW, IX]
|
|
):
|
|
if ymask is not None:
|
|
assert xmask is not None
|
|
y = Tensor.where(ymask[:, None, :], y, 0)
|
|
x = Tensor.where(xmask[:, None, :], x, 0)
|
|
key1 = roi_batch_ind[:, None, None, None, None, None]
|
|
key2 = Tensor.arange(channels, device=input.device)[None, :, None, None, None, None]
|
|
key3 = y[:, None, :, None, :, None]
|
|
key4 = x[:, None, None, :, None, :]
|
|
return tensor_getitem(input,key1,key2,key3,key4) # [K, C, PH, PW, IY, IX]
|
|
|
|
v1 = masked_index(y_low, x_low)
|
|
v2 = masked_index(y_low, x_high)
|
|
v3 = masked_index(y_high, x_low)
|
|
v4 = masked_index(y_high, x_high)
|
|
|
|
# all ws preemptively [K, C, PH, PW, IY, IX]
|
|
def outer_prod(y, x):
|
|
return y[:, None, :, None, :, None] * x[:, None, None, :, None, :]
|
|
|
|
w1 = outer_prod(hy, hx)
|
|
w2 = outer_prod(hy, lx)
|
|
w3 = outer_prod(ly, hx)
|
|
w4 = outer_prod(ly, lx)
|
|
|
|
val = w1*v1 + w2*v2 + w3*v3 + w4*v4
|
|
return val
|
|
|
|
#https://pytorch.org/vision/main/_modules/torchvision/ops/roi_align.html#roi_align
|
|
def _roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
|
|
orig_dtype = input.dtype
|
|
_, _, height, width = input.shape
|
|
ph = Tensor.arange(pooled_height, device=input.device)
|
|
pw = Tensor.arange(pooled_width, device=input.device)
|
|
|
|
roi_batch_ind = rois[:, 0].cast(dtypes.int32).contiguous()
|
|
offset = 0.5 if aligned else 0.0
|
|
roi_start_w = rois[:, 1] * spatial_scale - offset
|
|
roi_start_h = rois[:, 2] * spatial_scale - offset
|
|
roi_end_w = rois[:, 3] * spatial_scale - offset
|
|
roi_end_h = rois[:, 4] * spatial_scale - offset
|
|
|
|
roi_width = roi_end_w - roi_start_w
|
|
roi_height = roi_end_h - roi_start_h
|
|
if not aligned:
|
|
roi_width = roi_width.maximum(1.0)
|
|
roi_height = roi_height.maximum(1.0)
|
|
|
|
bin_size_h = roi_height / pooled_height
|
|
bin_size_w = roi_width / pooled_width
|
|
|
|
exact_sampling = sampling_ratio > 0
|
|
roi_bin_grid_h = sampling_ratio if exact_sampling else (roi_height / pooled_height).ceil()
|
|
roi_bin_grid_w = sampling_ratio if exact_sampling else (roi_width / pooled_width).ceil()
|
|
|
|
if exact_sampling:
|
|
count = max(roi_bin_grid_h * roi_bin_grid_w, 1)
|
|
iy = Tensor.arange(roi_bin_grid_h, device=input.device)
|
|
ix = Tensor.arange(roi_bin_grid_w, device=input.device)
|
|
ymask = None
|
|
xmask = None
|
|
else:
|
|
count = (roi_bin_grid_h * roi_bin_grid_w).maximum(1)
|
|
iy = Tensor.arange(height, device=input.device)
|
|
ix = Tensor.arange(width, device=input.device)
|
|
ymask = iy[None, :] < roi_bin_grid_h[:, None]
|
|
xmask = ix[None, :] < roi_bin_grid_w[:, None]
|
|
|
|
def from_K(t):
|
|
return t[:, None, None]
|
|
|
|
y = (
|
|
from_K(roi_start_h)
|
|
+ ph[None, :, None] * from_K(bin_size_h)
|
|
+ (iy[None, None, :] + 0.5) * from_K(bin_size_h / roi_bin_grid_h)
|
|
)
|
|
x = (
|
|
from_K(roi_start_w)
|
|
+ pw[None, :, None] * from_K(bin_size_w)
|
|
+ (ix[None, None, :] + 0.5) * from_K(bin_size_w / roi_bin_grid_w)
|
|
)
|
|
|
|
val = _bilinear_interpolate(input, roi_batch_ind, y, x, ymask, xmask)
|
|
if not exact_sampling:
|
|
val = ymask[:, None, None, None, :, None].where(val, 0)
|
|
val = xmask[:, None, None, None, None, :].where(val, 0)
|
|
|
|
output = val.sum((-1, -2))
|
|
if isinstance(count, Tensor):
|
|
output /= count[:, None, None, None]
|
|
else:
|
|
output /= count
|
|
|
|
output = output.cast(orig_dtype)
|
|
return output
|
|
|
|
class ROIAlign:
|
|
def __init__(self, output_size, spatial_scale, sampling_ratio):
|
|
self.output_size = output_size
|
|
self.spatial_scale = spatial_scale
|
|
self.sampling_ratio = sampling_ratio
|
|
|
|
def __call__(self, input, rois):
|
|
output = _roi_align(
|
|
input, rois, self.spatial_scale, self.output_size[0], self.output_size[1], self.sampling_ratio, aligned=False
|
|
)
|
|
return output
|
|
|
|
|
|
class LevelMapper:
|
|
def __init__(self, k_min, k_max, canonical_scale=224, canonical_level=4, eps=1e-6):
|
|
self.k_min = k_min
|
|
self.k_max = k_max
|
|
self.s0 = canonical_scale
|
|
self.lvl0 = canonical_level
|
|
self.eps = eps
|
|
|
|
def __call__(self, boxlists):
|
|
s = Tensor.sqrt(Tensor.cat(*[boxlist.area() for boxlist in boxlists]))
|
|
target_lvls = (self.lvl0 + Tensor.log2(s / self.s0 + self.eps)).floor()
|
|
target_lvls = target_lvls.clip(min_=self.k_min, max_=self.k_max)
|
|
return target_lvls - self.k_min
|
|
|
|
|
|
class Pooler:
|
|
def __init__(self, output_size, scales, sampling_ratio):
|
|
self.output_size = output_size
|
|
self.scales = scales
|
|
self.sampling_ratio = sampling_ratio
|
|
poolers = []
|
|
for scale in scales:
|
|
poolers.append(
|
|
ROIAlign(
|
|
output_size, spatial_scale=scale, sampling_ratio=sampling_ratio
|
|
)
|
|
)
|
|
self.poolers = poolers
|
|
self.output_size = output_size
|
|
lvl_min = -math.log2(scales[0])
|
|
lvl_max = -math.log2(scales[-1])
|
|
self.map_levels = LevelMapper(lvl_min, lvl_max)
|
|
|
|
def convert_to_roi_format(self, boxes):
|
|
concat_boxes = Tensor.cat(*[b.bbox for b in boxes], dim=0)
|
|
device, dtype = concat_boxes.device, concat_boxes.dtype
|
|
ids = Tensor.cat(
|
|
*[
|
|
Tensor.full((len(b), 1), i, dtype=dtype, device=device)
|
|
for i, b in enumerate(boxes)
|
|
],
|
|
dim=0,
|
|
)
|
|
if concat_boxes.shape[0] != 0:
|
|
rois = Tensor.cat(*[ids, concat_boxes], dim=1)
|
|
return rois
|
|
|
|
def __call__(self, x, boxes):
|
|
num_levels = len(self.poolers)
|
|
rois = self.convert_to_roi_format(boxes)
|
|
if rois is not None:
|
|
if num_levels == 1:
|
|
return self.poolers[0](x[0], rois)
|
|
|
|
levels = self.map_levels(boxes)
|
|
results = []
|
|
all_idxs = []
|
|
for level, (per_level_feature, pooler) in enumerate(zip(x, self.poolers)):
|
|
# this is fine because no grad will flow through index
|
|
idx_in_level = (levels.numpy() == level).nonzero()[0]
|
|
if len(idx_in_level) > 0:
|
|
rois_per_level = tensor_gather(rois, idx_in_level)
|
|
pooler_output = pooler(per_level_feature, rois_per_level)
|
|
all_idxs.extend(idx_in_level)
|
|
results.append(pooler_output)
|
|
|
|
return tensor_gather(Tensor.cat(*results), [x[0] for x in sorted({i:idx for i, idx in enumerate(all_idxs)}.items(), key=lambda x: x[1])])
|
|
|
|
|
|
class FPNPredictor:
|
|
def __init__(self):
|
|
num_classes = 81
|
|
representation_size = 1024
|
|
self.cls_score = nn.Linear(representation_size, num_classes)
|
|
num_bbox_reg_classes = num_classes
|
|
self.bbox_pred = nn.Linear(representation_size, num_bbox_reg_classes * 4)
|
|
|
|
def __call__(self, x):
|
|
scores = self.cls_score(x)
|
|
bbox_deltas = self.bbox_pred(x)
|
|
return scores, bbox_deltas
|
|
|
|
|
|
class PostProcessor:
|
|
# Not used in training
|
|
def __init__(
|
|
self,
|
|
score_thresh=0.05,
|
|
nms=0.5,
|
|
detections_per_img=100,
|
|
box_coder=None,
|
|
cls_agnostic_bbox_reg=False
|
|
):
|
|
self.score_thresh = score_thresh
|
|
self.nms = nms
|
|
self.detections_per_img = detections_per_img
|
|
if box_coder is None:
|
|
box_coder = BoxCoder(weights=(10., 10., 5., 5.))
|
|
self.box_coder = box_coder
|
|
self.cls_agnostic_bbox_reg = cls_agnostic_bbox_reg
|
|
|
|
def __call__(self, x, boxes):
|
|
class_logits, box_regression = x
|
|
class_prob = Tensor.softmax(class_logits, -1)
|
|
image_shapes = [box.size for box in boxes]
|
|
boxes_per_image = [len(box) for box in boxes]
|
|
concat_boxes = Tensor.cat(*[a.bbox for a in boxes], dim=0)
|
|
|
|
if self.cls_agnostic_bbox_reg:
|
|
box_regression = box_regression[:, -4:]
|
|
proposals = self.box_coder.decode(
|
|
box_regression.reshape(sum(boxes_per_image), -1), concat_boxes
|
|
)
|
|
if self.cls_agnostic_bbox_reg:
|
|
proposals = proposals.repeat([1, class_prob.shape[1]])
|
|
num_classes = class_prob.shape[1]
|
|
proposals = proposals.unsqueeze(0)
|
|
class_prob = class_prob.unsqueeze(0)
|
|
results = []
|
|
for prob, boxes_per_img, image_shape in zip(
|
|
class_prob, proposals, image_shapes
|
|
):
|
|
boxlist = self.prepare_boxlist(boxes_per_img, prob, image_shape)
|
|
boxlist = boxlist.clip_to_image(remove_empty=False)
|
|
boxlist = self.filter_results(boxlist, num_classes)
|
|
results.append(boxlist)
|
|
return results
|
|
|
|
def prepare_boxlist(self, boxes, scores, image_shape):
|
|
boxes = boxes.reshape(-1, 4)
|
|
scores = scores.reshape(-1)
|
|
boxlist = BoxList(boxes, image_shape, mode="xyxy")
|
|
boxlist.add_field("scores", scores)
|
|
return boxlist
|
|
|
|
def filter_results(self, boxlist, num_classes):
|
|
boxes = boxlist.bbox.reshape(-1, num_classes * 4)
|
|
scores = boxlist.get_field("scores").reshape(-1, num_classes)
|
|
|
|
device = scores.device
|
|
result = []
|
|
scores = scores.numpy()
|
|
boxes = boxes.numpy()
|
|
inds_all = scores > self.score_thresh
|
|
for j in range(1, num_classes):
|
|
inds = inds_all[:, j].nonzero()[0]
|
|
# This needs to be done in numpy because it can create empty arrays
|
|
scores_j = scores[inds, j]
|
|
boxes_j = boxes[inds, j * 4: (j + 1) * 4]
|
|
boxes_j = Tensor(boxes_j)
|
|
scores_j = Tensor(scores_j)
|
|
boxlist_for_class = BoxList(boxes_j, boxlist.size, mode="xyxy")
|
|
boxlist_for_class.add_field("scores", scores_j)
|
|
if len(boxlist_for_class):
|
|
boxlist_for_class = boxlist_nms(
|
|
boxlist_for_class, self.nms
|
|
)
|
|
num_labels = len(boxlist_for_class)
|
|
boxlist_for_class.add_field(
|
|
"labels", Tensor.full((num_labels,), j, device=device)
|
|
)
|
|
result.append(boxlist_for_class)
|
|
|
|
result = cat_boxlist(result)
|
|
number_of_detections = len(result)
|
|
|
|
if number_of_detections > self.detections_per_img > 0:
|
|
cls_scores = result.get_field("scores")
|
|
image_thresh, _ = topk(cls_scores, k=self.detections_per_img)
|
|
image_thresh = image_thresh.numpy()[-1]
|
|
keep = (cls_scores.numpy() >= image_thresh).nonzero()[0]
|
|
result = result[keep]
|
|
return result
|
|
|
|
|
|
class RoIBoxHead:
|
|
def __init__(self, in_channels):
|
|
self.feature_extractor = FPN2MLPFeatureExtractor(in_channels)
|
|
self.predictor = FPNPredictor()
|
|
self.post_processor = PostProcessor(
|
|
score_thresh=0.05,
|
|
nms=0.5,
|
|
detections_per_img=100,
|
|
box_coder=BoxCoder(weights=(10., 10., 5., 5.)),
|
|
cls_agnostic_bbox_reg=False
|
|
)
|
|
|
|
def __call__(self, features, proposals, targets=None):
|
|
x = self.feature_extractor(features, proposals)
|
|
class_logits, box_regression = self.predictor(x)
|
|
if not Tensor.training:
|
|
result = self.post_processor((class_logits, box_regression), proposals)
|
|
return x, result, {}
|
|
|
|
|
|
class MaskPostProcessor:
|
|
# Not used in loss calculation
|
|
def __call__(self, x, boxes):
|
|
mask_prob = x.sigmoid().numpy()
|
|
num_masks = x.shape[0]
|
|
labels = [bbox.get_field("labels") for bbox in boxes]
|
|
labels = Tensor.cat(*labels).numpy().astype(np.int32)
|
|
index = np.arange(num_masks)
|
|
mask_prob = mask_prob[index, labels][:, None]
|
|
boxes_per_image, cumsum = [], 0
|
|
for box in boxes:
|
|
cumsum += len(box)
|
|
boxes_per_image.append(cumsum)
|
|
# using numpy here as Tensor.chunk doesnt have custom chunk sizes
|
|
mask_prob = np.split(mask_prob, boxes_per_image, axis=0)
|
|
results = []
|
|
for prob, box in zip(mask_prob, boxes):
|
|
bbox = BoxList(box.bbox, box.size, mode="xyxy")
|
|
for field in box.fields():
|
|
bbox.add_field(field, box.get_field(field))
|
|
prob = Tensor(prob)
|
|
bbox.add_field("mask", prob)
|
|
results.append(bbox)
|
|
|
|
return results
|
|
|
|
|
|
class Mask:
|
|
def __init__(self):
|
|
self.feature_extractor = MaskRCNNFPNFeatureExtractor()
|
|
self.predictor = MaskRCNNC4Predictor()
|
|
self.post_processor = MaskPostProcessor()
|
|
|
|
def __call__(self, features, proposals, targets=None):
|
|
x = self.feature_extractor(features, proposals)
|
|
if x:
|
|
mask_logits = self.predictor(x)
|
|
if not Tensor.training:
|
|
result = self.post_processor(mask_logits, proposals)
|
|
return x, result, {}
|
|
return x, [], {}
|
|
|
|
|
|
class RoIHeads:
|
|
def __init__(self, in_channels):
|
|
self.box = RoIBoxHead(in_channels)
|
|
self.mask = Mask()
|
|
|
|
def __call__(self, features, proposals, targets=None):
|
|
x, detections, _ = self.box(features, proposals, targets)
|
|
x, detections, _ = self.mask(features, detections, targets)
|
|
return x, detections, {}
|
|
|
|
|
|
class ImageList(object):
|
|
def __init__(self, tensors, image_sizes):
|
|
self.tensors = tensors
|
|
self.image_sizes = image_sizes
|
|
|
|
def to(self, *args, **kwargs):
|
|
cast_tensor = self.tensors.to(*args, **kwargs)
|
|
return ImageList(cast_tensor, self.image_sizes)
|
|
|
|
|
|
def to_image_list(tensors, size_divisible=32):
|
|
# Preprocessing
|
|
if isinstance(tensors, Tensor) and size_divisible > 0:
|
|
tensors = [tensors]
|
|
|
|
if isinstance(tensors, ImageList):
|
|
return tensors
|
|
elif isinstance(tensors, Tensor):
|
|
# single tensor shape can be inferred
|
|
assert tensors.ndim == 4
|
|
image_sizes = [tensor.shape[-2:] for tensor in tensors]
|
|
return ImageList(tensors, image_sizes)
|
|
elif isinstance(tensors, (tuple, list)):
|
|
max_size = tuple(max(s) for s in zip(*[img.shape for img in tensors]))
|
|
if size_divisible > 0:
|
|
|
|
stride = size_divisible
|
|
max_size = list(max_size)
|
|
max_size[1] = int(math.ceil(max_size[1] / stride) * stride)
|
|
max_size[2] = int(math.ceil(max_size[2] / stride) * stride)
|
|
max_size = tuple(max_size)
|
|
|
|
batch_shape = (len(tensors),) + max_size
|
|
batched_imgs = np.zeros(batch_shape, dtype=_to_np_dtype(tensors[0].dtype))
|
|
for img, pad_img in zip(tensors, batched_imgs):
|
|
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]] += img.numpy()
|
|
|
|
batched_imgs = Tensor(batched_imgs)
|
|
image_sizes = [im.shape[-2:] for im in tensors]
|
|
|
|
return ImageList(batched_imgs, image_sizes)
|
|
else:
|
|
raise TypeError("Unsupported type for to_image_list: {}".format(type(tensors)))
|
|
|
|
|
|
class MaskRCNN:
|
|
def __init__(self, backbone: ResNet):
|
|
self.backbone = ResNetFPN(backbone, out_channels=256)
|
|
self.rpn = RPN(self.backbone.out_channels)
|
|
self.roi_heads = RoIHeads(self.backbone.out_channels)
|
|
|
|
def load_from_pretrained(self):
|
|
fn = Path('./') / "weights/maskrcnn.pt"
|
|
fetch("https://download.pytorch.org/models/maskrcnn/e2e_mask_rcnn_R_50_FPN_1x.pth", fn)
|
|
|
|
state_dict = torch_load(fn)['model']
|
|
loaded_keys = []
|
|
for k, v in state_dict.items():
|
|
if "module." in k:
|
|
k = k.replace("module.", "")
|
|
if "stem." in k:
|
|
k = k.replace("stem.", "")
|
|
if "fpn_inner" in k:
|
|
block_index = int(re.search(r"fpn_inner(\d+)", k).group(1))
|
|
k = re.sub(r"fpn_inner\d+", f"inner_blocks.{block_index - 1}", k)
|
|
if "fpn_layer" in k:
|
|
block_index = int(re.search(r"fpn_layer(\d+)", k).group(1))
|
|
k = re.sub(r"fpn_layer\d+", f"layer_blocks.{block_index - 1}", k)
|
|
loaded_keys.append(k)
|
|
get_child(self, k).assign(v.numpy()).realize()
|
|
return loaded_keys
|
|
|
|
def __call__(self, images):
|
|
images = to_image_list(images)
|
|
features = self.backbone(images.tensors)
|
|
proposals, _ = self.rpn(images, features)
|
|
x, result, _ = self.roi_heads(features, proposals)
|
|
return result
|
|
|
|
|
|
if __name__ == '__main__':
|
|
resnet = resnet = ResNet(50, num_classes=None, stride_in_1x1=True)
|
|
model = MaskRCNN(backbone=resnet)
|
|
model.load_from_pretrained()
|