From 5d3310ce5639e102d7cc72057508b7e899d3c4c9 Mon Sep 17 00:00:00 2001 From: Kunwar Raj Singh Date: Mon, 26 Jun 2023 04:07:51 +0530 Subject: [PATCH] MaskRCNN Inference (#884) * MaskRCNN weights loading * backbone maybe works * backbone works, but resnet body atol 1e-3 * RPN Call, but veryy wrong output * fixed topk * RPN maybe works, not sure about nms * Fix cursed modules * add back editorconfig * Full call, wrong output * Full call works * fix mask * use NMS from retinanet * Removing extra funcs * refactor * readable * Add example to run model * remove filter * Fix split, batched inference is worse * Fix image sizes * Matching reference * merge master * add filter on top detections * cuda backend fixed * add model eval and spec * convert images to rgb * fix eval * simplify examples code * remove extra code * meshgrid using tinygrad * removing numpy * roi align, floor, ceil * remove numpy from level_mapper * remove numpy from pooler * Revert "Merge branch 'master' of github.com:kunwar31/tinygrad into mrcnn-inference" This reverts commit 4b95a3cb499393bb68b95500cd736d50a93d3ce4, reversing changes made to 98f2b1fa2ede20113b1b369ac00d4b2a7ca5fbfa. * roi align gather * fix master merge * revert to old floor, ceil as ints present in domain * use log2 op * fix indexes * weird bug with ints and gpu * weird bug with ints and gpu * refactors, add env var for gather * floor with contiguous, where * refactor topk, sort * remove staticmethod * refactor stride * remove log2 mlop * realize -> contiguous * refactor forward * remove num_classes, stride_in_1x1 from state * refactor forward * refactoring * flake8 * removing numpy in anchor gen, use numpy for gather, nonzero, optimize topk * keep using tinygrad for smaller gathers * fix empty tensors * comms * move from tensor.py * resnet test passing * add coco dataset back * fix spaces * add test for log2 * no need to create Tensors * no need to create Tensors --------- Co-authored-by: Kunwar Raj Singh --- datasets/coco.py | 200 ++++++ examples/mask_rcnn.py | 299 ++++++++ examples/mlperf/model_eval.py | 41 +- examples/mlperf/model_spec.py | 13 +- models/mask_rcnn.py | 1273 +++++++++++++++++++++++++++++++++ models/resnet.py | 45 +- test/test_ops.py | 3 + tinygrad/tensor.py | 3 +- 8 files changed, 1854 insertions(+), 23 deletions(-) create mode 100644 datasets/coco.py create mode 100644 examples/mask_rcnn.py create mode 100644 models/mask_rcnn.py diff --git a/datasets/coco.py b/datasets/coco.py new file mode 100644 index 00000000..190c26cb --- /dev/null +++ b/datasets/coco.py @@ -0,0 +1,200 @@ +import json +import pathlib +import zipfile +import numpy as np +from extra.utils import download_file +import pycocotools._mask as _mask +from examples.mask_rcnn import Masker +from pycocotools.coco import COCO +from pycocotools.cocoeval import COCOeval + +iou = _mask.iou +merge = _mask.merge +frPyObjects = _mask.frPyObjects + +BASEDIR = pathlib.Path(__file__).parent.parent / "datasets/COCO" + +def create_dict(key_row, val_row, rows): return {row[key_row]:row[val_row] for row in rows} + + +if not pathlib.Path(BASEDIR/'val2017').is_dir(): + fn = BASEDIR/'val2017.zip' + download_file('http://images.cocodataset.org/zips/val2017.zip',fn) + with zipfile.ZipFile(fn, 'r') as zip_ref: + zip_ref.extractall(BASEDIR) + fn.unlink() + + +if not pathlib.Path(BASEDIR/'annotations').is_dir(): + fn = BASEDIR/'annotations_trainval2017.zip' + download_file('http://images.cocodataset.org/annotations/annotations_trainval2017.zip',fn) + with zipfile.ZipFile(fn, 'r') as zip_ref: + zip_ref.extractall(BASEDIR) + fn.unlink() + +with open(BASEDIR/'annotations/instances_val2017.json', 'r') as f: + annotations_raw = json.loads(f.read()) +images = annotations_raw['images'] +categories = annotations_raw['categories'] +annotations = annotations_raw['annotations'] +file_name_to_id = create_dict('file_name', 'id', images) +id_to_width = create_dict('id', 'width', images) +id_to_height = create_dict('id', 'height', images) +json_category_id_to_contiguous_id = {v['id']: i + 1 for i, v in enumerate(categories)} +contiguous_category_id_to_json_id = {v:k for k,v in json_category_id_to_contiguous_id.items()} + + +def encode(bimask): + if len(bimask.shape) == 3: + return _mask.encode(bimask) + elif len(bimask.shape) == 2: + h, w = bimask.shape + return _mask.encode(bimask.reshape((h, w, 1), order='F'))[0] + +def decode(rleObjs): + if type(rleObjs) == list: + return _mask.decode(rleObjs) + else: + return _mask.decode([rleObjs])[:,:,0] + +def area(rleObjs): + if type(rleObjs) == list: + return _mask.area(rleObjs) + else: + return _mask.area([rleObjs])[0] + +def toBbox(rleObjs): + if type(rleObjs) == list: + return _mask.toBbox(rleObjs) + else: + return _mask.toBbox([rleObjs])[0] + + +def convert_prediction_to_coco_bbox(file_name, prediction): + coco_results = [] + try: + original_id = file_name_to_id[file_name] + if len(prediction) == 0: + return coco_results + + image_width = id_to_width[original_id] + image_height = id_to_height[original_id] + prediction = prediction.resize((image_width, image_height)) + prediction = prediction.convert("xywh") + + boxes = prediction.bbox.numpy().tolist() + scores = prediction.get_field("scores").numpy().tolist() + labels = prediction.get_field("labels").numpy().tolist() + + mapped_labels = [contiguous_category_id_to_json_id[int(i)] for i in labels] + + coco_results.extend( + [ + { + "image_id": original_id, + "category_id": mapped_labels[k], + "bbox": box, + "score": scores[k], + } + for k, box in enumerate(boxes) + ] + ) + except Exception as e: + print(file_name, e) + return coco_results + +masker = Masker(threshold=0.5, padding=1) + +def convert_prediction_to_coco_mask(file_name, prediction): + coco_results = [] + try: + original_id = file_name_to_id[file_name] + if len(prediction) == 0: + return coco_results + + image_width = id_to_width[original_id] + image_height = id_to_height[original_id] + prediction = prediction.resize((image_width, image_height)) + masks = prediction.get_field("mask") + + scores = prediction.get_field("scores").numpy().tolist() + labels = prediction.get_field("labels").numpy().tolist() + + masks = masker([masks], [prediction])[0].numpy() + + rles = [ + encode(np.array(mask[0, :, :, np.newaxis], order="F"))[0] + for mask in masks + ] + for rle in rles: + rle["counts"] = rle["counts"].decode("utf-8") + + mapped_labels = [contiguous_category_id_to_json_id[int(i)] for i in labels] + + coco_results.extend( + [ + { + "image_id": original_id, + "category_id": mapped_labels[k], + "segmentation": rle, + "score": scores[k], + } + for k, rle in enumerate(rles) + ] + ) + except Exception as e: + print(file_name, e) + return coco_results + + + +def accumulate_predictions_for_coco(coco_results, json_result_file, rm=False): + path = pathlib.Path(json_result_file) + if rm and path.exists(): path.unlink() + with open(path, "a") as f: + for s in coco_results: + f.write(json.dumps(s)) + f.write('\n') + +def remove_dup(l): + seen = set() + seen_add = seen.add + return [x for x in l if not (x in seen or seen_add(x))] + +class NpEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, np.integer): + return int(obj) + if isinstance(obj, np.floating): + return float(obj) + if isinstance(obj, np.ndarray): + return obj.tolist() + return super(NpEncoder, self).default(obj) + + +def evaluate_predictions_on_coco(json_result_file, iou_type="bbox"): + coco_results = [] + with open(json_result_file, "r") as f: + for line in f: + coco_results.append(json.loads(line)) + + coco_gt = COCO(str(BASEDIR/'annotations/instances_val2017.json')) + set_of_json = remove_dup([json.dumps(d, cls=NpEncoder) for d in coco_results]) + unique_list = [json.loads(s) for s in set_of_json] + + with open(f'{json_result_file}.flattend', "w") as f: + json.dump(unique_list, f) + + coco_dt = coco_gt.loadRes(str(f'{json_result_file}.flattend')) + coco_eval = COCOeval(coco_gt, coco_dt, iou_type) + coco_eval.evaluate() + coco_eval.accumulate() + coco_eval.summarize() + return coco_eval + +def iterate(files, bs=1): + batch = [] + for file in files: + batch.append(file) + if len(batch) >= bs: yield batch; batch = [] + if len(batch) > 0: yield batch; batch = [] diff --git a/examples/mask_rcnn.py b/examples/mask_rcnn.py new file mode 100644 index 00000000..66a57275 --- /dev/null +++ b/examples/mask_rcnn.py @@ -0,0 +1,299 @@ +from models.mask_rcnn import MaskRCNN +from models.resnet import ResNet +from models.mask_rcnn import BoxList +from torch.nn import functional as F +from torchvision import transforms as T +from torchvision.transforms import functional as Ft +import random +from tinygrad.tensor import Tensor +from PIL import Image +import numpy as np +import torch +import argparse +import cv2 + + +class Resize: + def __init__(self, min_size, max_size): + if not isinstance(min_size, (list, tuple)): + min_size = (min_size,) + self.min_size = min_size + self.max_size = max_size + + # modified from torchvision to add support for max size + def get_size(self, image_size): + w, h = image_size + size = random.choice(self.min_size) + max_size = self.max_size + if max_size is not None: + min_original_size = float(min((w, h))) + max_original_size = float(max((w, h))) + if max_original_size / min_original_size * size > max_size: + size = int(round(max_size * min_original_size / max_original_size)) + + if (w <= h and w == size) or (h <= w and h == size): + return (h, w) + + if w < h: + ow = size + oh = int(size * h / w) + else: + oh = size + ow = int(size * w / h) + + return (oh, ow) + + def __call__(self, image): + size = self.get_size(image.size) + image = Ft.resize(image, size) + return image + + +class Normalize: + def __init__(self, mean, std, to_bgr255=True): + self.mean = mean + self.std = std + self.to_bgr255 = to_bgr255 + + def __call__(self, image): + if self.to_bgr255: + image = image[[2, 1, 0]] * 255 + else: + image = image[[0, 1, 2]] * 255 + image = Ft.normalize(image, mean=self.mean, std=self.std) + return image + +transforms = lambda size_scale: T.Compose( + [ + Resize(int(800*size_scale), int(1333*size_scale)), + T.ToTensor(), + Normalize( + mean=[102.9801, 115.9465, 122.7717], std=[1., 1., 1.], to_bgr255=True + ), + ] +) + +def expand_boxes(boxes, scale): + w_half = (boxes[:, 2] - boxes[:, 0]) * .5 + h_half = (boxes[:, 3] - boxes[:, 1]) * .5 + x_c = (boxes[:, 2] + boxes[:, 0]) * .5 + y_c = (boxes[:, 3] + boxes[:, 1]) * .5 + + w_half *= scale + h_half *= scale + + boxes_exp = torch.zeros_like(boxes) + boxes_exp[:, 0] = x_c - w_half + boxes_exp[:, 2] = x_c + w_half + boxes_exp[:, 1] = y_c - h_half + boxes_exp[:, 3] = y_c + h_half + return boxes_exp + + +def expand_masks(mask, padding): + N = mask.shape[0] + M = mask.shape[-1] + pad2 = 2 * padding + scale = float(M + pad2) / M + padded_mask = mask.new_zeros((N, 1, M + pad2, M + pad2)) + padded_mask[:, :, padding:-padding, padding:-padding] = mask + return padded_mask, scale + + +def paste_mask_in_image(mask, box, im_h, im_w, thresh=0.5, padding=1): + # TODO: remove torch + mask = torch.tensor(mask.numpy()) + box = torch.tensor(box.numpy()) + padded_mask, scale = expand_masks(mask[None], padding=padding) + mask = padded_mask[0, 0] + box = expand_boxes(box[None], scale)[0] + box = box.to(dtype=torch.int32) + + TO_REMOVE = 1 + w = int(box[2] - box[0] + TO_REMOVE) + h = int(box[3] - box[1] + TO_REMOVE) + w = max(w, 1) + h = max(h, 1) + + mask = mask.expand((1, 1, -1, -1)) + + mask = mask.to(torch.float32) + mask = F.interpolate(mask, size=(h, w), mode='bilinear', align_corners=False) + mask = mask[0][0] + + if thresh >= 0: + mask = mask > thresh + else: + mask = (mask * 255).to(torch.uint8) + + im_mask = torch.zeros((im_h, im_w), dtype=torch.uint8) + x_0 = max(box[0], 0) + x_1 = min(box[2] + 1, im_w) + y_0 = max(box[1], 0) + y_1 = min(box[3] + 1, im_h) + + im_mask[y_0:y_1, x_0:x_1] = mask[ + (y_0 - box[1]): (y_1 - box[1]), (x_0 - box[0]): (x_1 - box[0]) + ] + return im_mask + + +class Masker: + def __init__(self, threshold=0.5, padding=1): + self.threshold = threshold + self.padding = padding + + def forward_single_image(self, masks, boxes): + boxes = boxes.convert("xyxy") + im_w, im_h = boxes.size + res = [ + paste_mask_in_image(mask[0], box, im_h, im_w, self.threshold, self.padding) + for mask, box in zip(masks, boxes.bbox) + ] + if len(res) > 0: + res = torch.stack(res, dim=0)[:, None] + else: + res = masks.new_empty((0, 1, masks.shape[-2], masks.shape[-1])) + return Tensor(res.numpy()) + + def __call__(self, masks, boxes): + if isinstance(boxes, BoxList): + boxes = [boxes] + + results = [] + for mask, box in zip(masks, boxes): + result = self.forward_single_image(mask, box) + results.append(result) + return results + + +masker = Masker(threshold=0.5, padding=1) + +def select_top_predictions(predictions, confidence_threshold=0.9): + scores = predictions.get_field("scores").numpy() + keep = [idx for idx, score in enumerate(scores) if score > confidence_threshold] + return predictions[keep] + +def compute_prediction(original_image, model, confidence_threshold, size_scale=1.0): + image = transforms(size_scale)(original_image).numpy() + image = Tensor(image, requires_grad=False) + predictions = model(image) + prediction = predictions[0] + prediction = select_top_predictions(prediction, confidence_threshold) + width, height = original_image.size + prediction = prediction.resize((width, height)) + + if prediction.has_field("mask"): + masks = prediction.get_field("mask") + masks = masker([masks], [prediction])[0] + prediction.add_field("mask", masks) + return prediction + +def compute_prediction_batched(batch, model, size_scale=1.0): + imgs = [] + for img in batch: + imgs.append(transforms(size_scale)(img).numpy()) + image = [Tensor(image, requires_grad=False) for image in imgs] + predictions = model(image) + del image + return predictions + +palette = np.array([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1]) + +def findContours(*args, **kwargs): + if cv2.__version__.startswith('4'): + contours, hierarchy = cv2.findContours(*args, **kwargs) + elif cv2.__version__.startswith('3'): + _, contours, hierarchy = cv2.findContours(*args, **kwargs) + return contours, hierarchy + +def compute_colors_for_labels(labels): + l = labels[:, None] + colors = l * palette + colors = (colors % 255).astype("uint8") + return colors + +def overlay_mask(image, predictions): + image = np.asarray(image) + masks = predictions.get_field("mask").numpy() + labels = predictions.get_field("labels").numpy() + + colors = compute_colors_for_labels(labels).tolist() + + for mask, color in zip(masks, colors): + thresh = mask[0, :, :, None] + contours, hierarchy = findContours( + thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE + ) + image = cv2.drawContours(image, contours, -1, color, 3) + + composite = image + + return composite + +CATEGORIES = [ + "__background", "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light", + "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow", "elephant", + "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", + "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle", + "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange", "broccoli", + "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", "potted plant", "bed", "dining table", + "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven", "toaster", + "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush", +] + +def overlay_boxes(image, predictions): + labels = predictions.get_field("labels").numpy() + boxes = predictions.bbox + image = np.asarray(image) + colors = compute_colors_for_labels(labels).tolist() + + for box, color in zip(boxes, colors): + box = torch.tensor(box.numpy()) + box = box.to(torch.int64) + top_left, bottom_right = box[:2].tolist(), box[2:].tolist() + image = cv2.rectangle( + image, tuple(top_left), tuple(bottom_right), tuple(color), 1 + ) + + return image + +def overlay_class_names(image, predictions): + scores = predictions.get_field("scores").numpy().tolist() + labels = predictions.get_field("labels").numpy().tolist() + labels = [CATEGORIES[int(i)] for i in labels] + boxes = predictions.bbox.numpy() + image = np.asarray(image) + template = "{}: {:.2f}" + for box, score, label in zip(boxes, scores, labels): + x, y = box[:2] + s = template.format(label, score) + x, y = int(x), int(y) + cv2.putText( + image, s, (x, y), cv2.FONT_HERSHEY_SIMPLEX, .5, (255, 255, 255), 1 + ) + + return image + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Run MaskRCNN', formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument('--image', type=str, help="Path of the image to run") + parser.add_argument('--threshold', type=float, default=0.7, help="Detector threshold") + parser.add_argument('--size_scale', type=float, default=1.0, help="Image resize multiplier") + parser.add_argument('--out', type=str, default="/tmp/rendered.png", help="Output filename") + args = parser.parse_args() + + resnet = ResNet(50, num_classes=None, stride_in_1x1=True) + model_tiny = MaskRCNN(resnet) + model_tiny.load_from_pretrained() + img = Image.open(args.image) + top_result_tiny = compute_prediction(img, model_tiny, confidence_threshold=args.threshold, size_scale=args.size_scale) + bbox_image = overlay_boxes(img, top_result_tiny) + mask_image = overlay_mask(bbox_image, top_result_tiny) + final_image = overlay_class_names(mask_image, top_result_tiny) + + im = Image.fromarray(final_image) + print(f"saving {args.out}") + im.save(args.out) + im.show() diff --git a/examples/mlperf/model_eval.py b/examples/mlperf/model_eval.py index 0c6b9254..ab72f1eb 100644 --- a/examples/mlperf/model_eval.py +++ b/examples/mlperf/model_eval.py @@ -184,14 +184,51 @@ def eval_bert(): st = time.perf_counter() +def eval_mrcnn(): + from tqdm import tqdm + from models.mask_rcnn import MaskRCNN + from models.resnet import ResNet + from datasets.coco import BASEDIR, images, convert_prediction_to_coco_bbox, convert_prediction_to_coco_mask, accumulate_predictions_for_coco, evaluate_predictions_on_coco, iterate + from examples.mask_rcnn import compute_prediction_batched, Image + mdl = MaskRCNN(ResNet(50, num_classes=None, stride_in_1x1=True)) + mdl.load_from_pretrained() + + bbox_output = '/tmp/results_bbox.json' + mask_output = '/tmp/results_mask.json' + + accumulate_predictions_for_coco([], bbox_output, rm=True) + accumulate_predictions_for_coco([], mask_output, rm=True) + + #TODO: bs > 1 not as accurate + bs = 1 + + for batch in tqdm(iterate(images, bs=bs), total=len(images)//bs): + batch_imgs = [] + for image_row in batch: + image_name = image_row['file_name'] + img = Image.open(BASEDIR/f'val2017/{image_name}').convert("RGB") + batch_imgs.append(img) + batch_result = compute_prediction_batched(batch_imgs, mdl) + for image_row, result in zip(batch, batch_result): + image_name = image_row['file_name'] + box_pred = convert_prediction_to_coco_bbox(image_name, result) + mask_pred = convert_prediction_to_coco_mask(image_name, result) + accumulate_predictions_for_coco(box_pred, bbox_output) + accumulate_predictions_for_coco(mask_pred, mask_output) + del batch_imgs + del batch_result + + evaluate_predictions_on_coco(bbox_output, iou_type='bbox') + evaluate_predictions_on_coco(mask_output, iou_type='segm') + if __name__ == "__main__": # inference only Tensor.training = False Tensor.no_grad = True - models = getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert").split(",") + models = getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,mrcnn").split(",") for m in models: nm = f"eval_{m}" if nm in globals(): print(f"eval {m}") - globals()[nm]() + globals()[nm]() \ No newline at end of file diff --git a/examples/mlperf/model_spec.py b/examples/mlperf/model_spec.py index 69ff7cae..9d885724 100644 --- a/examples/mlperf/model_spec.py +++ b/examples/mlperf/model_spec.py @@ -5,7 +5,8 @@ import numpy as np def test_model(model, *inputs): GlobalCounters.reset() - model(*inputs).numpy() + out = model(*inputs) + if isinstance(out, Tensor): out = out.numpy() # TODO: return event future to still get the time_sum_s without DEBUG=2 print(f"{GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.time_sum_s*1000:.2f} ms") @@ -49,15 +50,21 @@ def spec_bert(): tt = Tensor(np.random.randint(0, 2, (1, 384)).astype(np.float32)) test_model(mdl, x, am, tt) +def spec_mrcnn(): + from models.mask_rcnn import MaskRCNN, ResNet + mdl = MaskRCNN(ResNet(50, num_classes=None, stride_in_1x1=True)) + mdl.load_from_pretrained() + x = Tensor.randn(3, 224, 224) + test_model(mdl, [x]) + if __name__ == "__main__": # inference only for now Tensor.training = False Tensor.no_grad = True - for m in getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert").split(","): + for m in getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,mrcnn").split(","): nm = f"spec_{m}" if nm in globals(): print(f"testing {m}") globals()[nm]() - diff --git a/models/mask_rcnn.py b/models/mask_rcnn.py new file mode 100644 index 00000000..5f32b828 --- /dev/null +++ b/models/mask_rcnn.py @@ -0,0 +1,1273 @@ +import re +import math +import os +import numpy as np +from pathlib import Path +from tinygrad import nn +from tinygrad.tensor import Tensor +from tinygrad.helpers import dtypes +from extra.utils import get_child, download_file +from tinygrad.state import torch_load +from models.resnet import ResNet +from models.retinanet import nms as _box_nms + + +USE_NP_GATHER = os.getenv('FULL_TINYGRAD', '0') == '0' + +def rint(tensor): + x = (tensor*2).cast(dtypes.int32).contiguous().cast(dtypes.float32)/2 + return (x<0).where(x.floor(), x.ceil()) + +def nearest_interpolate(tensor, scale_factor): + bs, c, py, px = tensor.shape + return tensor.reshape(bs, c, py, 1, px, 1).expand(bs, c, py, scale_factor, px, scale_factor).reshape(bs, c, py * scale_factor, px * scale_factor) + +def meshgrid(x, y): + grid_x = Tensor.cat(*[x[idx:idx+1].expand(y.shape).unsqueeze(0) for idx in range(x.shape[0])]) + grid_y = Tensor.cat(*[y.unsqueeze(0)]*x.shape[0]) + return grid_x.reshape(-1, 1), grid_y.reshape(-1, 1) + +def topk(input_, k, dim=-1, largest=True, sorted=False): + k = min(k, input_.shape[dim]-1) + input_ = input_.numpy() + if largest: input_ *= -1 + ind = np.argpartition(input_, k, axis=dim) + if largest: input_ *= -1 + ind = np.take(ind, np.arange(k), axis=dim) # k non-sorted indices + input_ = np.take_along_axis(input_, ind, axis=dim) # k non-sorted values + if not sorted: return Tensor(input_), ind + if largest: input_ *= -1 + ind_part = np.argsort(input_, axis=dim) + ind = np.take_along_axis(ind, ind_part, axis=dim) + if largest: input_ *= -1 + val = np.take_along_axis(input_, ind_part, axis=dim) + return Tensor(val), ind + +# This is very slow for large arrays, or indices +def _gather(array, indices): + indices = indices.float().to(array.device) + reshape_arg = [1]*array.ndim + [array.shape[-1]] + return Tensor.where( + indices.unsqueeze(indices.ndim).expand(*indices.shape, array.shape[-1]) == Tensor.arange(array.shape[-1]).reshape(*reshape_arg).expand(*indices.shape, array.shape[-1]), + array, 0, + ).sum(indices.ndim) + +# TODO: replace npgather with a faster gather using tinygrad only +# NOTE: this blocks the gradient +def npgather(array,indices): + if isinstance(array, Tensor): array = array.numpy() + if isinstance(indices, Tensor): indices = indices.numpy() + if isinstance(indices, list): indices = np.asarray(indices) + return Tensor(array[indices.astype(int)]) + +def get_strides(shape): + prod = [1] + for idx in range(len(shape)-1, -1, -1): prod.append(prod[-1] * shape[idx]) + # something about ints is broken with gpu, cuda + return Tensor(prod[::-1][1:], dtype=dtypes.int32).unsqueeze(0).cpu() + +# with keys as integer array for all axes +def tensor_getitem(tensor, *keys): + # something about ints is broken with gpu, cuda + flat_keys = Tensor.stack([key.expand((sum(keys)).shape).reshape(-1) for key in keys], dim=1).cpu().cast(dtypes.int32) + strides = get_strides(tensor.shape) + idxs = (flat_keys * strides).sum(1) + gatherer = npgather if USE_NP_GATHER else _gather + return gatherer(tensor.reshape(-1), idxs).reshape(sum(keys).shape) + + +# for gather with indicies only on axis=0 +def tensor_gather(tensor, indices): + if not isinstance(indices, Tensor): + indices = Tensor(indices, requires_grad=False) + if len(tensor.shape) > 2: + rem_shape = list(tensor.shape)[1:] + tensor = tensor.reshape(tensor.shape[0], -1) + else: + rem_shape = None + if len(tensor.shape) > 1: + tensor = tensor.T + repeat_arg = [1]*(tensor.ndim-1) + [tensor.shape[-2]] + indices = indices.unsqueeze(indices.ndim).repeat(repeat_arg) + ret = _gather(tensor, indices) + if rem_shape: + ret = ret.reshape([indices.shape[0]] + rem_shape) + else: + ret = _gather(tensor, indices) + del indices + return ret + + +class LastLevelMaxPool: + def __call__(self, x): return [Tensor.max_pool2d(x, 1, 2)] + + +# transpose +FLIP_LEFT_RIGHT = 0 +FLIP_TOP_BOTTOM = 1 + + +def permute_and_flatten(layer:Tensor, N, A, C, H, W): + layer = layer.reshape(N, -1, C, H, W) + layer = layer.permute(0, 3, 4, 1, 2) + layer = layer.reshape(N, -1, C) + return layer + + +class BoxList: + def __init__(self, bbox, image_size, mode="xyxy"): + if not isinstance(bbox, Tensor): + bbox = Tensor(bbox) + if bbox.ndim != 2: + raise ValueError( + "bbox should have 2 dimensions, got {}".format(bbox.ndim) + ) + if bbox.shape[-1] != 4: + raise ValueError( + "last dimenion of bbox should have a " + "size of 4, got {}".format(bbox.shape[-1]) + ) + if mode not in ("xyxy", "xywh"): + raise ValueError("mode should be 'xyxy' or 'xywh'") + + self.bbox = bbox + self.size = image_size # (image_width, image_height) + self.mode = mode + self.extra_fields = {} + + def __repr__(self): + s = self.__class__.__name__ + "(" + s += "num_boxes={}, ".format(len(self)) + s += "image_width={}, ".format(self.size[0]) + s += "image_height={}, ".format(self.size[1]) + s += "mode={})".format(self.mode) + return s + + def area(self): + box = self.bbox + if self.mode == "xyxy": + TO_REMOVE = 1 + area = (box[:, 2] - box[:, 0] + TO_REMOVE) * (box[:, 3] - box[:, 1] + TO_REMOVE) + elif self.mode == "xywh": + area = box[:, 2] * box[:, 3] + return area + + def add_field(self, field, field_data): + self.extra_fields[field] = field_data + + def get_field(self, field): + return self.extra_fields[field] + + def has_field(self, field): + return field in self.extra_fields + + def fields(self): + return list(self.extra_fields.keys()) + + def _copy_extra_fields(self, bbox): + for k, v in bbox.extra_fields.items(): + self.extra_fields[k] = v + + def convert(self, mode): + if mode == self.mode: + return self + xmin, ymin, xmax, ymax = self._split_into_xyxy() + if mode == "xyxy": + bbox = Tensor.cat(*(xmin, ymin, xmax, ymax), dim=-1) + bbox = BoxList(bbox, self.size, mode=mode) + else: + TO_REMOVE = 1 + bbox = Tensor.cat( + *(xmin, ymin, xmax - xmin + TO_REMOVE, ymax - ymin + TO_REMOVE), dim=-1 + ) + bbox = BoxList(bbox, self.size, mode=mode) + bbox._copy_extra_fields(self) + return bbox + + def _split_into_xyxy(self): + if self.mode == "xyxy": + xmin, ymin, xmax, ymax = self.bbox.chunk(4, dim=-1) + return xmin, ymin, xmax, ymax + elif self.mode == "xywh": + TO_REMOVE = 1 + xmin, ymin, w, h = self.bbox.chunk(4, dim=-1) + return ( + xmin, + ymin, + xmin + (w - TO_REMOVE).clamp(min=0), + ymin + (h - TO_REMOVE).clamp(min=0), + ) + + def resize(self, size, *args, **kwargs): + ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(size, self.size)) + if ratios[0] == ratios[1]: + ratio = ratios[0] + scaled_box = self.bbox * ratio + bbox = BoxList(scaled_box, size, mode=self.mode) + for k, v in self.extra_fields.items(): + if not isinstance(v, Tensor): + v = v.resize(size, *args, **kwargs) + bbox.add_field(k, v) + return bbox + + ratio_width, ratio_height = ratios + xmin, ymin, xmax, ymax = self._split_into_xyxy() + scaled_xmin = xmin * ratio_width + scaled_xmax = xmax * ratio_width + scaled_ymin = ymin * ratio_height + scaled_ymax = ymax * ratio_height + scaled_box = Tensor.cat( + *(scaled_xmin, scaled_ymin, scaled_xmax, scaled_ymax), dim=-1 + ) + bbox = BoxList(scaled_box, size, mode="xyxy") + for k, v in self.extra_fields.items(): + if not isinstance(v, Tensor): + v = v.resize(size, *args, **kwargs) + bbox.add_field(k, v) + + return bbox.convert(self.mode) + + def transpose(self, method): + image_width, image_height = self.size + xmin, ymin, xmax, ymax = self._split_into_xyxy() + if method == FLIP_LEFT_RIGHT: + TO_REMOVE = 1 + transposed_xmin = image_width - xmax - TO_REMOVE + transposed_xmax = image_width - xmin - TO_REMOVE + transposed_ymin = ymin + transposed_ymax = ymax + elif method == FLIP_TOP_BOTTOM: + transposed_xmin = xmin + transposed_xmax = xmax + transposed_ymin = image_height - ymax + transposed_ymax = image_height - ymin + + transposed_boxes = Tensor.cat( + *(transposed_xmin, transposed_ymin, transposed_xmax, transposed_ymax), dim=-1 + ) + bbox = BoxList(transposed_boxes, self.size, mode="xyxy") + for k, v in self.extra_fields.items(): + if not isinstance(v, Tensor): + v = v.transpose(method) + bbox.add_field(k, v) + return bbox.convert(self.mode) + + def clip_to_image(self, remove_empty=True): + TO_REMOVE = 1 + bb1 = self.bbox.clip(min_=0, max_=self.size[0] - TO_REMOVE)[:, 0] + bb2 = self.bbox.clip(min_=0, max_=self.size[1] - TO_REMOVE)[:, 1] + bb3 = self.bbox.clip(min_=0, max_=self.size[0] - TO_REMOVE)[:, 2] + bb4 = self.bbox.clip(min_=0, max_=self.size[1] - TO_REMOVE)[:, 3] + self.bbox = Tensor.stack((bb1, bb2, bb3, bb4), dim=1) + if remove_empty: + box = self.bbox + keep = (box[:, 3] > box[:, 1]) & (box[:, 2] > box[:, 0]) + return self[keep] + return self + + def __getitem__(self, item): + if isinstance(item, list): + if len(item) == 0: + return [] + if sum(item) == len(item) and isinstance(item[0], bool): + return self + bbox = BoxList(tensor_gather(self.bbox, item), self.size, self.mode) + for k, v in self.extra_fields.items(): + bbox.add_field(k, tensor_gather(v, item)) + return bbox + + def __len__(self): + return self.bbox.shape[0] + + +def cat_boxlist(bboxes): + size = bboxes[0].size + mode = bboxes[0].mode + fields = set(bboxes[0].fields()) + cat_box_list = [bbox.bbox for bbox in bboxes if bbox.bbox.shape[0] > 0] + + if len(cat_box_list) > 0: + cat_boxes = BoxList(Tensor.cat(*cat_box_list, dim=0), size, mode) + else: + cat_boxes = BoxList(bboxes[0].bbox, size, mode) + for field in fields: + cat_field_list = [bbox.get_field(field) for bbox in bboxes if bbox.get_field(field).shape[0] > 0] + + if len(cat_box_list) > 0: + data = Tensor.cat(*cat_field_list, dim=0) + else: + data = bboxes[0].get_field(field) + + cat_boxes.add_field(field, data) + + return cat_boxes + + +class FPN: + def __init__(self, in_channels_list, out_channels): + self.inner_blocks, self.layer_blocks = [], [] + for in_channels in in_channels_list: + self.inner_blocks.append(nn.Conv2d(in_channels, out_channels, kernel_size=1)) + self.layer_blocks.append(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)) + self.top_block = LastLevelMaxPool() + + def __call__(self, x: Tensor): + last_inner = self.inner_blocks[-1](x[-1]) + results = [] + results.append(self.layer_blocks[-1](last_inner)) + for feature, inner_block, layer_block in zip( + x[:-1][::-1], self.inner_blocks[:-1][::-1], self.layer_blocks[:-1][::-1] + ): + if not inner_block: + continue + inner_top_down = nearest_interpolate(last_inner, scale_factor=2) + inner_lateral = inner_block(feature) + last_inner = inner_lateral + inner_top_down + layer_result = layer_block(last_inner) + results.insert(0, layer_result) + last_results = self.top_block(results[-1]) + results.extend(last_results) + + return tuple(results) + + +class ResNetFPN: + def __init__(self, resnet, out_channels=256): + self.out_channels = out_channels + self.body = resnet + in_channels_stage2 = 256 + in_channels_list = [ + in_channels_stage2, + in_channels_stage2 * 2, + in_channels_stage2 * 4, + in_channels_stage2 * 8, + ] + self.fpn = FPN(in_channels_list, out_channels) + + def __call__(self, x): + x = self.body(x) + return self.fpn(x) + + +class AnchorGenerator: + def __init__( + self, + sizes=(32, 64, 128, 256, 512), + aspect_ratios=(0.5, 1.0, 2.0), + anchor_strides=(4, 8, 16, 32, 64), + straddle_thresh=0, + ): + if len(anchor_strides) == 1: + anchor_stride = anchor_strides[0] + cell_anchors = [ + generate_anchors(anchor_stride, sizes, aspect_ratios) + ] + else: + if len(anchor_strides) != len(sizes): + raise RuntimeError("FPN should have #anchor_strides == #sizes") + + cell_anchors = [ + generate_anchors( + anchor_stride, + size if isinstance(size, (tuple, list)) else (size,), + aspect_ratios + ) + for anchor_stride, size in zip(anchor_strides, sizes) + ] + self.strides = anchor_strides + self.cell_anchors = cell_anchors + self.straddle_thresh = straddle_thresh + + def num_anchors_per_location(self): + return [cell_anchors.shape[0] for cell_anchors in self.cell_anchors] + + def grid_anchors(self, grid_sizes): + anchors = [] + for size, stride, base_anchors in zip( + grid_sizes, self.strides, self.cell_anchors + ): + grid_height, grid_width = size + device = base_anchors.device + shifts_x = Tensor.arange( + start=0, stop=grid_width * stride, step=stride, dtype=dtypes.float32, device=device + ) + shifts_y = Tensor.arange( + start=0, stop=grid_height * stride, step=stride, dtype=dtypes.float32, device=device + ) + shift_y, shift_x = meshgrid(shifts_y, shifts_x) + shift_x = shift_x.reshape(-1) + shift_y = shift_y.reshape(-1) + shifts = Tensor.stack((shift_x, shift_y, shift_x, shift_y), dim=1) + + anchors.append( + (shifts.reshape(-1, 1, 4) + base_anchors.reshape(1, -1, 4)).reshape(-1, 4) + ) + + return anchors + + def add_visibility_to(self, boxlist): + image_width, image_height = boxlist.size + anchors = boxlist.bbox + if self.straddle_thresh >= 0: + inds_inside = ( + (anchors[:, 0] >= -self.straddle_thresh) + * (anchors[:, 1] >= -self.straddle_thresh) + * (anchors[:, 2] < image_width + self.straddle_thresh) + * (anchors[:, 3] < image_height + self.straddle_thresh) + ) + else: + device = anchors.device + inds_inside = Tensor.ones(anchors.shape[0], dtype=dtypes.uint8, device=device) + boxlist.add_field("visibility", inds_inside) + + def __call__(self, image_list, feature_maps): + grid_sizes = [feature_map.shape[-2:] for feature_map in feature_maps] + anchors_over_all_feature_maps = self.grid_anchors(grid_sizes) + anchors = [] + for (image_height, image_width) in image_list.image_sizes: + anchors_in_image = [] + for anchors_per_feature_map in anchors_over_all_feature_maps: + boxlist = BoxList( + anchors_per_feature_map, (image_width, image_height), mode="xyxy" + ) + self.add_visibility_to(boxlist) + anchors_in_image.append(boxlist) + anchors.append(anchors_in_image) + return anchors + + +def generate_anchors( + stride=16, sizes=(32, 64, 128, 256, 512), aspect_ratios=(0.5, 1, 2) +): + return _generate_anchors(stride, Tensor(list(sizes)) / stride, Tensor(list(aspect_ratios))) + + +def _generate_anchors(base_size, scales, aspect_ratios): + anchor = Tensor([1, 1, base_size, base_size]) - 1 + anchors = _ratio_enum(anchor, aspect_ratios) + anchors = Tensor.cat( + *[_scale_enum(anchors[i, :], scales).reshape(-1, 4) for i in range(anchors.shape[0])] + ) + return anchors + + +def _whctrs(anchor): + w = anchor[2] - anchor[0] + 1 + h = anchor[3] - anchor[1] + 1 + x_ctr = anchor[0] + 0.5 * (w - 1) + y_ctr = anchor[1] + 0.5 * (h - 1) + return w, h, x_ctr, y_ctr + + +def _mkanchors(ws, hs, x_ctr, y_ctr): + ws = ws[:, None] + hs = hs[:, None] + anchors = Tensor.cat(*( + x_ctr - 0.5 * (ws - 1), + y_ctr - 0.5 * (hs - 1), + x_ctr + 0.5 * (ws - 1), + y_ctr + 0.5 * (hs - 1), + ), dim=1) + return anchors + + +def _ratio_enum(anchor, ratios): + w, h, x_ctr, y_ctr = _whctrs(anchor) + size = w * h + size_ratios = size / ratios + ws = rint(Tensor.sqrt(size_ratios)) + hs = rint(ws * ratios) + anchors = _mkanchors(ws, hs, x_ctr, y_ctr) + return anchors + + +def _scale_enum(anchor, scales): + w, h, x_ctr, y_ctr = _whctrs(anchor) + ws = w * scales + hs = h * scales + anchors = _mkanchors(ws, hs, x_ctr, y_ctr) + return anchors + + +class RPNHead: + def __init__(self, in_channels, num_anchors): + self.conv = nn.Conv2d(in_channels, 256, kernel_size=3, padding=1) + self.cls_logits = nn.Conv2d(256, num_anchors, kernel_size=1) + self.bbox_pred = nn.Conv2d(256, num_anchors * 4, kernel_size=1) + + def __call__(self, x): + logits = [] + bbox_reg = [] + for feature in x: + t = Tensor.relu(self.conv(feature)) + logits.append(self.cls_logits(t)) + bbox_reg.append(self.bbox_pred(t)) + return logits, bbox_reg + + +class BoxCoder(object): + def __init__(self, weights, bbox_xform_clip=math.log(1000. / 16)): + self.weights = weights + self.bbox_xform_clip = bbox_xform_clip + + def encode(self, reference_boxes, proposals): + TO_REMOVE = 1 # TODO remove + ex_widths = proposals[:, 2] - proposals[:, 0] + TO_REMOVE + ex_heights = proposals[:, 3] - proposals[:, 1] + TO_REMOVE + ex_ctr_x = proposals[:, 0] + 0.5 * ex_widths + ex_ctr_y = proposals[:, 1] + 0.5 * ex_heights + + gt_widths = reference_boxes[:, 2] - reference_boxes[:, 0] + TO_REMOVE + gt_heights = reference_boxes[:, 3] - reference_boxes[:, 1] + TO_REMOVE + gt_ctr_x = reference_boxes[:, 0] + 0.5 * gt_widths + gt_ctr_y = reference_boxes[:, 1] + 0.5 * gt_heights + + wx, wy, ww, wh = self.weights + targets_dx = wx * (gt_ctr_x - ex_ctr_x) / ex_widths + targets_dy = wy * (gt_ctr_y - ex_ctr_y) / ex_heights + targets_dw = ww * Tensor.log(gt_widths / ex_widths) + targets_dh = wh * Tensor.log(gt_heights / ex_heights) + + targets = Tensor.stack((targets_dx, targets_dy, targets_dw, targets_dh), dim=1) + return targets + + def decode(self, rel_codes, boxes): + boxes = boxes.cast(rel_codes.dtype) + rel_codes = rel_codes + + TO_REMOVE = 1 # TODO remove + widths = boxes[:, 2] - boxes[:, 0] + TO_REMOVE + heights = boxes[:, 3] - boxes[:, 1] + TO_REMOVE + ctr_x = boxes[:, 0] + 0.5 * widths + ctr_y = boxes[:, 1] + 0.5 * heights + + wx, wy, ww, wh = self.weights + dx = rel_codes[:, 0::4] / wx + dy = rel_codes[:, 1::4] / wy + dw = rel_codes[:, 2::4] / ww + dh = rel_codes[:, 3::4] / wh + + # Prevent sending too large values into Tensor.exp() + dw = dw.clip(min_=dw.min(), max_=self.bbox_xform_clip) + dh = dh.clip(min_=dh.min(), max_=self.bbox_xform_clip) + + pred_ctr_x = dx * widths[:, None] + ctr_x[:, None] + pred_ctr_y = dy * heights[:, None] + ctr_y[:, None] + pred_w = dw.exp() * widths[:, None] + pred_h = dh.exp() * heights[:, None] + x = pred_ctr_x - 0.5 * pred_w + y = pred_ctr_y - 0.5 * pred_h + w = pred_ctr_x + 0.5 * pred_w - 1 + h = pred_ctr_y + 0.5 * pred_h - 1 + pred_boxes = Tensor.stack([x, y, w, h]).permute(1,2,0).reshape(rel_codes.shape[0], rel_codes.shape[1]) + return pred_boxes + + +def boxlist_nms(boxlist, nms_thresh, max_proposals=-1, score_field="scores"): + if nms_thresh <= 0: + return boxlist + mode = boxlist.mode + boxlist = boxlist.convert("xyxy") + boxes = boxlist.bbox + score = boxlist.get_field(score_field) + keep = _box_nms(boxes.numpy(), score.numpy(), nms_thresh) + if max_proposals > 0: + keep = keep[:max_proposals] + boxlist = boxlist[keep] + return boxlist.convert(mode) + + +def remove_small_boxes(boxlist, min_size): + xywh_boxes = boxlist.convert("xywh").bbox + _, _, ws, hs = xywh_boxes.chunk(4, dim=1) + keep = (( + (ws >= min_size) * (hs >= min_size) + ) > 0).reshape(-1) + if keep.sum().numpy() == len(boxlist): + return boxlist + else: + keep = keep.numpy().nonzero()[0] + return boxlist[keep] + + +class RPNPostProcessor: + # Not used in Loss calculation + def __init__( + self, + pre_nms_top_n, + post_nms_top_n, + nms_thresh, + min_size, + box_coder=None, + fpn_post_nms_top_n=None, + ): + self.pre_nms_top_n = pre_nms_top_n + self.post_nms_top_n = post_nms_top_n + self.nms_thresh = nms_thresh + self.min_size = min_size + + if box_coder is None: + box_coder = BoxCoder(weights=(1.0, 1.0, 1.0, 1.0)) + self.box_coder = box_coder + + if fpn_post_nms_top_n is None: + fpn_post_nms_top_n = post_nms_top_n + self.fpn_post_nms_top_n = fpn_post_nms_top_n + + def forward_for_single_feature_map(self, anchors, objectness, box_regression): + device = objectness.device + N, A, H, W = objectness.shape + objectness = permute_and_flatten(objectness, N, A, 1, H, W).reshape(N, -1) + objectness = objectness.sigmoid() + + box_regression = permute_and_flatten(box_regression, N, A, 4, H, W) + + num_anchors = A * H * W + + pre_nms_top_n = min(self.pre_nms_top_n, num_anchors) + objectness, topk_idx = topk(objectness, pre_nms_top_n, dim=1, sorted=False) + concat_anchors = Tensor.cat(*[a.bbox for a in anchors], dim=0).reshape(N, -1, 4) + image_shapes = [box.size for box in anchors] + + box_regression_list = [] + concat_anchors_list = [] + for batch_idx in range(N): + box_regression_list.append(tensor_gather(box_regression[batch_idx], topk_idx[batch_idx])) + concat_anchors_list.append(tensor_gather(concat_anchors[batch_idx], topk_idx[batch_idx])) + + box_regression = Tensor.stack(box_regression_list) + concat_anchors = Tensor.stack(concat_anchors_list) + + proposals = self.box_coder.decode( + box_regression.reshape(-1, 4), concat_anchors.reshape(-1, 4) + ) + + proposals = proposals.reshape(N, -1, 4) + + result = [] + for proposal, score, im_shape in zip(proposals, objectness, image_shapes): + boxlist = BoxList(proposal, im_shape, mode="xyxy") + boxlist.add_field("objectness", score) + boxlist = boxlist.clip_to_image(remove_empty=False) + boxlist = remove_small_boxes(boxlist, self.min_size) + boxlist = boxlist_nms( + boxlist, + self.nms_thresh, + max_proposals=self.post_nms_top_n, + score_field="objectness", + ) + result.append(boxlist) + return result + + def __call__(self, anchors, objectness, box_regression): + sampled_boxes = [] + num_levels = len(objectness) + anchors = list(zip(*anchors)) + for a, o, b in zip(anchors, objectness, box_regression): + sampled_boxes.append(self.forward_for_single_feature_map(a, o, b)) + + boxlists = list(zip(*sampled_boxes)) + boxlists = [cat_boxlist(boxlist) for boxlist in boxlists] + + if num_levels > 1: + boxlists = self.select_over_all_levels(boxlists) + + return boxlists + + def select_over_all_levels(self, boxlists): + num_images = len(boxlists) + for i in range(num_images): + objectness = boxlists[i].get_field("objectness") + post_nms_top_n = min(self.fpn_post_nms_top_n, objectness.shape[0]) + _, inds_sorted = topk(objectness, + post_nms_top_n, dim=0, sorted=False + ) + boxlists[i] = boxlists[i][inds_sorted] + return boxlists + + +class RPN: + def __init__(self, in_channels): + self.anchor_generator = AnchorGenerator() + + in_channels = 256 + head = RPNHead( + in_channels, self.anchor_generator.num_anchors_per_location()[0] + ) + rpn_box_coder = BoxCoder(weights=(1.0, 1.0, 1.0, 1.0)) + box_selector_test = RPNPostProcessor( + pre_nms_top_n=1000, + post_nms_top_n=1000, + nms_thresh=0.7, + min_size=0, + box_coder=rpn_box_coder, + fpn_post_nms_top_n=1000 + ) + self.head = head + self.box_selector_test = box_selector_test + + def __call__(self, images, features, targets=None): + objectness, rpn_box_regression = self.head(features) + anchors = self.anchor_generator(images, features) + boxes = self.box_selector_test(anchors, objectness, rpn_box_regression) + return boxes, {} + + +def make_conv3x3( + in_channels, + out_channels, + dilation=1, + stride=1, + use_gn=False, +): + conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=stride, + padding=dilation, + dilation=dilation, + bias=False if use_gn else True + ) + return conv + + +class MaskRCNNFPNFeatureExtractor: + def __init__(self): + resolution = 14 + scales = (0.25, 0.125, 0.0625, 0.03125) + sampling_ratio = 2 + pooler = Pooler( + output_size=(resolution, resolution), + scales=scales, + sampling_ratio=sampling_ratio, + ) + input_size = 256 + self.pooler = pooler + + use_gn = False + layers = (256, 256, 256, 256) + dilation = 1 + self.mask_fcn1 = make_conv3x3(input_size, layers[0], dilation=dilation, stride=1, use_gn=use_gn) + self.mask_fcn2 = make_conv3x3(layers[0], layers[1], dilation=dilation, stride=1, use_gn=use_gn) + self.mask_fcn3 = make_conv3x3(layers[1], layers[2], dilation=dilation, stride=1, use_gn=use_gn) + self.mask_fcn4 = make_conv3x3(layers[2], layers[3], dilation=dilation, stride=1, use_gn=use_gn) + self.blocks = [self.mask_fcn1, self.mask_fcn2, self.mask_fcn3, self.mask_fcn4] + + def __call__(self, x, proposals): + x = self.pooler(x, proposals) + for layer in self.blocks: + if x is not None: + x = Tensor.relu(layer(x)) + return x + + +class MaskRCNNC4Predictor: + def __init__(self): + num_classes = 81 + dim_reduced = 256 + num_inputs = dim_reduced + self.conv5_mask = nn.ConvTranspose2d(num_inputs, dim_reduced, 2, 2, 0) + self.mask_fcn_logits = nn.Conv2d(dim_reduced, num_classes, 1, 1, 0) + + def __call__(self, x): + x = Tensor.relu(self.conv5_mask(x)) + return self.mask_fcn_logits(x) + + +class FPN2MLPFeatureExtractor: + def __init__(self, cfg): + resolution = 7 + scales = (0.25, 0.125, 0.0625, 0.03125) + sampling_ratio = 2 + pooler = Pooler( + output_size=(resolution, resolution), + scales=scales, + sampling_ratio=sampling_ratio, + ) + input_size = 256 * resolution ** 2 + representation_size = 1024 + self.pooler = pooler + self.fc6 = nn.Linear(input_size, representation_size) + self.fc7 = nn.Linear(representation_size, representation_size) + + def __call__(self, x, proposals): + x = self.pooler(x, proposals) + x = x.reshape(x.shape[0], -1) + x = Tensor.relu(self.fc6(x)) + x = Tensor.relu(self.fc7(x)) + return x + + +def _bilinear_interpolate( + input, # [N, C, H, W] + roi_batch_ind, # [K] + y, # [K, PH, IY] + x, # [K, PW, IX] + ymask, # [K, IY] + xmask, # [K, IX] +): + _, channels, height, width = input.shape + y = y.clip(min_=0.0, max_=float(height-1)) + x = x.clip(min_=0.0, max_=float(width-1)) + + # Tensor.where doesnt work well with int32 data so cast to float32 + y_low = y.cast(dtypes.int32).contiguous().float().contiguous() + x_low = x.cast(dtypes.int32).contiguous().float().contiguous() + + y_high = Tensor.where(y_low >= height - 1, float(height - 1), y_low + 1) + y_low = Tensor.where(y_low >= height - 1, float(height - 1), y_low) + + x_high = Tensor.where(x_low >= width - 1, float(width - 1), x_low + 1) + x_low = Tensor.where(x_low >= width - 1, float(width - 1), x_low) + + ly = y - y_low + lx = x - x_low + hy = 1.0 - ly + hx = 1.0 - lx + + def masked_index( + y, # [K, PH, IY] + x, # [K, PW, IX] + ): + if ymask is not None: + assert xmask is not None + y = Tensor.where(ymask[:, None, :], y, 0) + x = Tensor.where(xmask[:, None, :], x, 0) + key1 = roi_batch_ind[:, None, None, None, None, None] + key2 = Tensor.arange(channels, device=input.device)[None, :, None, None, None, None] + key3 = y[:, None, :, None, :, None] + key4 = x[:, None, None, :, None, :] + return tensor_getitem(input,key1,key2,key3,key4) # [K, C, PH, PW, IY, IX] + + v1 = masked_index(y_low, x_low) + v2 = masked_index(y_low, x_high) + v3 = masked_index(y_high, x_low) + v4 = masked_index(y_high, x_high) + + # all ws preemptively [K, C, PH, PW, IY, IX] + def outer_prod(y, x): + return y[:, None, :, None, :, None] * x[:, None, None, :, None, :] + + w1 = outer_prod(hy, hx) + w2 = outer_prod(hy, lx) + w3 = outer_prod(ly, hx) + w4 = outer_prod(ly, lx) + + val = w1*v1 + w2*v2 + w3*v3 + w4*v4 + return val + +#https://pytorch.org/vision/main/_modules/torchvision/ops/roi_align.html#roi_align +def _roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned): + orig_dtype = input.dtype + _, _, height, width = input.shape + ph = Tensor.arange(pooled_height, device=input.device) + pw = Tensor.arange(pooled_width, device=input.device) + + roi_batch_ind = rois[:, 0].cast(dtypes.int32).contiguous() + offset = 0.5 if aligned else 0.0 + roi_start_w = rois[:, 1] * spatial_scale - offset + roi_start_h = rois[:, 2] * spatial_scale - offset + roi_end_w = rois[:, 3] * spatial_scale - offset + roi_end_h = rois[:, 4] * spatial_scale - offset + + roi_width = roi_end_w - roi_start_w + roi_height = roi_end_h - roi_start_h + if not aligned: + roi_width = roi_width.maximum(1.0) + roi_height = roi_height.maximum(1.0) + + bin_size_h = roi_height / pooled_height + bin_size_w = roi_width / pooled_width + + exact_sampling = sampling_ratio > 0 + roi_bin_grid_h = sampling_ratio if exact_sampling else (roi_height / pooled_height).ceil() + roi_bin_grid_w = sampling_ratio if exact_sampling else (roi_width / pooled_width).ceil() + + if exact_sampling: + count = max(roi_bin_grid_h * roi_bin_grid_w, 1) + iy = Tensor.arange(roi_bin_grid_h, device=input.device) + ix = Tensor.arange(roi_bin_grid_w, device=input.device) + ymask = None + xmask = None + else: + count = (roi_bin_grid_h * roi_bin_grid_w).maximum(1) + iy = Tensor.arange(height, device=input.device) + ix = Tensor.arange(width, device=input.device) + ymask = iy[None, :] < roi_bin_grid_h[:, None] + xmask = ix[None, :] < roi_bin_grid_w[:, None] + + def from_K(t): + return t[:, None, None] + + y = ( + from_K(roi_start_h) + + ph[None, :, None] * from_K(bin_size_h) + + (iy[None, None, :] + 0.5) * from_K(bin_size_h / roi_bin_grid_h) + ) + x = ( + from_K(roi_start_w) + + pw[None, :, None] * from_K(bin_size_w) + + (ix[None, None, :] + 0.5) * from_K(bin_size_w / roi_bin_grid_w) + ) + + val = _bilinear_interpolate(input, roi_batch_ind, y, x, ymask, xmask) + if not exact_sampling: + val = ymask[:, None, None, None, :, None].where(val, 0) + val = xmask[:, None, None, None, None, :].where(val, 0) + + output = val.sum((-1, -2)) + if isinstance(count, Tensor): + output /= count[:, None, None, None] + else: + output /= count + + output = output.cast(orig_dtype) + return output + +class ROIAlign: + def __init__(self, output_size, spatial_scale, sampling_ratio): + self.output_size = output_size + self.spatial_scale = spatial_scale + self.sampling_ratio = sampling_ratio + + def __call__(self, input, rois): + output = _roi_align( + input, rois, self.spatial_scale, self.output_size[0], self.output_size[1], self.sampling_ratio, aligned=False + ) + return output + + +class LevelMapper: + def __init__(self, k_min, k_max, canonical_scale=224, canonical_level=4, eps=1e-6): + self.k_min = k_min + self.k_max = k_max + self.s0 = canonical_scale + self.lvl0 = canonical_level + self.eps = eps + + def __call__(self, boxlists): + s = Tensor.sqrt(Tensor.cat(*[boxlist.area() for boxlist in boxlists])) + target_lvls = (self.lvl0 + Tensor.log2(s / self.s0 + self.eps)).floor() + target_lvls = target_lvls.clip(min_=self.k_min, max_=self.k_max) + return target_lvls - self.k_min + + +class Pooler: + def __init__(self, output_size, scales, sampling_ratio): + self.output_size = output_size + self.scales = scales + self.sampling_ratio = sampling_ratio + poolers = [] + for scale in scales: + poolers.append( + ROIAlign( + output_size, spatial_scale=scale, sampling_ratio=sampling_ratio + ) + ) + self.poolers = poolers + self.output_size = output_size + lvl_min = -math.log2(scales[0]) + lvl_max = -math.log2(scales[-1]) + self.map_levels = LevelMapper(lvl_min, lvl_max) + + def convert_to_roi_format(self, boxes): + concat_boxes = Tensor.cat(*[b.bbox for b in boxes], dim=0) + device, dtype = concat_boxes.device, concat_boxes.dtype + ids = Tensor.cat( + *[ + Tensor.full((len(b), 1), i, dtype=dtype, device=device) + for i, b in enumerate(boxes) + ], + dim=0, + ) + if concat_boxes.shape[0] != 0: + rois = Tensor.cat(*[ids, concat_boxes], dim=1) + return rois + + def __call__(self, x, boxes): + num_levels = len(self.poolers) + rois = self.convert_to_roi_format(boxes) + if rois: + if num_levels == 1: + return self.poolers[0](x[0], rois) + + levels = self.map_levels(boxes) + results = [] + all_idxs = [] + for level, (per_level_feature, pooler) in enumerate(zip(x, self.poolers)): + # this is fine because no grad will flow through index + idx_in_level = (levels.numpy() == level).nonzero()[0] + if len(idx_in_level) > 0: + rois_per_level = tensor_gather(rois, idx_in_level) + pooler_output = pooler(per_level_feature, rois_per_level) + all_idxs.extend(idx_in_level) + results.append(pooler_output) + + return tensor_gather(Tensor.cat(*results), [x[0] for x in sorted({i:idx for i, idx in enumerate(all_idxs)}.items(), key=lambda x: x[1])]) + + +class FPNPredictor: + def __init__(self): + num_classes = 81 + representation_size = 1024 + self.cls_score = nn.Linear(representation_size, num_classes) + num_bbox_reg_classes = num_classes + self.bbox_pred = nn.Linear(representation_size, num_bbox_reg_classes * 4) + + def __call__(self, x): + scores = self.cls_score(x) + bbox_deltas = self.bbox_pred(x) + return scores, bbox_deltas + + +class PostProcessor: + # Not used in training + def __init__( + self, + score_thresh=0.05, + nms=0.5, + detections_per_img=100, + box_coder=None, + cls_agnostic_bbox_reg=False + ): + self.score_thresh = score_thresh + self.nms = nms + self.detections_per_img = detections_per_img + if box_coder is None: + box_coder = BoxCoder(weights=(10., 10., 5., 5.)) + self.box_coder = box_coder + self.cls_agnostic_bbox_reg = cls_agnostic_bbox_reg + + def __call__(self, x, boxes): + class_logits, box_regression = x + class_prob = Tensor.softmax(class_logits, -1) + image_shapes = [box.size for box in boxes] + boxes_per_image = [len(box) for box in boxes] + concat_boxes = Tensor.cat(*[a.bbox for a in boxes], dim=0) + + if self.cls_agnostic_bbox_reg: + box_regression = box_regression[:, -4:] + proposals = self.box_coder.decode( + box_regression.reshape(sum(boxes_per_image), -1), concat_boxes + ) + if self.cls_agnostic_bbox_reg: + proposals = proposals.repeat([1, class_prob.shape[1]]) + num_classes = class_prob.shape[1] + proposals = proposals.unsqueeze(0) + class_prob = class_prob.unsqueeze(0) + results = [] + for prob, boxes_per_img, image_shape in zip( + class_prob, proposals, image_shapes + ): + boxlist = self.prepare_boxlist(boxes_per_img, prob, image_shape) + boxlist = boxlist.clip_to_image(remove_empty=False) + boxlist = self.filter_results(boxlist, num_classes) + results.append(boxlist) + return results + + def prepare_boxlist(self, boxes, scores, image_shape): + boxes = boxes.reshape(-1, 4) + scores = scores.reshape(-1) + boxlist = BoxList(boxes, image_shape, mode="xyxy") + boxlist.add_field("scores", scores) + return boxlist + + def filter_results(self, boxlist, num_classes): + boxes = boxlist.bbox.reshape(-1, num_classes * 4) + scores = boxlist.get_field("scores").reshape(-1, num_classes) + + device = scores.device + result = [] + scores = scores.numpy() + boxes = boxes.numpy() + inds_all = scores > self.score_thresh + for j in range(1, num_classes): + inds = inds_all[:, j].nonzero()[0] + # This needs to be done in numpy because it can create empty arrays + scores_j = scores[inds, j] + boxes_j = boxes[inds, j * 4: (j + 1) * 4] + boxes_j = Tensor(boxes_j) + scores_j = Tensor(scores_j) + boxlist_for_class = BoxList(boxes_j, boxlist.size, mode="xyxy") + boxlist_for_class.add_field("scores", scores_j) + if len(boxlist_for_class): + boxlist_for_class = boxlist_nms( + boxlist_for_class, self.nms + ) + num_labels = len(boxlist_for_class) + boxlist_for_class.add_field( + "labels", Tensor.full((num_labels,), j, device=device) + ) + result.append(boxlist_for_class) + + result = cat_boxlist(result) + number_of_detections = len(result) + + if number_of_detections > self.detections_per_img > 0: + cls_scores = result.get_field("scores") + image_thresh, _ = topk(cls_scores, k=self.detections_per_img) + image_thresh = image_thresh.numpy()[-1] + keep = (cls_scores.numpy() >= image_thresh).nonzero()[0] + result = result[keep] + return result + + +class RoIBoxHead: + def __init__(self, in_channels): + self.feature_extractor = FPN2MLPFeatureExtractor(in_channels) + self.predictor = FPNPredictor() + self.post_processor = PostProcessor( + score_thresh=0.05, + nms=0.5, + detections_per_img=100, + box_coder=BoxCoder(weights=(10., 10., 5., 5.)), + cls_agnostic_bbox_reg=False + ) + + def __call__(self, features, proposals, targets=None): + x = self.feature_extractor(features, proposals) + class_logits, box_regression = self.predictor(x) + if not Tensor.training: + result = self.post_processor((class_logits, box_regression), proposals) + return x, result, {} + + +class MaskPostProcessor: + # Not used in loss calculation + def __call__(self, x, boxes): + mask_prob = x.sigmoid().numpy() + num_masks = x.shape[0] + labels = [bbox.get_field("labels") for bbox in boxes] + labels = Tensor.cat(*labels).numpy().astype(np.int32) + index = np.arange(num_masks) + mask_prob = mask_prob[index, labels][:, None] + boxes_per_image, cumsum = [], 0 + for box in boxes: + cumsum += len(box) + boxes_per_image.append(cumsum) + # using numpy here as Tensor.chunk doesnt have custom chunk sizes + mask_prob = np.split(mask_prob, boxes_per_image, axis=0) + results = [] + for prob, box in zip(mask_prob, boxes): + bbox = BoxList(box.bbox, box.size, mode="xyxy") + for field in box.fields(): + bbox.add_field(field, box.get_field(field)) + prob = Tensor(prob) + bbox.add_field("mask", prob) + results.append(bbox) + + return results + + +class Mask: + def __init__(self): + self.feature_extractor = MaskRCNNFPNFeatureExtractor() + self.predictor = MaskRCNNC4Predictor() + self.post_processor = MaskPostProcessor() + + def __call__(self, features, proposals, targets=None): + x = self.feature_extractor(features, proposals) + if x: + mask_logits = self.predictor(x) + if not Tensor.training: + result = self.post_processor(mask_logits, proposals) + return x, result, {} + return x, [], {} + + +class RoIHeads: + def __init__(self, in_channels): + self.box = RoIBoxHead(in_channels) + self.mask = Mask() + + def __call__(self, features, proposals, targets=None): + x, detections, _ = self.box(features, proposals, targets) + x, detections, _ = self.mask(features, detections, targets) + return x, detections, {} + + +class ImageList(object): + def __init__(self, tensors, image_sizes): + self.tensors = tensors + self.image_sizes = image_sizes + + def to(self, *args, **kwargs): + cast_tensor = self.tensors.to(*args, **kwargs) + return ImageList(cast_tensor, self.image_sizes) + + +def to_image_list(tensors, size_divisible=32): + # Preprocessing + if isinstance(tensors, Tensor) and size_divisible > 0: + tensors = [tensors] + + if isinstance(tensors, ImageList): + return tensors + elif isinstance(tensors, Tensor): + # single tensor shape can be inferred + assert tensors.ndim == 4 + image_sizes = [tensor.shape[-2:] for tensor in tensors] + return ImageList(tensors, image_sizes) + elif isinstance(tensors, (tuple, list)): + max_size = tuple(max(s) for s in zip(*[img.shape for img in tensors])) + if size_divisible > 0: + + stride = size_divisible + max_size = list(max_size) + max_size[1] = int(math.ceil(max_size[1] / stride) * stride) + max_size[2] = int(math.ceil(max_size[2] / stride) * stride) + max_size = tuple(max_size) + + batch_shape = (len(tensors),) + max_size + batched_imgs = np.zeros(batch_shape, dtype=tensors[0].dtype.np) + for img, pad_img in zip(tensors, batched_imgs): + pad_img[: img.shape[0], : img.shape[1], : img.shape[2]] += img.numpy() + + batched_imgs = Tensor(batched_imgs) + image_sizes = [im.shape[-2:] for im in tensors] + + return ImageList(batched_imgs, image_sizes) + else: + raise TypeError("Unsupported type for to_image_list: {}".format(type(tensors))) + + +class MaskRCNN: + def __init__(self, backbone: ResNet): + self.backbone = ResNetFPN(backbone, out_channels=256) + self.rpn = RPN(self.backbone.out_channels) + self.roi_heads = RoIHeads(self.backbone.out_channels) + + def load_from_pretrained(self): + fn = Path('./') / "weights/maskrcnn.pt" + download_file("https://download.pytorch.org/models/maskrcnn/e2e_mask_rcnn_R_50_FPN_1x.pth", fn) + + state_dict = torch_load(fn)['model'] + loaded_keys = [] + for k, v in state_dict.items(): + if "module." in k: + k = k.replace("module.", "") + if "stem." in k: + k = k.replace("stem.", "") + if "fpn_inner" in k: + block_index = int(re.search(r"fpn_inner(\d+)", k).group(1)) + k = re.sub(r"fpn_inner\d+", f"inner_blocks.{block_index - 1}", k) + if "fpn_layer" in k: + block_index = int(re.search(r"fpn_layer(\d+)", k).group(1)) + k = re.sub(r"fpn_layer\d+", f"layer_blocks.{block_index - 1}", k) + loaded_keys.append(k) + get_child(self, k).assign(v.numpy()).realize() + return loaded_keys + + def __call__(self, images): + images = to_image_list(images) + features = self.backbone(images.tensors) + proposals, _ = self.rpn(images, features) + x, result, _ = self.roi_heads(features, proposals) + return result + + +if __name__ == '__main__': + resnet = resnet = ResNet(50, num_classes=None, stride_in_1x1=True) + model = MaskRCNN(backbone=resnet) + model.load_from_pretrained() diff --git a/models/resnet.py b/models/resnet.py index 8bc955a3..b5a756c0 100644 --- a/models/resnet.py +++ b/models/resnet.py @@ -27,14 +27,15 @@ class BasicBlock: class Bottleneck: - # NOTE: the original implementation places stride at the first convolution (self.conv1), this is the v1.5 variant + # NOTE: stride_in_1x1=False, this is the v1.5 variant expansion = 4 - def __init__(self, in_planes, planes, stride=1, groups=1, base_width=64): + def __init__(self, in_planes, planes, stride=1, stride_in_1x1=False, groups=1, base_width=64): width = int(planes * (base_width / 64.0)) * groups - self.conv1 = nn.Conv2d(in_planes, width, kernel_size=1, bias=False) + # NOTE: the original implementation places stride at the first convolution (self.conv1), control with stride_in_1x1 + self.conv1 = nn.Conv2d(in_planes, width, kernel_size=1, stride=stride if stride_in_1x1 else 1, bias=False) self.bn1 = nn.BatchNorm2d(width) - self.conv2 = nn.Conv2d(width, width, kernel_size=3, padding=1, stride=stride, groups=groups, bias=False) + self.conv2 = nn.Conv2d(width, width, kernel_size=3, padding=1, stride=1 if stride_in_1x1 else stride, groups=groups, bias=False) self.bn2 = nn.BatchNorm2d(width) self.conv3 = nn.Conv2d(width, self.expansion*planes, kernel_size=1, bias=False) self.bn3 = nn.BatchNorm2d(self.expansion*planes) @@ -54,9 +55,8 @@ class Bottleneck: return out class ResNet: - def __init__(self, num, num_classes, groups=1, width_per_group=64): + def __init__(self, num, num_classes=None, groups=1, width_per_group=64, stride_in_1x1=False): self.num = num - self.block = { 18: BasicBlock, 34: BasicBlock, @@ -79,30 +79,41 @@ class ResNet: self.base_width = width_per_group self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, bias=False, padding=3) self.bn1 = nn.BatchNorm2d(64) - self.layer1 = self._make_layer(self.block, 64, self.num_blocks[0], stride=1) - self.layer2 = self._make_layer(self.block, 128, self.num_blocks[1], stride=2) - self.layer3 = self._make_layer(self.block, 256, self.num_blocks[2], stride=2) - self.layer4 = self._make_layer(self.block, 512, self.num_blocks[3], stride=2) - self.fc = nn.Linear(512 * self.block.expansion, num_classes) + self.layer1 = self._make_layer(self.block, 64, self.num_blocks[0], stride=1, stride_in_1x1=stride_in_1x1) + self.layer2 = self._make_layer(self.block, 128, self.num_blocks[1], stride=2, stride_in_1x1=stride_in_1x1) + self.layer3 = self._make_layer(self.block, 256, self.num_blocks[2], stride=2, stride_in_1x1=stride_in_1x1) + self.layer4 = self._make_layer(self.block, 512, self.num_blocks[3], stride=2, stride_in_1x1=stride_in_1x1) + self.fc = nn.Linear(512 * self.block.expansion, num_classes) if num_classes is not None else None - def _make_layer(self, block, planes, num_blocks, stride): + def _make_layer(self, block, planes, num_blocks, stride, stride_in_1x1): strides = [stride] + [1] * (num_blocks-1) layers = [] for stride in strides: - layers.append(block(self.in_planes, planes, stride, self.groups, self.base_width)) + if block == Bottleneck: + layers.append(block(self.in_planes, planes, stride, stride_in_1x1, self.groups, self.base_width)) + else: + layers.append(block(self.in_planes, planes, stride, self.groups, self.base_width)) self.in_planes = planes * block.expansion return layers def forward(self, x): + is_feature_only = self.fc is None + if is_feature_only: features = [] out = self.bn1(self.conv1(x)).relu() out = out.pad2d([1,1,1,1]).max_pool2d((3,3), 2) out = out.sequential(self.layer1) + if is_feature_only: features.append(out) out = out.sequential(self.layer2) + if is_feature_only: features.append(out) out = out.sequential(self.layer3) + if is_feature_only: features.append(out) out = out.sequential(self.layer4) - out = out.mean([2,3]) - out = self.fc(out).log_softmax() - return out + if is_feature_only: features.append(out) + if not is_feature_only: + out = out.mean([2,3]) + out = self.fc(out).log_softmax() + return out + return features def __call__(self, x): return self.forward(x) @@ -140,4 +151,4 @@ ResNet34 = lambda num_classes=1000: ResNet(34, num_classes=num_classes) ResNet50 = lambda num_classes=1000: ResNet(50, num_classes=num_classes) ResNet101 = lambda num_classes=1000: ResNet(101, num_classes=num_classes) ResNet152 = lambda num_classes=1000: ResNet(152, num_classes=num_classes) -ResNeXt50_32X4D = lambda num_classes=1000: ResNet(50, num_classes=num_classes, groups=32, width_per_group=4) +ResNeXt50_32X4D = lambda num_classes=1000: ResNet(50, num_classes=num_classes, groups=32, width_per_group=4) \ No newline at end of file diff --git a/test/test_ops.py b/test/test_ops.py index 2d21180b..db0952d1 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -243,6 +243,9 @@ class TestOps(unittest.TestCase): def test_log(self): helper_test_op([(45,65)], lambda x: torch.log(x), Tensor.log) helper_test_op([()], lambda x: torch.log(x), Tensor.log) + def test_log2(self): + helper_test_op([(45,65)], lambda x: torch.log2(x), Tensor.log2) + helper_test_op([()], lambda x: torch.log2(x), Tensor.log2) def test_exp(self): helper_test_op([(45,65)], lambda x: torch.exp(x), Tensor.exp) helper_test_op([()], lambda x: torch.exp(x), Tensor.exp) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 31da9886..d3f29b2b 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -7,7 +7,7 @@ import operator import numpy as np from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, cast from tinygrad.helpers import ImageDType, argfix, make_pair, getenv, IMAGE, DEBUG, flatten, DType, dtypes -from math import ceil, pi, prod, sqrt +from math import ceil, pi, prod, sqrt, log from tinygrad.lazy import Device, LazyBuffer from tinygrad.ops import LoadOps @@ -481,6 +481,7 @@ class Tensor: def contiguous(self): return mlops.Contiguous.apply(self) def log(self): return mlops.Log.apply(self) + def log2(self): return mlops.Log.apply(self)/log(2) def exp(self): return mlops.Exp.apply(self) def relu(self): return mlops.Relu.apply(self) def sin(self): return mlops.Sin.apply(self)