mirror of https://github.com/commaai/tinygrad.git
MaskRCNN Inference (#884)
* MaskRCNN weights loading * backbone maybe works * backbone works, but resnet body atol 1e-3 * RPN Call, but veryy wrong output * fixed topk * RPN maybe works, not sure about nms * Fix cursed modules * add back editorconfig * Full call, wrong output * Full call works * fix mask * use NMS from retinanet * Removing extra funcs * refactor * readable * Add example to run model * remove filter * Fix split, batched inference is worse * Fix image sizes * Matching reference * merge master * add filter on top detections * cuda backend fixed * add model eval and spec * convert images to rgb * fix eval * simplify examples code * remove extra code * meshgrid using tinygrad * removing numpy * roi align, floor, ceil * remove numpy from level_mapper * remove numpy from pooler * Revert "Merge branch 'master' of github.com:kunwar31/tinygrad into mrcnn-inference" This reverts commit 4b95a3cb499393bb68b95500cd736d50a93d3ce4, reversing changes made to 98f2b1fa2ede20113b1b369ac00d4b2a7ca5fbfa. * roi align gather * fix master merge * revert to old floor, ceil as ints present in domain * use log2 op * fix indexes * weird bug with ints and gpu * weird bug with ints and gpu * refactors, add env var for gather * floor with contiguous, where * refactor topk, sort * remove staticmethod * refactor stride * remove log2 mlop * realize -> contiguous * refactor forward * remove num_classes, stride_in_1x1 from state * refactor forward * refactoring * flake8 * removing numpy in anchor gen, use numpy for gather, nonzero, optimize topk * keep using tinygrad for smaller gathers * fix empty tensors * comms * move from tensor.py * resnet test passing * add coco dataset back * fix spaces * add test for log2 * no need to create Tensors * no need to create Tensors --------- Co-authored-by: Kunwar Raj Singh <kunwar31@pop-os.localdomain>
This commit is contained in:
parent
0f281e7b18
commit
5d3310ce56
|
@ -0,0 +1,200 @@
|
|||
import json
|
||||
import pathlib
|
||||
import zipfile
|
||||
import numpy as np
|
||||
from extra.utils import download_file
|
||||
import pycocotools._mask as _mask
|
||||
from examples.mask_rcnn import Masker
|
||||
from pycocotools.coco import COCO
|
||||
from pycocotools.cocoeval import COCOeval
|
||||
|
||||
iou = _mask.iou
|
||||
merge = _mask.merge
|
||||
frPyObjects = _mask.frPyObjects
|
||||
|
||||
BASEDIR = pathlib.Path(__file__).parent.parent / "datasets/COCO"
|
||||
|
||||
def create_dict(key_row, val_row, rows): return {row[key_row]:row[val_row] for row in rows}
|
||||
|
||||
|
||||
if not pathlib.Path(BASEDIR/'val2017').is_dir():
|
||||
fn = BASEDIR/'val2017.zip'
|
||||
download_file('http://images.cocodataset.org/zips/val2017.zip',fn)
|
||||
with zipfile.ZipFile(fn, 'r') as zip_ref:
|
||||
zip_ref.extractall(BASEDIR)
|
||||
fn.unlink()
|
||||
|
||||
|
||||
if not pathlib.Path(BASEDIR/'annotations').is_dir():
|
||||
fn = BASEDIR/'annotations_trainval2017.zip'
|
||||
download_file('http://images.cocodataset.org/annotations/annotations_trainval2017.zip',fn)
|
||||
with zipfile.ZipFile(fn, 'r') as zip_ref:
|
||||
zip_ref.extractall(BASEDIR)
|
||||
fn.unlink()
|
||||
|
||||
with open(BASEDIR/'annotations/instances_val2017.json', 'r') as f:
|
||||
annotations_raw = json.loads(f.read())
|
||||
images = annotations_raw['images']
|
||||
categories = annotations_raw['categories']
|
||||
annotations = annotations_raw['annotations']
|
||||
file_name_to_id = create_dict('file_name', 'id', images)
|
||||
id_to_width = create_dict('id', 'width', images)
|
||||
id_to_height = create_dict('id', 'height', images)
|
||||
json_category_id_to_contiguous_id = {v['id']: i + 1 for i, v in enumerate(categories)}
|
||||
contiguous_category_id_to_json_id = {v:k for k,v in json_category_id_to_contiguous_id.items()}
|
||||
|
||||
|
||||
def encode(bimask):
|
||||
if len(bimask.shape) == 3:
|
||||
return _mask.encode(bimask)
|
||||
elif len(bimask.shape) == 2:
|
||||
h, w = bimask.shape
|
||||
return _mask.encode(bimask.reshape((h, w, 1), order='F'))[0]
|
||||
|
||||
def decode(rleObjs):
|
||||
if type(rleObjs) == list:
|
||||
return _mask.decode(rleObjs)
|
||||
else:
|
||||
return _mask.decode([rleObjs])[:,:,0]
|
||||
|
||||
def area(rleObjs):
|
||||
if type(rleObjs) == list:
|
||||
return _mask.area(rleObjs)
|
||||
else:
|
||||
return _mask.area([rleObjs])[0]
|
||||
|
||||
def toBbox(rleObjs):
|
||||
if type(rleObjs) == list:
|
||||
return _mask.toBbox(rleObjs)
|
||||
else:
|
||||
return _mask.toBbox([rleObjs])[0]
|
||||
|
||||
|
||||
def convert_prediction_to_coco_bbox(file_name, prediction):
|
||||
coco_results = []
|
||||
try:
|
||||
original_id = file_name_to_id[file_name]
|
||||
if len(prediction) == 0:
|
||||
return coco_results
|
||||
|
||||
image_width = id_to_width[original_id]
|
||||
image_height = id_to_height[original_id]
|
||||
prediction = prediction.resize((image_width, image_height))
|
||||
prediction = prediction.convert("xywh")
|
||||
|
||||
boxes = prediction.bbox.numpy().tolist()
|
||||
scores = prediction.get_field("scores").numpy().tolist()
|
||||
labels = prediction.get_field("labels").numpy().tolist()
|
||||
|
||||
mapped_labels = [contiguous_category_id_to_json_id[int(i)] for i in labels]
|
||||
|
||||
coco_results.extend(
|
||||
[
|
||||
{
|
||||
"image_id": original_id,
|
||||
"category_id": mapped_labels[k],
|
||||
"bbox": box,
|
||||
"score": scores[k],
|
||||
}
|
||||
for k, box in enumerate(boxes)
|
||||
]
|
||||
)
|
||||
except Exception as e:
|
||||
print(file_name, e)
|
||||
return coco_results
|
||||
|
||||
masker = Masker(threshold=0.5, padding=1)
|
||||
|
||||
def convert_prediction_to_coco_mask(file_name, prediction):
|
||||
coco_results = []
|
||||
try:
|
||||
original_id = file_name_to_id[file_name]
|
||||
if len(prediction) == 0:
|
||||
return coco_results
|
||||
|
||||
image_width = id_to_width[original_id]
|
||||
image_height = id_to_height[original_id]
|
||||
prediction = prediction.resize((image_width, image_height))
|
||||
masks = prediction.get_field("mask")
|
||||
|
||||
scores = prediction.get_field("scores").numpy().tolist()
|
||||
labels = prediction.get_field("labels").numpy().tolist()
|
||||
|
||||
masks = masker([masks], [prediction])[0].numpy()
|
||||
|
||||
rles = [
|
||||
encode(np.array(mask[0, :, :, np.newaxis], order="F"))[0]
|
||||
for mask in masks
|
||||
]
|
||||
for rle in rles:
|
||||
rle["counts"] = rle["counts"].decode("utf-8")
|
||||
|
||||
mapped_labels = [contiguous_category_id_to_json_id[int(i)] for i in labels]
|
||||
|
||||
coco_results.extend(
|
||||
[
|
||||
{
|
||||
"image_id": original_id,
|
||||
"category_id": mapped_labels[k],
|
||||
"segmentation": rle,
|
||||
"score": scores[k],
|
||||
}
|
||||
for k, rle in enumerate(rles)
|
||||
]
|
||||
)
|
||||
except Exception as e:
|
||||
print(file_name, e)
|
||||
return coco_results
|
||||
|
||||
|
||||
|
||||
def accumulate_predictions_for_coco(coco_results, json_result_file, rm=False):
|
||||
path = pathlib.Path(json_result_file)
|
||||
if rm and path.exists(): path.unlink()
|
||||
with open(path, "a") as f:
|
||||
for s in coco_results:
|
||||
f.write(json.dumps(s))
|
||||
f.write('\n')
|
||||
|
||||
def remove_dup(l):
|
||||
seen = set()
|
||||
seen_add = seen.add
|
||||
return [x for x in l if not (x in seen or seen_add(x))]
|
||||
|
||||
class NpEncoder(json.JSONEncoder):
|
||||
def default(self, obj):
|
||||
if isinstance(obj, np.integer):
|
||||
return int(obj)
|
||||
if isinstance(obj, np.floating):
|
||||
return float(obj)
|
||||
if isinstance(obj, np.ndarray):
|
||||
return obj.tolist()
|
||||
return super(NpEncoder, self).default(obj)
|
||||
|
||||
|
||||
def evaluate_predictions_on_coco(json_result_file, iou_type="bbox"):
|
||||
coco_results = []
|
||||
with open(json_result_file, "r") as f:
|
||||
for line in f:
|
||||
coco_results.append(json.loads(line))
|
||||
|
||||
coco_gt = COCO(str(BASEDIR/'annotations/instances_val2017.json'))
|
||||
set_of_json = remove_dup([json.dumps(d, cls=NpEncoder) for d in coco_results])
|
||||
unique_list = [json.loads(s) for s in set_of_json]
|
||||
|
||||
with open(f'{json_result_file}.flattend', "w") as f:
|
||||
json.dump(unique_list, f)
|
||||
|
||||
coco_dt = coco_gt.loadRes(str(f'{json_result_file}.flattend'))
|
||||
coco_eval = COCOeval(coco_gt, coco_dt, iou_type)
|
||||
coco_eval.evaluate()
|
||||
coco_eval.accumulate()
|
||||
coco_eval.summarize()
|
||||
return coco_eval
|
||||
|
||||
def iterate(files, bs=1):
|
||||
batch = []
|
||||
for file in files:
|
||||
batch.append(file)
|
||||
if len(batch) >= bs: yield batch; batch = []
|
||||
if len(batch) > 0: yield batch; batch = []
|
|
@ -0,0 +1,299 @@
|
|||
from models.mask_rcnn import MaskRCNN
|
||||
from models.resnet import ResNet
|
||||
from models.mask_rcnn import BoxList
|
||||
from torch.nn import functional as F
|
||||
from torchvision import transforms as T
|
||||
from torchvision.transforms import functional as Ft
|
||||
import random
|
||||
from tinygrad.tensor import Tensor
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import torch
|
||||
import argparse
|
||||
import cv2
|
||||
|
||||
|
||||
class Resize:
|
||||
def __init__(self, min_size, max_size):
|
||||
if not isinstance(min_size, (list, tuple)):
|
||||
min_size = (min_size,)
|
||||
self.min_size = min_size
|
||||
self.max_size = max_size
|
||||
|
||||
# modified from torchvision to add support for max size
|
||||
def get_size(self, image_size):
|
||||
w, h = image_size
|
||||
size = random.choice(self.min_size)
|
||||
max_size = self.max_size
|
||||
if max_size is not None:
|
||||
min_original_size = float(min((w, h)))
|
||||
max_original_size = float(max((w, h)))
|
||||
if max_original_size / min_original_size * size > max_size:
|
||||
size = int(round(max_size * min_original_size / max_original_size))
|
||||
|
||||
if (w <= h and w == size) or (h <= w and h == size):
|
||||
return (h, w)
|
||||
|
||||
if w < h:
|
||||
ow = size
|
||||
oh = int(size * h / w)
|
||||
else:
|
||||
oh = size
|
||||
ow = int(size * w / h)
|
||||
|
||||
return (oh, ow)
|
||||
|
||||
def __call__(self, image):
|
||||
size = self.get_size(image.size)
|
||||
image = Ft.resize(image, size)
|
||||
return image
|
||||
|
||||
|
||||
class Normalize:
|
||||
def __init__(self, mean, std, to_bgr255=True):
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
self.to_bgr255 = to_bgr255
|
||||
|
||||
def __call__(self, image):
|
||||
if self.to_bgr255:
|
||||
image = image[[2, 1, 0]] * 255
|
||||
else:
|
||||
image = image[[0, 1, 2]] * 255
|
||||
image = Ft.normalize(image, mean=self.mean, std=self.std)
|
||||
return image
|
||||
|
||||
transforms = lambda size_scale: T.Compose(
|
||||
[
|
||||
Resize(int(800*size_scale), int(1333*size_scale)),
|
||||
T.ToTensor(),
|
||||
Normalize(
|
||||
mean=[102.9801, 115.9465, 122.7717], std=[1., 1., 1.], to_bgr255=True
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def expand_boxes(boxes, scale):
|
||||
w_half = (boxes[:, 2] - boxes[:, 0]) * .5
|
||||
h_half = (boxes[:, 3] - boxes[:, 1]) * .5
|
||||
x_c = (boxes[:, 2] + boxes[:, 0]) * .5
|
||||
y_c = (boxes[:, 3] + boxes[:, 1]) * .5
|
||||
|
||||
w_half *= scale
|
||||
h_half *= scale
|
||||
|
||||
boxes_exp = torch.zeros_like(boxes)
|
||||
boxes_exp[:, 0] = x_c - w_half
|
||||
boxes_exp[:, 2] = x_c + w_half
|
||||
boxes_exp[:, 1] = y_c - h_half
|
||||
boxes_exp[:, 3] = y_c + h_half
|
||||
return boxes_exp
|
||||
|
||||
|
||||
def expand_masks(mask, padding):
|
||||
N = mask.shape[0]
|
||||
M = mask.shape[-1]
|
||||
pad2 = 2 * padding
|
||||
scale = float(M + pad2) / M
|
||||
padded_mask = mask.new_zeros((N, 1, M + pad2, M + pad2))
|
||||
padded_mask[:, :, padding:-padding, padding:-padding] = mask
|
||||
return padded_mask, scale
|
||||
|
||||
|
||||
def paste_mask_in_image(mask, box, im_h, im_w, thresh=0.5, padding=1):
|
||||
# TODO: remove torch
|
||||
mask = torch.tensor(mask.numpy())
|
||||
box = torch.tensor(box.numpy())
|
||||
padded_mask, scale = expand_masks(mask[None], padding=padding)
|
||||
mask = padded_mask[0, 0]
|
||||
box = expand_boxes(box[None], scale)[0]
|
||||
box = box.to(dtype=torch.int32)
|
||||
|
||||
TO_REMOVE = 1
|
||||
w = int(box[2] - box[0] + TO_REMOVE)
|
||||
h = int(box[3] - box[1] + TO_REMOVE)
|
||||
w = max(w, 1)
|
||||
h = max(h, 1)
|
||||
|
||||
mask = mask.expand((1, 1, -1, -1))
|
||||
|
||||
mask = mask.to(torch.float32)
|
||||
mask = F.interpolate(mask, size=(h, w), mode='bilinear', align_corners=False)
|
||||
mask = mask[0][0]
|
||||
|
||||
if thresh >= 0:
|
||||
mask = mask > thresh
|
||||
else:
|
||||
mask = (mask * 255).to(torch.uint8)
|
||||
|
||||
im_mask = torch.zeros((im_h, im_w), dtype=torch.uint8)
|
||||
x_0 = max(box[0], 0)
|
||||
x_1 = min(box[2] + 1, im_w)
|
||||
y_0 = max(box[1], 0)
|
||||
y_1 = min(box[3] + 1, im_h)
|
||||
|
||||
im_mask[y_0:y_1, x_0:x_1] = mask[
|
||||
(y_0 - box[1]): (y_1 - box[1]), (x_0 - box[0]): (x_1 - box[0])
|
||||
]
|
||||
return im_mask
|
||||
|
||||
|
||||
class Masker:
|
||||
def __init__(self, threshold=0.5, padding=1):
|
||||
self.threshold = threshold
|
||||
self.padding = padding
|
||||
|
||||
def forward_single_image(self, masks, boxes):
|
||||
boxes = boxes.convert("xyxy")
|
||||
im_w, im_h = boxes.size
|
||||
res = [
|
||||
paste_mask_in_image(mask[0], box, im_h, im_w, self.threshold, self.padding)
|
||||
for mask, box in zip(masks, boxes.bbox)
|
||||
]
|
||||
if len(res) > 0:
|
||||
res = torch.stack(res, dim=0)[:, None]
|
||||
else:
|
||||
res = masks.new_empty((0, 1, masks.shape[-2], masks.shape[-1]))
|
||||
return Tensor(res.numpy())
|
||||
|
||||
def __call__(self, masks, boxes):
|
||||
if isinstance(boxes, BoxList):
|
||||
boxes = [boxes]
|
||||
|
||||
results = []
|
||||
for mask, box in zip(masks, boxes):
|
||||
result = self.forward_single_image(mask, box)
|
||||
results.append(result)
|
||||
return results
|
||||
|
||||
|
||||
masker = Masker(threshold=0.5, padding=1)
|
||||
|
||||
def select_top_predictions(predictions, confidence_threshold=0.9):
|
||||
scores = predictions.get_field("scores").numpy()
|
||||
keep = [idx for idx, score in enumerate(scores) if score > confidence_threshold]
|
||||
return predictions[keep]
|
||||
|
||||
def compute_prediction(original_image, model, confidence_threshold, size_scale=1.0):
|
||||
image = transforms(size_scale)(original_image).numpy()
|
||||
image = Tensor(image, requires_grad=False)
|
||||
predictions = model(image)
|
||||
prediction = predictions[0]
|
||||
prediction = select_top_predictions(prediction, confidence_threshold)
|
||||
width, height = original_image.size
|
||||
prediction = prediction.resize((width, height))
|
||||
|
||||
if prediction.has_field("mask"):
|
||||
masks = prediction.get_field("mask")
|
||||
masks = masker([masks], [prediction])[0]
|
||||
prediction.add_field("mask", masks)
|
||||
return prediction
|
||||
|
||||
def compute_prediction_batched(batch, model, size_scale=1.0):
|
||||
imgs = []
|
||||
for img in batch:
|
||||
imgs.append(transforms(size_scale)(img).numpy())
|
||||
image = [Tensor(image, requires_grad=False) for image in imgs]
|
||||
predictions = model(image)
|
||||
del image
|
||||
return predictions
|
||||
|
||||
palette = np.array([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])
|
||||
|
||||
def findContours(*args, **kwargs):
|
||||
if cv2.__version__.startswith('4'):
|
||||
contours, hierarchy = cv2.findContours(*args, **kwargs)
|
||||
elif cv2.__version__.startswith('3'):
|
||||
_, contours, hierarchy = cv2.findContours(*args, **kwargs)
|
||||
return contours, hierarchy
|
||||
|
||||
def compute_colors_for_labels(labels):
|
||||
l = labels[:, None]
|
||||
colors = l * palette
|
||||
colors = (colors % 255).astype("uint8")
|
||||
return colors
|
||||
|
||||
def overlay_mask(image, predictions):
|
||||
image = np.asarray(image)
|
||||
masks = predictions.get_field("mask").numpy()
|
||||
labels = predictions.get_field("labels").numpy()
|
||||
|
||||
colors = compute_colors_for_labels(labels).tolist()
|
||||
|
||||
for mask, color in zip(masks, colors):
|
||||
thresh = mask[0, :, :, None]
|
||||
contours, hierarchy = findContours(
|
||||
thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
|
||||
)
|
||||
image = cv2.drawContours(image, contours, -1, color, 3)
|
||||
|
||||
composite = image
|
||||
|
||||
return composite
|
||||
|
||||
CATEGORIES = [
|
||||
"__background", "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light",
|
||||
"fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow", "elephant",
|
||||
"bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard",
|
||||
"sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle",
|
||||
"wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange", "broccoli",
|
||||
"carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", "potted plant", "bed", "dining table",
|
||||
"toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven", "toaster",
|
||||
"sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush",
|
||||
]
|
||||
|
||||
def overlay_boxes(image, predictions):
|
||||
labels = predictions.get_field("labels").numpy()
|
||||
boxes = predictions.bbox
|
||||
image = np.asarray(image)
|
||||
colors = compute_colors_for_labels(labels).tolist()
|
||||
|
||||
for box, color in zip(boxes, colors):
|
||||
box = torch.tensor(box.numpy())
|
||||
box = box.to(torch.int64)
|
||||
top_left, bottom_right = box[:2].tolist(), box[2:].tolist()
|
||||
image = cv2.rectangle(
|
||||
image, tuple(top_left), tuple(bottom_right), tuple(color), 1
|
||||
)
|
||||
|
||||
return image
|
||||
|
||||
def overlay_class_names(image, predictions):
|
||||
scores = predictions.get_field("scores").numpy().tolist()
|
||||
labels = predictions.get_field("labels").numpy().tolist()
|
||||
labels = [CATEGORIES[int(i)] for i in labels]
|
||||
boxes = predictions.bbox.numpy()
|
||||
image = np.asarray(image)
|
||||
template = "{}: {:.2f}"
|
||||
for box, score, label in zip(boxes, scores, labels):
|
||||
x, y = box[:2]
|
||||
s = template.format(label, score)
|
||||
x, y = int(x), int(y)
|
||||
cv2.putText(
|
||||
image, s, (x, y), cv2.FONT_HERSHEY_SIMPLEX, .5, (255, 255, 255), 1
|
||||
)
|
||||
|
||||
return image
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Run MaskRCNN', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument('--image', type=str, help="Path of the image to run")
|
||||
parser.add_argument('--threshold', type=float, default=0.7, help="Detector threshold")
|
||||
parser.add_argument('--size_scale', type=float, default=1.0, help="Image resize multiplier")
|
||||
parser.add_argument('--out', type=str, default="/tmp/rendered.png", help="Output filename")
|
||||
args = parser.parse_args()
|
||||
|
||||
resnet = ResNet(50, num_classes=None, stride_in_1x1=True)
|
||||
model_tiny = MaskRCNN(resnet)
|
||||
model_tiny.load_from_pretrained()
|
||||
img = Image.open(args.image)
|
||||
top_result_tiny = compute_prediction(img, model_tiny, confidence_threshold=args.threshold, size_scale=args.size_scale)
|
||||
bbox_image = overlay_boxes(img, top_result_tiny)
|
||||
mask_image = overlay_mask(bbox_image, top_result_tiny)
|
||||
final_image = overlay_class_names(mask_image, top_result_tiny)
|
||||
|
||||
im = Image.fromarray(final_image)
|
||||
print(f"saving {args.out}")
|
||||
im.save(args.out)
|
||||
im.show()
|
|
@ -184,14 +184,51 @@ def eval_bert():
|
|||
|
||||
st = time.perf_counter()
|
||||
|
||||
def eval_mrcnn():
|
||||
from tqdm import tqdm
|
||||
from models.mask_rcnn import MaskRCNN
|
||||
from models.resnet import ResNet
|
||||
from datasets.coco import BASEDIR, images, convert_prediction_to_coco_bbox, convert_prediction_to_coco_mask, accumulate_predictions_for_coco, evaluate_predictions_on_coco, iterate
|
||||
from examples.mask_rcnn import compute_prediction_batched, Image
|
||||
mdl = MaskRCNN(ResNet(50, num_classes=None, stride_in_1x1=True))
|
||||
mdl.load_from_pretrained()
|
||||
|
||||
bbox_output = '/tmp/results_bbox.json'
|
||||
mask_output = '/tmp/results_mask.json'
|
||||
|
||||
accumulate_predictions_for_coco([], bbox_output, rm=True)
|
||||
accumulate_predictions_for_coco([], mask_output, rm=True)
|
||||
|
||||
#TODO: bs > 1 not as accurate
|
||||
bs = 1
|
||||
|
||||
for batch in tqdm(iterate(images, bs=bs), total=len(images)//bs):
|
||||
batch_imgs = []
|
||||
for image_row in batch:
|
||||
image_name = image_row['file_name']
|
||||
img = Image.open(BASEDIR/f'val2017/{image_name}').convert("RGB")
|
||||
batch_imgs.append(img)
|
||||
batch_result = compute_prediction_batched(batch_imgs, mdl)
|
||||
for image_row, result in zip(batch, batch_result):
|
||||
image_name = image_row['file_name']
|
||||
box_pred = convert_prediction_to_coco_bbox(image_name, result)
|
||||
mask_pred = convert_prediction_to_coco_mask(image_name, result)
|
||||
accumulate_predictions_for_coco(box_pred, bbox_output)
|
||||
accumulate_predictions_for_coco(mask_pred, mask_output)
|
||||
del batch_imgs
|
||||
del batch_result
|
||||
|
||||
evaluate_predictions_on_coco(bbox_output, iou_type='bbox')
|
||||
evaluate_predictions_on_coco(mask_output, iou_type='segm')
|
||||
|
||||
if __name__ == "__main__":
|
||||
# inference only
|
||||
Tensor.training = False
|
||||
Tensor.no_grad = True
|
||||
|
||||
models = getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert").split(",")
|
||||
models = getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,mrcnn").split(",")
|
||||
for m in models:
|
||||
nm = f"eval_{m}"
|
||||
if nm in globals():
|
||||
print(f"eval {m}")
|
||||
globals()[nm]()
|
||||
globals()[nm]()
|
|
@ -5,7 +5,8 @@ import numpy as np
|
|||
|
||||
def test_model(model, *inputs):
|
||||
GlobalCounters.reset()
|
||||
model(*inputs).numpy()
|
||||
out = model(*inputs)
|
||||
if isinstance(out, Tensor): out = out.numpy()
|
||||
# TODO: return event future to still get the time_sum_s without DEBUG=2
|
||||
print(f"{GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.time_sum_s*1000:.2f} ms")
|
||||
|
||||
|
@ -49,15 +50,21 @@ def spec_bert():
|
|||
tt = Tensor(np.random.randint(0, 2, (1, 384)).astype(np.float32))
|
||||
test_model(mdl, x, am, tt)
|
||||
|
||||
def spec_mrcnn():
|
||||
from models.mask_rcnn import MaskRCNN, ResNet
|
||||
mdl = MaskRCNN(ResNet(50, num_classes=None, stride_in_1x1=True))
|
||||
mdl.load_from_pretrained()
|
||||
x = Tensor.randn(3, 224, 224)
|
||||
test_model(mdl, [x])
|
||||
|
||||
if __name__ == "__main__":
|
||||
# inference only for now
|
||||
Tensor.training = False
|
||||
Tensor.no_grad = True
|
||||
|
||||
for m in getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert").split(","):
|
||||
for m in getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,mrcnn").split(","):
|
||||
nm = f"spec_{m}"
|
||||
if nm in globals():
|
||||
print(f"testing {m}")
|
||||
globals()[nm]()
|
||||
|
||||
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -27,14 +27,15 @@ class BasicBlock:
|
|||
|
||||
|
||||
class Bottleneck:
|
||||
# NOTE: the original implementation places stride at the first convolution (self.conv1), this is the v1.5 variant
|
||||
# NOTE: stride_in_1x1=False, this is the v1.5 variant
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, in_planes, planes, stride=1, groups=1, base_width=64):
|
||||
def __init__(self, in_planes, planes, stride=1, stride_in_1x1=False, groups=1, base_width=64):
|
||||
width = int(planes * (base_width / 64.0)) * groups
|
||||
self.conv1 = nn.Conv2d(in_planes, width, kernel_size=1, bias=False)
|
||||
# NOTE: the original implementation places stride at the first convolution (self.conv1), control with stride_in_1x1
|
||||
self.conv1 = nn.Conv2d(in_planes, width, kernel_size=1, stride=stride if stride_in_1x1 else 1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(width)
|
||||
self.conv2 = nn.Conv2d(width, width, kernel_size=3, padding=1, stride=stride, groups=groups, bias=False)
|
||||
self.conv2 = nn.Conv2d(width, width, kernel_size=3, padding=1, stride=1 if stride_in_1x1 else stride, groups=groups, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(width)
|
||||
self.conv3 = nn.Conv2d(width, self.expansion*planes, kernel_size=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(self.expansion*planes)
|
||||
|
@ -54,9 +55,8 @@ class Bottleneck:
|
|||
return out
|
||||
|
||||
class ResNet:
|
||||
def __init__(self, num, num_classes, groups=1, width_per_group=64):
|
||||
def __init__(self, num, num_classes=None, groups=1, width_per_group=64, stride_in_1x1=False):
|
||||
self.num = num
|
||||
|
||||
self.block = {
|
||||
18: BasicBlock,
|
||||
34: BasicBlock,
|
||||
|
@ -79,30 +79,41 @@ class ResNet:
|
|||
self.base_width = width_per_group
|
||||
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, bias=False, padding=3)
|
||||
self.bn1 = nn.BatchNorm2d(64)
|
||||
self.layer1 = self._make_layer(self.block, 64, self.num_blocks[0], stride=1)
|
||||
self.layer2 = self._make_layer(self.block, 128, self.num_blocks[1], stride=2)
|
||||
self.layer3 = self._make_layer(self.block, 256, self.num_blocks[2], stride=2)
|
||||
self.layer4 = self._make_layer(self.block, 512, self.num_blocks[3], stride=2)
|
||||
self.fc = nn.Linear(512 * self.block.expansion, num_classes)
|
||||
self.layer1 = self._make_layer(self.block, 64, self.num_blocks[0], stride=1, stride_in_1x1=stride_in_1x1)
|
||||
self.layer2 = self._make_layer(self.block, 128, self.num_blocks[1], stride=2, stride_in_1x1=stride_in_1x1)
|
||||
self.layer3 = self._make_layer(self.block, 256, self.num_blocks[2], stride=2, stride_in_1x1=stride_in_1x1)
|
||||
self.layer4 = self._make_layer(self.block, 512, self.num_blocks[3], stride=2, stride_in_1x1=stride_in_1x1)
|
||||
self.fc = nn.Linear(512 * self.block.expansion, num_classes) if num_classes is not None else None
|
||||
|
||||
def _make_layer(self, block, planes, num_blocks, stride):
|
||||
def _make_layer(self, block, planes, num_blocks, stride, stride_in_1x1):
|
||||
strides = [stride] + [1] * (num_blocks-1)
|
||||
layers = []
|
||||
for stride in strides:
|
||||
layers.append(block(self.in_planes, planes, stride, self.groups, self.base_width))
|
||||
if block == Bottleneck:
|
||||
layers.append(block(self.in_planes, planes, stride, stride_in_1x1, self.groups, self.base_width))
|
||||
else:
|
||||
layers.append(block(self.in_planes, planes, stride, self.groups, self.base_width))
|
||||
self.in_planes = planes * block.expansion
|
||||
return layers
|
||||
|
||||
def forward(self, x):
|
||||
is_feature_only = self.fc is None
|
||||
if is_feature_only: features = []
|
||||
out = self.bn1(self.conv1(x)).relu()
|
||||
out = out.pad2d([1,1,1,1]).max_pool2d((3,3), 2)
|
||||
out = out.sequential(self.layer1)
|
||||
if is_feature_only: features.append(out)
|
||||
out = out.sequential(self.layer2)
|
||||
if is_feature_only: features.append(out)
|
||||
out = out.sequential(self.layer3)
|
||||
if is_feature_only: features.append(out)
|
||||
out = out.sequential(self.layer4)
|
||||
out = out.mean([2,3])
|
||||
out = self.fc(out).log_softmax()
|
||||
return out
|
||||
if is_feature_only: features.append(out)
|
||||
if not is_feature_only:
|
||||
out = out.mean([2,3])
|
||||
out = self.fc(out).log_softmax()
|
||||
return out
|
||||
return features
|
||||
|
||||
def __call__(self, x):
|
||||
return self.forward(x)
|
||||
|
@ -140,4 +151,4 @@ ResNet34 = lambda num_classes=1000: ResNet(34, num_classes=num_classes)
|
|||
ResNet50 = lambda num_classes=1000: ResNet(50, num_classes=num_classes)
|
||||
ResNet101 = lambda num_classes=1000: ResNet(101, num_classes=num_classes)
|
||||
ResNet152 = lambda num_classes=1000: ResNet(152, num_classes=num_classes)
|
||||
ResNeXt50_32X4D = lambda num_classes=1000: ResNet(50, num_classes=num_classes, groups=32, width_per_group=4)
|
||||
ResNeXt50_32X4D = lambda num_classes=1000: ResNet(50, num_classes=num_classes, groups=32, width_per_group=4)
|
|
@ -243,6 +243,9 @@ class TestOps(unittest.TestCase):
|
|||
def test_log(self):
|
||||
helper_test_op([(45,65)], lambda x: torch.log(x), Tensor.log)
|
||||
helper_test_op([()], lambda x: torch.log(x), Tensor.log)
|
||||
def test_log2(self):
|
||||
helper_test_op([(45,65)], lambda x: torch.log2(x), Tensor.log2)
|
||||
helper_test_op([()], lambda x: torch.log2(x), Tensor.log2)
|
||||
def test_exp(self):
|
||||
helper_test_op([(45,65)], lambda x: torch.exp(x), Tensor.exp)
|
||||
helper_test_op([()], lambda x: torch.exp(x), Tensor.exp)
|
||||
|
|
|
@ -7,7 +7,7 @@ import operator
|
|||
import numpy as np
|
||||
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, cast
|
||||
from tinygrad.helpers import ImageDType, argfix, make_pair, getenv, IMAGE, DEBUG, flatten, DType, dtypes
|
||||
from math import ceil, pi, prod, sqrt
|
||||
from math import ceil, pi, prod, sqrt, log
|
||||
from tinygrad.lazy import Device, LazyBuffer
|
||||
from tinygrad.ops import LoadOps
|
||||
|
||||
|
@ -481,6 +481,7 @@ class Tensor:
|
|||
|
||||
def contiguous(self): return mlops.Contiguous.apply(self)
|
||||
def log(self): return mlops.Log.apply(self)
|
||||
def log2(self): return mlops.Log.apply(self)/log(2)
|
||||
def exp(self): return mlops.Exp.apply(self)
|
||||
def relu(self): return mlops.Relu.apply(self)
|
||||
def sin(self): return mlops.Sin.apply(self)
|
||||
|
|
Loading…
Reference in New Issue