tinygrad/examples/mask_rcnn.py

300 lines
9.5 KiB
Python

from models.mask_rcnn import MaskRCNN
from models.resnet import ResNet
from models.mask_rcnn import BoxList
from torch.nn import functional as F
from torchvision import transforms as T
from torchvision.transforms import functional as Ft
import random
from tinygrad.tensor import Tensor
from PIL import Image
import numpy as np
import torch
import argparse
import cv2
class Resize:
def __init__(self, min_size, max_size):
if not isinstance(min_size, (list, tuple)):
min_size = (min_size,)
self.min_size = min_size
self.max_size = max_size
# modified from torchvision to add support for max size
def get_size(self, image_size):
w, h = image_size
size = random.choice(self.min_size)
max_size = self.max_size
if max_size is not None:
min_original_size = float(min((w, h)))
max_original_size = float(max((w, h)))
if max_original_size / min_original_size * size > max_size:
size = int(round(max_size * min_original_size / max_original_size))
if (w <= h and w == size) or (h <= w and h == size):
return (h, w)
if w < h:
ow = size
oh = int(size * h / w)
else:
oh = size
ow = int(size * w / h)
return (oh, ow)
def __call__(self, image):
size = self.get_size(image.size)
image = Ft.resize(image, size)
return image
class Normalize:
def __init__(self, mean, std, to_bgr255=True):
self.mean = mean
self.std = std
self.to_bgr255 = to_bgr255
def __call__(self, image):
if self.to_bgr255:
image = image[[2, 1, 0]] * 255
else:
image = image[[0, 1, 2]] * 255
image = Ft.normalize(image, mean=self.mean, std=self.std)
return image
transforms = lambda size_scale: T.Compose(
[
Resize(int(800*size_scale), int(1333*size_scale)),
T.ToTensor(),
Normalize(
mean=[102.9801, 115.9465, 122.7717], std=[1., 1., 1.], to_bgr255=True
),
]
)
def expand_boxes(boxes, scale):
w_half = (boxes[:, 2] - boxes[:, 0]) * .5
h_half = (boxes[:, 3] - boxes[:, 1]) * .5
x_c = (boxes[:, 2] + boxes[:, 0]) * .5
y_c = (boxes[:, 3] + boxes[:, 1]) * .5
w_half *= scale
h_half *= scale
boxes_exp = torch.zeros_like(boxes)
boxes_exp[:, 0] = x_c - w_half
boxes_exp[:, 2] = x_c + w_half
boxes_exp[:, 1] = y_c - h_half
boxes_exp[:, 3] = y_c + h_half
return boxes_exp
def expand_masks(mask, padding):
N = mask.shape[0]
M = mask.shape[-1]
pad2 = 2 * padding
scale = float(M + pad2) / M
padded_mask = mask.new_zeros((N, 1, M + pad2, M + pad2))
padded_mask[:, :, padding:-padding, padding:-padding] = mask
return padded_mask, scale
def paste_mask_in_image(mask, box, im_h, im_w, thresh=0.5, padding=1):
# TODO: remove torch
mask = torch.tensor(mask.numpy())
box = torch.tensor(box.numpy())
padded_mask, scale = expand_masks(mask[None], padding=padding)
mask = padded_mask[0, 0]
box = expand_boxes(box[None], scale)[0]
box = box.to(dtype=torch.int32)
TO_REMOVE = 1
w = int(box[2] - box[0] + TO_REMOVE)
h = int(box[3] - box[1] + TO_REMOVE)
w = max(w, 1)
h = max(h, 1)
mask = mask.expand((1, 1, -1, -1))
mask = mask.to(torch.float32)
mask = F.interpolate(mask, size=(h, w), mode='bilinear', align_corners=False)
mask = mask[0][0]
if thresh >= 0:
mask = mask > thresh
else:
mask = (mask * 255).to(torch.uint8)
im_mask = torch.zeros((im_h, im_w), dtype=torch.uint8)
x_0 = max(box[0], 0)
x_1 = min(box[2] + 1, im_w)
y_0 = max(box[1], 0)
y_1 = min(box[3] + 1, im_h)
im_mask[y_0:y_1, x_0:x_1] = mask[
(y_0 - box[1]): (y_1 - box[1]), (x_0 - box[0]): (x_1 - box[0])
]
return im_mask
class Masker:
def __init__(self, threshold=0.5, padding=1):
self.threshold = threshold
self.padding = padding
def forward_single_image(self, masks, boxes):
boxes = boxes.convert("xyxy")
im_w, im_h = boxes.size
res = [
paste_mask_in_image(mask[0], box, im_h, im_w, self.threshold, self.padding)
for mask, box in zip(masks, boxes.bbox)
]
if len(res) > 0:
res = torch.stack(res, dim=0)[:, None]
else:
res = masks.new_empty((0, 1, masks.shape[-2], masks.shape[-1]))
return Tensor(res.numpy())
def __call__(self, masks, boxes):
if isinstance(boxes, BoxList):
boxes = [boxes]
results = []
for mask, box in zip(masks, boxes):
result = self.forward_single_image(mask, box)
results.append(result)
return results
masker = Masker(threshold=0.5, padding=1)
def select_top_predictions(predictions, confidence_threshold=0.9):
scores = predictions.get_field("scores").numpy()
keep = [idx for idx, score in enumerate(scores) if score > confidence_threshold]
return predictions[keep]
def compute_prediction(original_image, model, confidence_threshold, size_scale=1.0):
image = transforms(size_scale)(original_image).numpy()
image = Tensor(image, requires_grad=False)
predictions = model(image)
prediction = predictions[0]
prediction = select_top_predictions(prediction, confidence_threshold)
width, height = original_image.size
prediction = prediction.resize((width, height))
if prediction.has_field("mask"):
masks = prediction.get_field("mask")
masks = masker([masks], [prediction])[0]
prediction.add_field("mask", masks)
return prediction
def compute_prediction_batched(batch, model, size_scale=1.0):
imgs = []
for img in batch:
imgs.append(transforms(size_scale)(img).numpy())
image = [Tensor(image, requires_grad=False) for image in imgs]
predictions = model(image)
del image
return predictions
palette = np.array([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])
def findContours(*args, **kwargs):
if cv2.__version__.startswith('4'):
contours, hierarchy = cv2.findContours(*args, **kwargs)
elif cv2.__version__.startswith('3'):
_, contours, hierarchy = cv2.findContours(*args, **kwargs)
return contours, hierarchy
def compute_colors_for_labels(labels):
l = labels[:, None]
colors = l * palette
colors = (colors % 255).astype("uint8")
return colors
def overlay_mask(image, predictions):
image = np.asarray(image)
masks = predictions.get_field("mask").numpy()
labels = predictions.get_field("labels").numpy()
colors = compute_colors_for_labels(labels).tolist()
for mask, color in zip(masks, colors):
thresh = mask[0, :, :, None]
contours, hierarchy = findContours(
thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
)
image = cv2.drawContours(image, contours, -1, color, 3)
composite = image
return composite
CATEGORIES = [
"__background", "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light",
"fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow", "elephant",
"bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard",
"sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle",
"wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange", "broccoli",
"carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", "potted plant", "bed", "dining table",
"toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven", "toaster",
"sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush",
]
def overlay_boxes(image, predictions):
labels = predictions.get_field("labels").numpy()
boxes = predictions.bbox
image = np.asarray(image)
colors = compute_colors_for_labels(labels).tolist()
for box, color in zip(boxes, colors):
box = torch.tensor(box.numpy())
box = box.to(torch.int64)
top_left, bottom_right = box[:2].tolist(), box[2:].tolist()
image = cv2.rectangle(
image, tuple(top_left), tuple(bottom_right), tuple(color), 1
)
return image
def overlay_class_names(image, predictions):
scores = predictions.get_field("scores").numpy().tolist()
labels = predictions.get_field("labels").numpy().tolist()
labels = [CATEGORIES[int(i)] for i in labels]
boxes = predictions.bbox.numpy()
image = np.asarray(image)
template = "{}: {:.2f}"
for box, score, label in zip(boxes, scores, labels):
x, y = box[:2]
s = template.format(label, score)
x, y = int(x), int(y)
cv2.putText(
image, s, (x, y), cv2.FONT_HERSHEY_SIMPLEX, .5, (255, 255, 255), 1
)
return image
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Run MaskRCNN', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--image', type=str, help="Path of the image to run")
parser.add_argument('--threshold', type=float, default=0.7, help="Detector threshold")
parser.add_argument('--size_scale', type=float, default=1.0, help="Image resize multiplier")
parser.add_argument('--out', type=str, default="/tmp/rendered.png", help="Output filename")
args = parser.parse_args()
resnet = ResNet(50, num_classes=None, stride_in_1x1=True)
model_tiny = MaskRCNN(resnet)
model_tiny.load_from_pretrained()
img = Image.open(args.image)
top_result_tiny = compute_prediction(img, model_tiny, confidence_threshold=args.threshold, size_scale=args.size_scale)
bbox_image = overlay_boxes(img, top_result_tiny)
mask_image = overlay_mask(bbox_image, top_result_tiny)
final_image = overlay_class_names(mask_image, top_result_tiny)
im = Image.fromarray(final_image)
print(f"saving {args.out}")
im.save(args.out)
im.show()