import re import math import os import numpy as np from pathlib import Path from tinygrad import nn, Tensor, dtypes from tinygrad.tensor import _to_np_dtype from tinygrad.helpers import get_child, fetch from tinygrad.nn.state import torch_load from extra.models.resnet import ResNet from extra.models.retinanet import nms as _box_nms USE_NP_GATHER = os.getenv('FULL_TINYGRAD', '0') == '0' def rint(tensor): x = (tensor*2).cast(dtypes.int32).contiguous().cast(dtypes.float32)/2 return (x<0).where(x.floor(), x.ceil()) def nearest_interpolate(tensor, scale_factor): bs, c, py, px = tensor.shape return tensor.reshape(bs, c, py, 1, px, 1).expand(bs, c, py, scale_factor, px, scale_factor).reshape(bs, c, py * scale_factor, px * scale_factor) def meshgrid(x, y): grid_x = Tensor.cat(*[x[idx:idx+1].expand(y.shape).unsqueeze(0) for idx in range(x.shape[0])]) grid_y = Tensor.cat(*[y.unsqueeze(0)]*x.shape[0]) return grid_x.reshape(-1, 1), grid_y.reshape(-1, 1) def topk(input_, k, dim=-1, largest=True, sorted=False): k = min(k, input_.shape[dim]-1) input_ = input_.numpy() if largest: input_ *= -1 ind = np.argpartition(input_, k, axis=dim) if largest: input_ *= -1 ind = np.take(ind, np.arange(k), axis=dim) # k non-sorted indices input_ = np.take_along_axis(input_, ind, axis=dim) # k non-sorted values if not sorted: return Tensor(input_), ind if largest: input_ *= -1 ind_part = np.argsort(input_, axis=dim) ind = np.take_along_axis(ind, ind_part, axis=dim) if largest: input_ *= -1 val = np.take_along_axis(input_, ind_part, axis=dim) return Tensor(val), ind # This is very slow for large arrays, or indices def _gather(array, indices): indices = indices.float().to(array.device) reshape_arg = [1]*array.ndim + [array.shape[-1]] return Tensor.where( indices.unsqueeze(indices.ndim).expand(*indices.shape, array.shape[-1]) == Tensor.arange(array.shape[-1]).reshape(*reshape_arg).expand(*indices.shape, array.shape[-1]), array, 0, ).sum(indices.ndim) # TODO: replace npgather with a faster gather using tinygrad only # NOTE: this blocks the gradient def npgather(array,indices): if isinstance(array, Tensor): array = array.numpy() if isinstance(indices, Tensor): indices = indices.numpy() if isinstance(indices, list): indices = np.asarray(indices) return Tensor(array[indices.astype(int)]) def get_strides(shape): prod = [1] for idx in range(len(shape)-1, -1, -1): prod.append(prod[-1] * shape[idx]) # something about ints is broken with gpu, cuda return Tensor(prod[::-1][1:], dtype=dtypes.int32).unsqueeze(0) # with keys as integer array for all axes def tensor_getitem(tensor, *keys): # something about ints is broken with gpu, cuda flat_keys = Tensor.stack(*[key.expand((sum(keys)).shape).reshape(-1) for key in keys], dim=1).cast(dtypes.int32) strides = get_strides(tensor.shape) idxs = (flat_keys * strides).sum(1) gatherer = npgather if USE_NP_GATHER else _gather return gatherer(tensor.reshape(-1), idxs).reshape(sum(keys).shape) # for gather with indicies only on axis=0 def tensor_gather(tensor, indices): if not isinstance(indices, Tensor): indices = Tensor(indices, requires_grad=False) if len(tensor.shape) > 2: rem_shape = list(tensor.shape)[1:] tensor = tensor.reshape(tensor.shape[0], -1) else: rem_shape = None if len(tensor.shape) > 1: tensor = tensor.T repeat_arg = [1]*(tensor.ndim-1) + [tensor.shape[-2]] indices = indices.unsqueeze(indices.ndim).repeat(repeat_arg) ret = _gather(tensor, indices) if rem_shape: ret = ret.reshape([indices.shape[0]] + rem_shape) else: ret = _gather(tensor, indices) del indices return ret class LastLevelMaxPool: def __call__(self, x): return [Tensor.max_pool2d(x, 1, 2)] # transpose FLIP_LEFT_RIGHT = 0 FLIP_TOP_BOTTOM = 1 def permute_and_flatten(layer:Tensor, N, A, C, H, W): layer = layer.reshape(N, -1, C, H, W) layer = layer.permute(0, 3, 4, 1, 2) layer = layer.reshape(N, -1, C) return layer class BoxList: def __init__(self, bbox, image_size, mode="xyxy"): if not isinstance(bbox, Tensor): bbox = Tensor(bbox) if bbox.ndim != 2: raise ValueError( "bbox should have 2 dimensions, got {}".format(bbox.ndim) ) if bbox.shape[-1] != 4: raise ValueError( "last dimenion of bbox should have a " "size of 4, got {}".format(bbox.shape[-1]) ) if mode not in ("xyxy", "xywh"): raise ValueError("mode should be 'xyxy' or 'xywh'") self.bbox = bbox self.size = image_size # (image_width, image_height) self.mode = mode self.extra_fields = {} def __repr__(self): s = self.__class__.__name__ + "(" s += "num_boxes={}, ".format(len(self)) s += "image_width={}, ".format(self.size[0]) s += "image_height={}, ".format(self.size[1]) s += "mode={})".format(self.mode) return s def area(self): box = self.bbox if self.mode == "xyxy": TO_REMOVE = 1 area = (box[:, 2] - box[:, 0] + TO_REMOVE) * (box[:, 3] - box[:, 1] + TO_REMOVE) elif self.mode == "xywh": area = box[:, 2] * box[:, 3] return area def add_field(self, field, field_data): self.extra_fields[field] = field_data def get_field(self, field): return self.extra_fields[field] def has_field(self, field): return field in self.extra_fields def fields(self): return list(self.extra_fields.keys()) def _copy_extra_fields(self, bbox): for k, v in bbox.extra_fields.items(): self.extra_fields[k] = v def convert(self, mode): if mode == self.mode: return self xmin, ymin, xmax, ymax = self._split_into_xyxy() if mode == "xyxy": bbox = Tensor.cat(*(xmin, ymin, xmax, ymax), dim=-1) bbox = BoxList(bbox, self.size, mode=mode) else: TO_REMOVE = 1 bbox = Tensor.cat( *(xmin, ymin, xmax - xmin + TO_REMOVE, ymax - ymin + TO_REMOVE), dim=-1 ) bbox = BoxList(bbox, self.size, mode=mode) bbox._copy_extra_fields(self) return bbox def _split_into_xyxy(self): if self.mode == "xyxy": xmin, ymin, xmax, ymax = self.bbox.chunk(4, dim=-1) return xmin, ymin, xmax, ymax if self.mode == "xywh": TO_REMOVE = 1 xmin, ymin, w, h = self.bbox.chunk(4, dim=-1) return ( xmin, ymin, xmin + (w - TO_REMOVE).clamp(min=0), ymin + (h - TO_REMOVE).clamp(min=0), ) def resize(self, size, *args, **kwargs): ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(size, self.size)) if ratios[0] == ratios[1]: ratio = ratios[0] scaled_box = self.bbox * ratio bbox = BoxList(scaled_box, size, mode=self.mode) for k, v in self.extra_fields.items(): if not isinstance(v, Tensor): v = v.resize(size, *args, **kwargs) bbox.add_field(k, v) return bbox ratio_width, ratio_height = ratios xmin, ymin, xmax, ymax = self._split_into_xyxy() scaled_xmin = xmin * ratio_width scaled_xmax = xmax * ratio_width scaled_ymin = ymin * ratio_height scaled_ymax = ymax * ratio_height scaled_box = Tensor.cat( *(scaled_xmin, scaled_ymin, scaled_xmax, scaled_ymax), dim=-1 ) bbox = BoxList(scaled_box, size, mode="xyxy") for k, v in self.extra_fields.items(): if not isinstance(v, Tensor): v = v.resize(size, *args, **kwargs) bbox.add_field(k, v) return bbox.convert(self.mode) def transpose(self, method): image_width, image_height = self.size xmin, ymin, xmax, ymax = self._split_into_xyxy() if method == FLIP_LEFT_RIGHT: TO_REMOVE = 1 transposed_xmin = image_width - xmax - TO_REMOVE transposed_xmax = image_width - xmin - TO_REMOVE transposed_ymin = ymin transposed_ymax = ymax elif method == FLIP_TOP_BOTTOM: transposed_xmin = xmin transposed_xmax = xmax transposed_ymin = image_height - ymax transposed_ymax = image_height - ymin transposed_boxes = Tensor.cat( *(transposed_xmin, transposed_ymin, transposed_xmax, transposed_ymax), dim=-1 ) bbox = BoxList(transposed_boxes, self.size, mode="xyxy") for k, v in self.extra_fields.items(): if not isinstance(v, Tensor): v = v.transpose(method) bbox.add_field(k, v) return bbox.convert(self.mode) def clip_to_image(self, remove_empty=True): TO_REMOVE = 1 bb1 = self.bbox.clip(min_=0, max_=self.size[0] - TO_REMOVE)[:, 0] bb2 = self.bbox.clip(min_=0, max_=self.size[1] - TO_REMOVE)[:, 1] bb3 = self.bbox.clip(min_=0, max_=self.size[0] - TO_REMOVE)[:, 2] bb4 = self.bbox.clip(min_=0, max_=self.size[1] - TO_REMOVE)[:, 3] self.bbox = Tensor.stack(bb1, bb2, bb3, bb4, dim=1) if remove_empty: box = self.bbox keep = (box[:, 3] > box[:, 1]) & (box[:, 2] > box[:, 0]) return self[keep] return self def __getitem__(self, item): if isinstance(item, list): if len(item) == 0: return [] if sum(item) == len(item) and isinstance(item[0], bool): return self bbox = BoxList(tensor_gather(self.bbox, item), self.size, self.mode) for k, v in self.extra_fields.items(): bbox.add_field(k, tensor_gather(v, item)) return bbox def __len__(self): return self.bbox.shape[0] def cat_boxlist(bboxes): size = bboxes[0].size mode = bboxes[0].mode fields = set(bboxes[0].fields()) cat_box_list = [bbox.bbox for bbox in bboxes if bbox.bbox.shape[0] > 0] if len(cat_box_list) > 0: cat_boxes = BoxList(Tensor.cat(*cat_box_list, dim=0), size, mode) else: cat_boxes = BoxList(bboxes[0].bbox, size, mode) for field in fields: cat_field_list = [bbox.get_field(field) for bbox in bboxes if bbox.get_field(field).shape[0] > 0] if len(cat_box_list) > 0: data = Tensor.cat(*cat_field_list, dim=0) else: data = bboxes[0].get_field(field) cat_boxes.add_field(field, data) return cat_boxes class FPN: def __init__(self, in_channels_list, out_channels): self.inner_blocks, self.layer_blocks = [], [] for in_channels in in_channels_list: self.inner_blocks.append(nn.Conv2d(in_channels, out_channels, kernel_size=1)) self.layer_blocks.append(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)) self.top_block = LastLevelMaxPool() def __call__(self, x: Tensor): last_inner = self.inner_blocks[-1](x[-1]) results = [] results.append(self.layer_blocks[-1](last_inner)) for feature, inner_block, layer_block in zip( x[:-1][::-1], self.inner_blocks[:-1][::-1], self.layer_blocks[:-1][::-1] ): if not inner_block: continue inner_top_down = nearest_interpolate(last_inner, scale_factor=2) inner_lateral = inner_block(feature) last_inner = inner_lateral + inner_top_down layer_result = layer_block(last_inner) results.insert(0, layer_result) last_results = self.top_block(results[-1]) results.extend(last_results) return tuple(results) class ResNetFPN: def __init__(self, resnet, out_channels=256): self.out_channels = out_channels self.body = resnet in_channels_stage2 = 256 in_channels_list = [ in_channels_stage2, in_channels_stage2 * 2, in_channels_stage2 * 4, in_channels_stage2 * 8, ] self.fpn = FPN(in_channels_list, out_channels) def __call__(self, x): x = self.body(x) return self.fpn(x) class AnchorGenerator: def __init__( self, sizes=(32, 64, 128, 256, 512), aspect_ratios=(0.5, 1.0, 2.0), anchor_strides=(4, 8, 16, 32, 64), straddle_thresh=0, ): if len(anchor_strides) == 1: anchor_stride = anchor_strides[0] cell_anchors = [ generate_anchors(anchor_stride, sizes, aspect_ratios) ] else: if len(anchor_strides) != len(sizes): raise RuntimeError("FPN should have #anchor_strides == #sizes") cell_anchors = [ generate_anchors( anchor_stride, size if isinstance(size, (tuple, list)) else (size,), aspect_ratios ) for anchor_stride, size in zip(anchor_strides, sizes) ] self.strides = anchor_strides self.cell_anchors = cell_anchors self.straddle_thresh = straddle_thresh def num_anchors_per_location(self): return [cell_anchors.shape[0] for cell_anchors in self.cell_anchors] def grid_anchors(self, grid_sizes): anchors = [] for size, stride, base_anchors in zip( grid_sizes, self.strides, self.cell_anchors ): grid_height, grid_width = size device = base_anchors.device shifts_x = Tensor.arange( start=0, stop=grid_width * stride, step=stride, dtype=dtypes.float32, device=device ) shifts_y = Tensor.arange( start=0, stop=grid_height * stride, step=stride, dtype=dtypes.float32, device=device ) shift_y, shift_x = meshgrid(shifts_y, shifts_x) shift_x = shift_x.reshape(-1) shift_y = shift_y.reshape(-1) shifts = Tensor.stack(shift_x, shift_y, shift_x, shift_y, dim=1) anchors.append( (shifts.reshape(-1, 1, 4) + base_anchors.reshape(1, -1, 4)).reshape(-1, 4) ) return anchors def add_visibility_to(self, boxlist): image_width, image_height = boxlist.size anchors = boxlist.bbox if self.straddle_thresh >= 0: inds_inside = ( (anchors[:, 0] >= -self.straddle_thresh) * (anchors[:, 1] >= -self.straddle_thresh) * (anchors[:, 2] < image_width + self.straddle_thresh) * (anchors[:, 3] < image_height + self.straddle_thresh) ) else: device = anchors.device inds_inside = Tensor.ones(anchors.shape[0], dtype=dtypes.uint8, device=device) boxlist.add_field("visibility", inds_inside) def __call__(self, image_list, feature_maps): grid_sizes = [feature_map.shape[-2:] for feature_map in feature_maps] anchors_over_all_feature_maps = self.grid_anchors(grid_sizes) anchors = [] for (image_height, image_width) in image_list.image_sizes: anchors_in_image = [] for anchors_per_feature_map in anchors_over_all_feature_maps: boxlist = BoxList( anchors_per_feature_map, (image_width, image_height), mode="xyxy" ) self.add_visibility_to(boxlist) anchors_in_image.append(boxlist) anchors.append(anchors_in_image) return anchors def generate_anchors( stride=16, sizes=(32, 64, 128, 256, 512), aspect_ratios=(0.5, 1, 2) ): return _generate_anchors(stride, Tensor(list(sizes)) / stride, Tensor(list(aspect_ratios))) def _generate_anchors(base_size, scales, aspect_ratios): anchor = Tensor([1, 1, base_size, base_size]) - 1 anchors = _ratio_enum(anchor, aspect_ratios) anchors = Tensor.cat( *[_scale_enum(anchors[i, :], scales).reshape(-1, 4) for i in range(anchors.shape[0])] ) return anchors def _whctrs(anchor): w = anchor[2] - anchor[0] + 1 h = anchor[3] - anchor[1] + 1 x_ctr = anchor[0] + 0.5 * (w - 1) y_ctr = anchor[1] + 0.5 * (h - 1) return w, h, x_ctr, y_ctr def _mkanchors(ws, hs, x_ctr, y_ctr): ws = ws[:, None] hs = hs[:, None] anchors = Tensor.cat(*( x_ctr - 0.5 * (ws - 1), y_ctr - 0.5 * (hs - 1), x_ctr + 0.5 * (ws - 1), y_ctr + 0.5 * (hs - 1), ), dim=1) return anchors def _ratio_enum(anchor, ratios): w, h, x_ctr, y_ctr = _whctrs(anchor) size = w * h size_ratios = size / ratios ws = rint(Tensor.sqrt(size_ratios)) hs = rint(ws * ratios) anchors = _mkanchors(ws, hs, x_ctr, y_ctr) return anchors def _scale_enum(anchor, scales): w, h, x_ctr, y_ctr = _whctrs(anchor) ws = w * scales hs = h * scales anchors = _mkanchors(ws, hs, x_ctr, y_ctr) return anchors class RPNHead: def __init__(self, in_channels, num_anchors): self.conv = nn.Conv2d(in_channels, 256, kernel_size=3, padding=1) self.cls_logits = nn.Conv2d(256, num_anchors, kernel_size=1) self.bbox_pred = nn.Conv2d(256, num_anchors * 4, kernel_size=1) def __call__(self, x): logits = [] bbox_reg = [] for feature in x: t = Tensor.relu(self.conv(feature)) logits.append(self.cls_logits(t)) bbox_reg.append(self.bbox_pred(t)) return logits, bbox_reg class BoxCoder(object): def __init__(self, weights, bbox_xform_clip=math.log(1000. / 16)): self.weights = weights self.bbox_xform_clip = bbox_xform_clip def encode(self, reference_boxes, proposals): TO_REMOVE = 1 # TODO remove ex_widths = proposals[:, 2] - proposals[:, 0] + TO_REMOVE ex_heights = proposals[:, 3] - proposals[:, 1] + TO_REMOVE ex_ctr_x = proposals[:, 0] + 0.5 * ex_widths ex_ctr_y = proposals[:, 1] + 0.5 * ex_heights gt_widths = reference_boxes[:, 2] - reference_boxes[:, 0] + TO_REMOVE gt_heights = reference_boxes[:, 3] - reference_boxes[:, 1] + TO_REMOVE gt_ctr_x = reference_boxes[:, 0] + 0.5 * gt_widths gt_ctr_y = reference_boxes[:, 1] + 0.5 * gt_heights wx, wy, ww, wh = self.weights targets_dx = wx * (gt_ctr_x - ex_ctr_x) / ex_widths targets_dy = wy * (gt_ctr_y - ex_ctr_y) / ex_heights targets_dw = ww * Tensor.log(gt_widths / ex_widths) targets_dh = wh * Tensor.log(gt_heights / ex_heights) targets = Tensor.stack(targets_dx, targets_dy, targets_dw, targets_dh, dim=1) return targets def decode(self, rel_codes, boxes): boxes = boxes.cast(rel_codes.dtype) rel_codes = rel_codes TO_REMOVE = 1 # TODO remove widths = boxes[:, 2] - boxes[:, 0] + TO_REMOVE heights = boxes[:, 3] - boxes[:, 1] + TO_REMOVE ctr_x = boxes[:, 0] + 0.5 * widths ctr_y = boxes[:, 1] + 0.5 * heights wx, wy, ww, wh = self.weights dx = rel_codes[:, 0::4] / wx dy = rel_codes[:, 1::4] / wy dw = rel_codes[:, 2::4] / ww dh = rel_codes[:, 3::4] / wh # Prevent sending too large values into Tensor.exp() dw = dw.clip(min_=dw.min(), max_=self.bbox_xform_clip) dh = dh.clip(min_=dh.min(), max_=self.bbox_xform_clip) pred_ctr_x = dx * widths[:, None] + ctr_x[:, None] pred_ctr_y = dy * heights[:, None] + ctr_y[:, None] pred_w = dw.exp() * widths[:, None] pred_h = dh.exp() * heights[:, None] x = pred_ctr_x - 0.5 * pred_w y = pred_ctr_y - 0.5 * pred_h w = pred_ctr_x + 0.5 * pred_w - 1 h = pred_ctr_y + 0.5 * pred_h - 1 pred_boxes = Tensor.stack(x, y, w, h).permute(1,2,0).reshape(rel_codes.shape[0], rel_codes.shape[1]) return pred_boxes def boxlist_nms(boxlist, nms_thresh, max_proposals=-1, score_field="scores"): if nms_thresh <= 0: return boxlist mode = boxlist.mode boxlist = boxlist.convert("xyxy") boxes = boxlist.bbox score = boxlist.get_field(score_field) keep = _box_nms(boxes.numpy(), score.numpy(), nms_thresh) if max_proposals > 0: keep = keep[:max_proposals] boxlist = boxlist[keep] return boxlist.convert(mode) def remove_small_boxes(boxlist, min_size): xywh_boxes = boxlist.convert("xywh").bbox _, _, ws, hs = xywh_boxes.chunk(4, dim=1) keep = (( (ws >= min_size) * (hs >= min_size) ) > 0).reshape(-1) if keep.sum().numpy() == len(boxlist): return boxlist else: keep = keep.numpy().nonzero()[0] return boxlist[keep] class RPNPostProcessor: # Not used in Loss calculation def __init__( self, pre_nms_top_n, post_nms_top_n, nms_thresh, min_size, box_coder=None, fpn_post_nms_top_n=None, ): self.pre_nms_top_n = pre_nms_top_n self.post_nms_top_n = post_nms_top_n self.nms_thresh = nms_thresh self.min_size = min_size if box_coder is None: box_coder = BoxCoder(weights=(1.0, 1.0, 1.0, 1.0)) self.box_coder = box_coder if fpn_post_nms_top_n is None: fpn_post_nms_top_n = post_nms_top_n self.fpn_post_nms_top_n = fpn_post_nms_top_n def forward_for_single_feature_map(self, anchors, objectness, box_regression): device = objectness.device N, A, H, W = objectness.shape objectness = permute_and_flatten(objectness, N, A, 1, H, W).reshape(N, -1) objectness = objectness.sigmoid() box_regression = permute_and_flatten(box_regression, N, A, 4, H, W) num_anchors = A * H * W pre_nms_top_n = min(self.pre_nms_top_n, num_anchors) objectness, topk_idx = topk(objectness, pre_nms_top_n, dim=1, sorted=False) concat_anchors = Tensor.cat(*[a.bbox for a in anchors], dim=0).reshape(N, -1, 4) image_shapes = [box.size for box in anchors] box_regression_list = [] concat_anchors_list = [] for batch_idx in range(N): box_regression_list.append(tensor_gather(box_regression[batch_idx], topk_idx[batch_idx])) concat_anchors_list.append(tensor_gather(concat_anchors[batch_idx], topk_idx[batch_idx])) box_regression = Tensor.stack(*box_regression_list) concat_anchors = Tensor.stack(*concat_anchors_list) proposals = self.box_coder.decode( box_regression.reshape(-1, 4), concat_anchors.reshape(-1, 4) ) proposals = proposals.reshape(N, -1, 4) result = [] for proposal, score, im_shape in zip(proposals, objectness, image_shapes): boxlist = BoxList(proposal, im_shape, mode="xyxy") boxlist.add_field("objectness", score) boxlist = boxlist.clip_to_image(remove_empty=False) boxlist = remove_small_boxes(boxlist, self.min_size) boxlist = boxlist_nms( boxlist, self.nms_thresh, max_proposals=self.post_nms_top_n, score_field="objectness", ) result.append(boxlist) return result def __call__(self, anchors, objectness, box_regression): sampled_boxes = [] num_levels = len(objectness) anchors = list(zip(*anchors)) for a, o, b in zip(anchors, objectness, box_regression): sampled_boxes.append(self.forward_for_single_feature_map(a, o, b)) boxlists = list(zip(*sampled_boxes)) boxlists = [cat_boxlist(boxlist) for boxlist in boxlists] if num_levels > 1: boxlists = self.select_over_all_levels(boxlists) return boxlists def select_over_all_levels(self, boxlists): num_images = len(boxlists) for i in range(num_images): objectness = boxlists[i].get_field("objectness") post_nms_top_n = min(self.fpn_post_nms_top_n, objectness.shape[0]) _, inds_sorted = topk(objectness, post_nms_top_n, dim=0, sorted=False ) boxlists[i] = boxlists[i][inds_sorted] return boxlists class RPN: def __init__(self, in_channels): self.anchor_generator = AnchorGenerator() in_channels = 256 head = RPNHead( in_channels, self.anchor_generator.num_anchors_per_location()[0] ) rpn_box_coder = BoxCoder(weights=(1.0, 1.0, 1.0, 1.0)) box_selector_test = RPNPostProcessor( pre_nms_top_n=1000, post_nms_top_n=1000, nms_thresh=0.7, min_size=0, box_coder=rpn_box_coder, fpn_post_nms_top_n=1000 ) self.head = head self.box_selector_test = box_selector_test def __call__(self, images, features, targets=None): objectness, rpn_box_regression = self.head(features) anchors = self.anchor_generator(images, features) boxes = self.box_selector_test(anchors, objectness, rpn_box_regression) return boxes, {} def make_conv3x3( in_channels, out_channels, dilation=1, stride=1, use_gn=False, ): conv = nn.Conv2d( in_channels, out_channels, kernel_size=3, stride=stride, padding=dilation, dilation=dilation, bias=False if use_gn else True ) return conv class MaskRCNNFPNFeatureExtractor: def __init__(self): resolution = 14 scales = (0.25, 0.125, 0.0625, 0.03125) sampling_ratio = 2 pooler = Pooler( output_size=(resolution, resolution), scales=scales, sampling_ratio=sampling_ratio, ) input_size = 256 self.pooler = pooler use_gn = False layers = (256, 256, 256, 256) dilation = 1 self.mask_fcn1 = make_conv3x3(input_size, layers[0], dilation=dilation, stride=1, use_gn=use_gn) self.mask_fcn2 = make_conv3x3(layers[0], layers[1], dilation=dilation, stride=1, use_gn=use_gn) self.mask_fcn3 = make_conv3x3(layers[1], layers[2], dilation=dilation, stride=1, use_gn=use_gn) self.mask_fcn4 = make_conv3x3(layers[2], layers[3], dilation=dilation, stride=1, use_gn=use_gn) self.blocks = [self.mask_fcn1, self.mask_fcn2, self.mask_fcn3, self.mask_fcn4] def __call__(self, x, proposals): x = self.pooler(x, proposals) for layer in self.blocks: if x is not None: x = Tensor.relu(layer(x)) return x class MaskRCNNC4Predictor: def __init__(self): num_classes = 81 dim_reduced = 256 num_inputs = dim_reduced self.conv5_mask = nn.ConvTranspose2d(num_inputs, dim_reduced, 2, 2, 0) self.mask_fcn_logits = nn.Conv2d(dim_reduced, num_classes, 1, 1, 0) def __call__(self, x): x = Tensor.relu(self.conv5_mask(x)) return self.mask_fcn_logits(x) class FPN2MLPFeatureExtractor: def __init__(self, cfg): resolution = 7 scales = (0.25, 0.125, 0.0625, 0.03125) sampling_ratio = 2 pooler = Pooler( output_size=(resolution, resolution), scales=scales, sampling_ratio=sampling_ratio, ) input_size = 256 * resolution ** 2 representation_size = 1024 self.pooler = pooler self.fc6 = nn.Linear(input_size, representation_size) self.fc7 = nn.Linear(representation_size, representation_size) def __call__(self, x, proposals): x = self.pooler(x, proposals) x = x.reshape(x.shape[0], -1) x = Tensor.relu(self.fc6(x)) x = Tensor.relu(self.fc7(x)) return x def _bilinear_interpolate( input, # [N, C, H, W] roi_batch_ind, # [K] y, # [K, PH, IY] x, # [K, PW, IX] ymask, # [K, IY] xmask, # [K, IX] ): _, channels, height, width = input.shape y = y.clip(min_=0.0, max_=float(height-1)) x = x.clip(min_=0.0, max_=float(width-1)) # Tensor.where doesnt work well with int32 data so cast to float32 y_low = y.cast(dtypes.int32).contiguous().float().contiguous() x_low = x.cast(dtypes.int32).contiguous().float().contiguous() y_high = Tensor.where(y_low >= height - 1, float(height - 1), y_low + 1) y_low = Tensor.where(y_low >= height - 1, float(height - 1), y_low) x_high = Tensor.where(x_low >= width - 1, float(width - 1), x_low + 1) x_low = Tensor.where(x_low >= width - 1, float(width - 1), x_low) ly = y - y_low lx = x - x_low hy = 1.0 - ly hx = 1.0 - lx def masked_index( y, # [K, PH, IY] x, # [K, PW, IX] ): if ymask is not None: assert xmask is not None y = Tensor.where(ymask[:, None, :], y, 0) x = Tensor.where(xmask[:, None, :], x, 0) key1 = roi_batch_ind[:, None, None, None, None, None] key2 = Tensor.arange(channels, device=input.device)[None, :, None, None, None, None] key3 = y[:, None, :, None, :, None] key4 = x[:, None, None, :, None, :] return tensor_getitem(input,key1,key2,key3,key4) # [K, C, PH, PW, IY, IX] v1 = masked_index(y_low, x_low) v2 = masked_index(y_low, x_high) v3 = masked_index(y_high, x_low) v4 = masked_index(y_high, x_high) # all ws preemptively [K, C, PH, PW, IY, IX] def outer_prod(y, x): return y[:, None, :, None, :, None] * x[:, None, None, :, None, :] w1 = outer_prod(hy, hx) w2 = outer_prod(hy, lx) w3 = outer_prod(ly, hx) w4 = outer_prod(ly, lx) val = w1*v1 + w2*v2 + w3*v3 + w4*v4 return val #https://pytorch.org/vision/main/_modules/torchvision/ops/roi_align.html#roi_align def _roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned): orig_dtype = input.dtype _, _, height, width = input.shape ph = Tensor.arange(pooled_height, device=input.device) pw = Tensor.arange(pooled_width, device=input.device) roi_batch_ind = rois[:, 0].cast(dtypes.int32).contiguous() offset = 0.5 if aligned else 0.0 roi_start_w = rois[:, 1] * spatial_scale - offset roi_start_h = rois[:, 2] * spatial_scale - offset roi_end_w = rois[:, 3] * spatial_scale - offset roi_end_h = rois[:, 4] * spatial_scale - offset roi_width = roi_end_w - roi_start_w roi_height = roi_end_h - roi_start_h if not aligned: roi_width = roi_width.maximum(1.0) roi_height = roi_height.maximum(1.0) bin_size_h = roi_height / pooled_height bin_size_w = roi_width / pooled_width exact_sampling = sampling_ratio > 0 roi_bin_grid_h = sampling_ratio if exact_sampling else (roi_height / pooled_height).ceil() roi_bin_grid_w = sampling_ratio if exact_sampling else (roi_width / pooled_width).ceil() if exact_sampling: count = max(roi_bin_grid_h * roi_bin_grid_w, 1) iy = Tensor.arange(roi_bin_grid_h, device=input.device) ix = Tensor.arange(roi_bin_grid_w, device=input.device) ymask = None xmask = None else: count = (roi_bin_grid_h * roi_bin_grid_w).maximum(1) iy = Tensor.arange(height, device=input.device) ix = Tensor.arange(width, device=input.device) ymask = iy[None, :] < roi_bin_grid_h[:, None] xmask = ix[None, :] < roi_bin_grid_w[:, None] def from_K(t): return t[:, None, None] y = ( from_K(roi_start_h) + ph[None, :, None] * from_K(bin_size_h) + (iy[None, None, :] + 0.5) * from_K(bin_size_h / roi_bin_grid_h) ) x = ( from_K(roi_start_w) + pw[None, :, None] * from_K(bin_size_w) + (ix[None, None, :] + 0.5) * from_K(bin_size_w / roi_bin_grid_w) ) val = _bilinear_interpolate(input, roi_batch_ind, y, x, ymask, xmask) if not exact_sampling: val = ymask[:, None, None, None, :, None].where(val, 0) val = xmask[:, None, None, None, None, :].where(val, 0) output = val.sum((-1, -2)) if isinstance(count, Tensor): output /= count[:, None, None, None] else: output /= count output = output.cast(orig_dtype) return output class ROIAlign: def __init__(self, output_size, spatial_scale, sampling_ratio): self.output_size = output_size self.spatial_scale = spatial_scale self.sampling_ratio = sampling_ratio def __call__(self, input, rois): output = _roi_align( input, rois, self.spatial_scale, self.output_size[0], self.output_size[1], self.sampling_ratio, aligned=False ) return output class LevelMapper: def __init__(self, k_min, k_max, canonical_scale=224, canonical_level=4, eps=1e-6): self.k_min = k_min self.k_max = k_max self.s0 = canonical_scale self.lvl0 = canonical_level self.eps = eps def __call__(self, boxlists): s = Tensor.sqrt(Tensor.cat(*[boxlist.area() for boxlist in boxlists])) target_lvls = (self.lvl0 + Tensor.log2(s / self.s0 + self.eps)).floor() target_lvls = target_lvls.clip(min_=self.k_min, max_=self.k_max) return target_lvls - self.k_min class Pooler: def __init__(self, output_size, scales, sampling_ratio): self.output_size = output_size self.scales = scales self.sampling_ratio = sampling_ratio poolers = [] for scale in scales: poolers.append( ROIAlign( output_size, spatial_scale=scale, sampling_ratio=sampling_ratio ) ) self.poolers = poolers self.output_size = output_size lvl_min = -math.log2(scales[0]) lvl_max = -math.log2(scales[-1]) self.map_levels = LevelMapper(lvl_min, lvl_max) def convert_to_roi_format(self, boxes): concat_boxes = Tensor.cat(*[b.bbox for b in boxes], dim=0) device, dtype = concat_boxes.device, concat_boxes.dtype ids = Tensor.cat( *[ Tensor.full((len(b), 1), i, dtype=dtype, device=device) for i, b in enumerate(boxes) ], dim=0, ) if concat_boxes.shape[0] != 0: rois = Tensor.cat(*[ids, concat_boxes], dim=1) return rois def __call__(self, x, boxes): num_levels = len(self.poolers) rois = self.convert_to_roi_format(boxes) if rois is not None: if num_levels == 1: return self.poolers[0](x[0], rois) levels = self.map_levels(boxes) results = [] all_idxs = [] for level, (per_level_feature, pooler) in enumerate(zip(x, self.poolers)): # this is fine because no grad will flow through index idx_in_level = (levels.numpy() == level).nonzero()[0] if len(idx_in_level) > 0: rois_per_level = tensor_gather(rois, idx_in_level) pooler_output = pooler(per_level_feature, rois_per_level) all_idxs.extend(idx_in_level) results.append(pooler_output) return tensor_gather(Tensor.cat(*results), [x[0] for x in sorted({i:idx for i, idx in enumerate(all_idxs)}.items(), key=lambda x: x[1])]) class FPNPredictor: def __init__(self): num_classes = 81 representation_size = 1024 self.cls_score = nn.Linear(representation_size, num_classes) num_bbox_reg_classes = num_classes self.bbox_pred = nn.Linear(representation_size, num_bbox_reg_classes * 4) def __call__(self, x): scores = self.cls_score(x) bbox_deltas = self.bbox_pred(x) return scores, bbox_deltas class PostProcessor: # Not used in training def __init__( self, score_thresh=0.05, nms=0.5, detections_per_img=100, box_coder=None, cls_agnostic_bbox_reg=False ): self.score_thresh = score_thresh self.nms = nms self.detections_per_img = detections_per_img if box_coder is None: box_coder = BoxCoder(weights=(10., 10., 5., 5.)) self.box_coder = box_coder self.cls_agnostic_bbox_reg = cls_agnostic_bbox_reg def __call__(self, x, boxes): class_logits, box_regression = x class_prob = Tensor.softmax(class_logits, -1) image_shapes = [box.size for box in boxes] boxes_per_image = [len(box) for box in boxes] concat_boxes = Tensor.cat(*[a.bbox for a in boxes], dim=0) if self.cls_agnostic_bbox_reg: box_regression = box_regression[:, -4:] proposals = self.box_coder.decode( box_regression.reshape(sum(boxes_per_image), -1), concat_boxes ) if self.cls_agnostic_bbox_reg: proposals = proposals.repeat([1, class_prob.shape[1]]) num_classes = class_prob.shape[1] proposals = proposals.unsqueeze(0) class_prob = class_prob.unsqueeze(0) results = [] for prob, boxes_per_img, image_shape in zip( class_prob, proposals, image_shapes ): boxlist = self.prepare_boxlist(boxes_per_img, prob, image_shape) boxlist = boxlist.clip_to_image(remove_empty=False) boxlist = self.filter_results(boxlist, num_classes) results.append(boxlist) return results def prepare_boxlist(self, boxes, scores, image_shape): boxes = boxes.reshape(-1, 4) scores = scores.reshape(-1) boxlist = BoxList(boxes, image_shape, mode="xyxy") boxlist.add_field("scores", scores) return boxlist def filter_results(self, boxlist, num_classes): boxes = boxlist.bbox.reshape(-1, num_classes * 4) scores = boxlist.get_field("scores").reshape(-1, num_classes) device = scores.device result = [] scores = scores.numpy() boxes = boxes.numpy() inds_all = scores > self.score_thresh for j in range(1, num_classes): inds = inds_all[:, j].nonzero()[0] # This needs to be done in numpy because it can create empty arrays scores_j = scores[inds, j] boxes_j = boxes[inds, j * 4: (j + 1) * 4] boxes_j = Tensor(boxes_j) scores_j = Tensor(scores_j) boxlist_for_class = BoxList(boxes_j, boxlist.size, mode="xyxy") boxlist_for_class.add_field("scores", scores_j) if len(boxlist_for_class): boxlist_for_class = boxlist_nms( boxlist_for_class, self.nms ) num_labels = len(boxlist_for_class) boxlist_for_class.add_field( "labels", Tensor.full((num_labels,), j, device=device) ) result.append(boxlist_for_class) result = cat_boxlist(result) number_of_detections = len(result) if number_of_detections > self.detections_per_img > 0: cls_scores = result.get_field("scores") image_thresh, _ = topk(cls_scores, k=self.detections_per_img) image_thresh = image_thresh.numpy()[-1] keep = (cls_scores.numpy() >= image_thresh).nonzero()[0] result = result[keep] return result class RoIBoxHead: def __init__(self, in_channels): self.feature_extractor = FPN2MLPFeatureExtractor(in_channels) self.predictor = FPNPredictor() self.post_processor = PostProcessor( score_thresh=0.05, nms=0.5, detections_per_img=100, box_coder=BoxCoder(weights=(10., 10., 5., 5.)), cls_agnostic_bbox_reg=False ) def __call__(self, features, proposals, targets=None): x = self.feature_extractor(features, proposals) class_logits, box_regression = self.predictor(x) if not Tensor.training: result = self.post_processor((class_logits, box_regression), proposals) return x, result, {} class MaskPostProcessor: # Not used in loss calculation def __call__(self, x, boxes): mask_prob = x.sigmoid().numpy() num_masks = x.shape[0] labels = [bbox.get_field("labels") for bbox in boxes] labels = Tensor.cat(*labels).numpy().astype(np.int32) index = np.arange(num_masks) mask_prob = mask_prob[index, labels][:, None] boxes_per_image, cumsum = [], 0 for box in boxes: cumsum += len(box) boxes_per_image.append(cumsum) # using numpy here as Tensor.chunk doesnt have custom chunk sizes mask_prob = np.split(mask_prob, boxes_per_image, axis=0) results = [] for prob, box in zip(mask_prob, boxes): bbox = BoxList(box.bbox, box.size, mode="xyxy") for field in box.fields(): bbox.add_field(field, box.get_field(field)) prob = Tensor(prob) bbox.add_field("mask", prob) results.append(bbox) return results class Mask: def __init__(self): self.feature_extractor = MaskRCNNFPNFeatureExtractor() self.predictor = MaskRCNNC4Predictor() self.post_processor = MaskPostProcessor() def __call__(self, features, proposals, targets=None): x = self.feature_extractor(features, proposals) if x: mask_logits = self.predictor(x) if not Tensor.training: result = self.post_processor(mask_logits, proposals) return x, result, {} return x, [], {} class RoIHeads: def __init__(self, in_channels): self.box = RoIBoxHead(in_channels) self.mask = Mask() def __call__(self, features, proposals, targets=None): x, detections, _ = self.box(features, proposals, targets) x, detections, _ = self.mask(features, detections, targets) return x, detections, {} class ImageList(object): def __init__(self, tensors, image_sizes): self.tensors = tensors self.image_sizes = image_sizes def to(self, *args, **kwargs): cast_tensor = self.tensors.to(*args, **kwargs) return ImageList(cast_tensor, self.image_sizes) def to_image_list(tensors, size_divisible=32): # Preprocessing if isinstance(tensors, Tensor) and size_divisible > 0: tensors = [tensors] if isinstance(tensors, ImageList): return tensors elif isinstance(tensors, Tensor): # single tensor shape can be inferred assert tensors.ndim == 4 image_sizes = [tensor.shape[-2:] for tensor in tensors] return ImageList(tensors, image_sizes) elif isinstance(tensors, (tuple, list)): max_size = tuple(max(s) for s in zip(*[img.shape for img in tensors])) if size_divisible > 0: stride = size_divisible max_size = list(max_size) max_size[1] = int(math.ceil(max_size[1] / stride) * stride) max_size[2] = int(math.ceil(max_size[2] / stride) * stride) max_size = tuple(max_size) batch_shape = (len(tensors),) + max_size batched_imgs = np.zeros(batch_shape, dtype=_to_np_dtype(tensors[0].dtype)) for img, pad_img in zip(tensors, batched_imgs): pad_img[: img.shape[0], : img.shape[1], : img.shape[2]] += img.numpy() batched_imgs = Tensor(batched_imgs) image_sizes = [im.shape[-2:] for im in tensors] return ImageList(batched_imgs, image_sizes) else: raise TypeError("Unsupported type for to_image_list: {}".format(type(tensors))) class MaskRCNN: def __init__(self, backbone: ResNet): self.backbone = ResNetFPN(backbone, out_channels=256) self.rpn = RPN(self.backbone.out_channels) self.roi_heads = RoIHeads(self.backbone.out_channels) def load_from_pretrained(self): fn = Path('./') / "weights/maskrcnn.pt" fetch("https://download.pytorch.org/models/maskrcnn/e2e_mask_rcnn_R_50_FPN_1x.pth", fn) state_dict = torch_load(fn)['model'] loaded_keys = [] for k, v in state_dict.items(): if "module." in k: k = k.replace("module.", "") if "stem." in k: k = k.replace("stem.", "") if "fpn_inner" in k: block_index = int(re.search(r"fpn_inner(\d+)", k).group(1)) k = re.sub(r"fpn_inner\d+", f"inner_blocks.{block_index - 1}", k) if "fpn_layer" in k: block_index = int(re.search(r"fpn_layer(\d+)", k).group(1)) k = re.sub(r"fpn_layer\d+", f"layer_blocks.{block_index - 1}", k) loaded_keys.append(k) get_child(self, k).assign(v.numpy()).realize() return loaded_keys def __call__(self, images): images = to_image_list(images) features = self.backbone(images.tensors) proposals, _ = self.rpn(images, features) x, result, _ = self.roi_heads(features, proposals) return result if __name__ == '__main__': resnet = resnet = ResNet(50, num_classes=None, stride_in_1x1=True) model = MaskRCNN(backbone=resnet) model.load_from_pretrained()