MaskRCNN Inference (#884)

* MaskRCNN weights loading

* backbone maybe works

* backbone works, but resnet body atol 1e-3

* RPN Call, but veryy wrong output

* fixed topk

* RPN maybe works, not sure about nms

* Fix cursed modules

* add back editorconfig

* Full call, wrong output

* Full call works

* fix mask

* use NMS from retinanet

* Removing extra funcs

* refactor

* readable

* Add example to run model

* remove filter

* Fix split, batched inference is worse

* Fix image sizes

* Matching reference

* merge master

* add filter on top detections

* cuda backend fixed

* add model eval and spec

* convert images to rgb

* fix eval

* simplify examples code

* remove extra code

* meshgrid using tinygrad

* removing numpy

* roi align, floor, ceil

* remove numpy from level_mapper

* remove numpy from pooler

* Revert "Merge branch 'master' of github.com:kunwar31/tinygrad into mrcnn-inference"

This reverts commit 4b95a3cb499393bb68b95500cd736d50a93d3ce4, reversing
changes made to 98f2b1fa2ede20113b1b369ac00d4b2a7ca5fbfa.

* roi align gather

* fix master merge

* revert to old floor, ceil as ints present in domain

* use log2 op

* fix indexes

* weird bug with ints and gpu

* weird bug with ints and gpu

* refactors, add env var for gather

* floor with contiguous, where

* refactor topk, sort

* remove staticmethod

* refactor stride

* remove log2 mlop

* realize -> contiguous

* refactor forward

* remove num_classes, stride_in_1x1 from state

* refactor forward

* refactoring

* flake8

* removing numpy in anchor gen, use numpy for gather, nonzero, optimize topk

* keep using tinygrad for smaller gathers

* fix empty tensors

* comms

* move from tensor.py

* resnet test passing

* add coco dataset back

* fix spaces

* add test for log2

* no need to create Tensors

* no need to create Tensors

---------

Co-authored-by: Kunwar Raj Singh <kunwar31@pop-os.localdomain>
This commit is contained in:
Kunwar Raj Singh 2023-06-26 04:07:51 +05:30 committed by GitHub
parent 0f281e7b18
commit 5d3310ce56
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 1854 additions and 23 deletions

200
datasets/coco.py Normal file
View File

@ -0,0 +1,200 @@
import json
import pathlib
import zipfile
import numpy as np
from extra.utils import download_file
import pycocotools._mask as _mask
from examples.mask_rcnn import Masker
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
iou = _mask.iou
merge = _mask.merge
frPyObjects = _mask.frPyObjects
BASEDIR = pathlib.Path(__file__).parent.parent / "datasets/COCO"
def create_dict(key_row, val_row, rows): return {row[key_row]:row[val_row] for row in rows}
if not pathlib.Path(BASEDIR/'val2017').is_dir():
fn = BASEDIR/'val2017.zip'
download_file('http://images.cocodataset.org/zips/val2017.zip',fn)
with zipfile.ZipFile(fn, 'r') as zip_ref:
zip_ref.extractall(BASEDIR)
fn.unlink()
if not pathlib.Path(BASEDIR/'annotations').is_dir():
fn = BASEDIR/'annotations_trainval2017.zip'
download_file('http://images.cocodataset.org/annotations/annotations_trainval2017.zip',fn)
with zipfile.ZipFile(fn, 'r') as zip_ref:
zip_ref.extractall(BASEDIR)
fn.unlink()
with open(BASEDIR/'annotations/instances_val2017.json', 'r') as f:
annotations_raw = json.loads(f.read())
images = annotations_raw['images']
categories = annotations_raw['categories']
annotations = annotations_raw['annotations']
file_name_to_id = create_dict('file_name', 'id', images)
id_to_width = create_dict('id', 'width', images)
id_to_height = create_dict('id', 'height', images)
json_category_id_to_contiguous_id = {v['id']: i + 1 for i, v in enumerate(categories)}
contiguous_category_id_to_json_id = {v:k for k,v in json_category_id_to_contiguous_id.items()}
def encode(bimask):
if len(bimask.shape) == 3:
return _mask.encode(bimask)
elif len(bimask.shape) == 2:
h, w = bimask.shape
return _mask.encode(bimask.reshape((h, w, 1), order='F'))[0]
def decode(rleObjs):
if type(rleObjs) == list:
return _mask.decode(rleObjs)
else:
return _mask.decode([rleObjs])[:,:,0]
def area(rleObjs):
if type(rleObjs) == list:
return _mask.area(rleObjs)
else:
return _mask.area([rleObjs])[0]
def toBbox(rleObjs):
if type(rleObjs) == list:
return _mask.toBbox(rleObjs)
else:
return _mask.toBbox([rleObjs])[0]
def convert_prediction_to_coco_bbox(file_name, prediction):
coco_results = []
try:
original_id = file_name_to_id[file_name]
if len(prediction) == 0:
return coco_results
image_width = id_to_width[original_id]
image_height = id_to_height[original_id]
prediction = prediction.resize((image_width, image_height))
prediction = prediction.convert("xywh")
boxes = prediction.bbox.numpy().tolist()
scores = prediction.get_field("scores").numpy().tolist()
labels = prediction.get_field("labels").numpy().tolist()
mapped_labels = [contiguous_category_id_to_json_id[int(i)] for i in labels]
coco_results.extend(
[
{
"image_id": original_id,
"category_id": mapped_labels[k],
"bbox": box,
"score": scores[k],
}
for k, box in enumerate(boxes)
]
)
except Exception as e:
print(file_name, e)
return coco_results
masker = Masker(threshold=0.5, padding=1)
def convert_prediction_to_coco_mask(file_name, prediction):
coco_results = []
try:
original_id = file_name_to_id[file_name]
if len(prediction) == 0:
return coco_results
image_width = id_to_width[original_id]
image_height = id_to_height[original_id]
prediction = prediction.resize((image_width, image_height))
masks = prediction.get_field("mask")
scores = prediction.get_field("scores").numpy().tolist()
labels = prediction.get_field("labels").numpy().tolist()
masks = masker([masks], [prediction])[0].numpy()
rles = [
encode(np.array(mask[0, :, :, np.newaxis], order="F"))[0]
for mask in masks
]
for rle in rles:
rle["counts"] = rle["counts"].decode("utf-8")
mapped_labels = [contiguous_category_id_to_json_id[int(i)] for i in labels]
coco_results.extend(
[
{
"image_id": original_id,
"category_id": mapped_labels[k],
"segmentation": rle,
"score": scores[k],
}
for k, rle in enumerate(rles)
]
)
except Exception as e:
print(file_name, e)
return coco_results
def accumulate_predictions_for_coco(coco_results, json_result_file, rm=False):
path = pathlib.Path(json_result_file)
if rm and path.exists(): path.unlink()
with open(path, "a") as f:
for s in coco_results:
f.write(json.dumps(s))
f.write('\n')
def remove_dup(l):
seen = set()
seen_add = seen.add
return [x for x in l if not (x in seen or seen_add(x))]
class NpEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, np.integer):
return int(obj)
if isinstance(obj, np.floating):
return float(obj)
if isinstance(obj, np.ndarray):
return obj.tolist()
return super(NpEncoder, self).default(obj)
def evaluate_predictions_on_coco(json_result_file, iou_type="bbox"):
coco_results = []
with open(json_result_file, "r") as f:
for line in f:
coco_results.append(json.loads(line))
coco_gt = COCO(str(BASEDIR/'annotations/instances_val2017.json'))
set_of_json = remove_dup([json.dumps(d, cls=NpEncoder) for d in coco_results])
unique_list = [json.loads(s) for s in set_of_json]
with open(f'{json_result_file}.flattend', "w") as f:
json.dump(unique_list, f)
coco_dt = coco_gt.loadRes(str(f'{json_result_file}.flattend'))
coco_eval = COCOeval(coco_gt, coco_dt, iou_type)
coco_eval.evaluate()
coco_eval.accumulate()
coco_eval.summarize()
return coco_eval
def iterate(files, bs=1):
batch = []
for file in files:
batch.append(file)
if len(batch) >= bs: yield batch; batch = []
if len(batch) > 0: yield batch; batch = []

299
examples/mask_rcnn.py Normal file
View File

@ -0,0 +1,299 @@
from models.mask_rcnn import MaskRCNN
from models.resnet import ResNet
from models.mask_rcnn import BoxList
from torch.nn import functional as F
from torchvision import transforms as T
from torchvision.transforms import functional as Ft
import random
from tinygrad.tensor import Tensor
from PIL import Image
import numpy as np
import torch
import argparse
import cv2
class Resize:
def __init__(self, min_size, max_size):
if not isinstance(min_size, (list, tuple)):
min_size = (min_size,)
self.min_size = min_size
self.max_size = max_size
# modified from torchvision to add support for max size
def get_size(self, image_size):
w, h = image_size
size = random.choice(self.min_size)
max_size = self.max_size
if max_size is not None:
min_original_size = float(min((w, h)))
max_original_size = float(max((w, h)))
if max_original_size / min_original_size * size > max_size:
size = int(round(max_size * min_original_size / max_original_size))
if (w <= h and w == size) or (h <= w and h == size):
return (h, w)
if w < h:
ow = size
oh = int(size * h / w)
else:
oh = size
ow = int(size * w / h)
return (oh, ow)
def __call__(self, image):
size = self.get_size(image.size)
image = Ft.resize(image, size)
return image
class Normalize:
def __init__(self, mean, std, to_bgr255=True):
self.mean = mean
self.std = std
self.to_bgr255 = to_bgr255
def __call__(self, image):
if self.to_bgr255:
image = image[[2, 1, 0]] * 255
else:
image = image[[0, 1, 2]] * 255
image = Ft.normalize(image, mean=self.mean, std=self.std)
return image
transforms = lambda size_scale: T.Compose(
[
Resize(int(800*size_scale), int(1333*size_scale)),
T.ToTensor(),
Normalize(
mean=[102.9801, 115.9465, 122.7717], std=[1., 1., 1.], to_bgr255=True
),
]
)
def expand_boxes(boxes, scale):
w_half = (boxes[:, 2] - boxes[:, 0]) * .5
h_half = (boxes[:, 3] - boxes[:, 1]) * .5
x_c = (boxes[:, 2] + boxes[:, 0]) * .5
y_c = (boxes[:, 3] + boxes[:, 1]) * .5
w_half *= scale
h_half *= scale
boxes_exp = torch.zeros_like(boxes)
boxes_exp[:, 0] = x_c - w_half
boxes_exp[:, 2] = x_c + w_half
boxes_exp[:, 1] = y_c - h_half
boxes_exp[:, 3] = y_c + h_half
return boxes_exp
def expand_masks(mask, padding):
N = mask.shape[0]
M = mask.shape[-1]
pad2 = 2 * padding
scale = float(M + pad2) / M
padded_mask = mask.new_zeros((N, 1, M + pad2, M + pad2))
padded_mask[:, :, padding:-padding, padding:-padding] = mask
return padded_mask, scale
def paste_mask_in_image(mask, box, im_h, im_w, thresh=0.5, padding=1):
# TODO: remove torch
mask = torch.tensor(mask.numpy())
box = torch.tensor(box.numpy())
padded_mask, scale = expand_masks(mask[None], padding=padding)
mask = padded_mask[0, 0]
box = expand_boxes(box[None], scale)[0]
box = box.to(dtype=torch.int32)
TO_REMOVE = 1
w = int(box[2] - box[0] + TO_REMOVE)
h = int(box[3] - box[1] + TO_REMOVE)
w = max(w, 1)
h = max(h, 1)
mask = mask.expand((1, 1, -1, -1))
mask = mask.to(torch.float32)
mask = F.interpolate(mask, size=(h, w), mode='bilinear', align_corners=False)
mask = mask[0][0]
if thresh >= 0:
mask = mask > thresh
else:
mask = (mask * 255).to(torch.uint8)
im_mask = torch.zeros((im_h, im_w), dtype=torch.uint8)
x_0 = max(box[0], 0)
x_1 = min(box[2] + 1, im_w)
y_0 = max(box[1], 0)
y_1 = min(box[3] + 1, im_h)
im_mask[y_0:y_1, x_0:x_1] = mask[
(y_0 - box[1]): (y_1 - box[1]), (x_0 - box[0]): (x_1 - box[0])
]
return im_mask
class Masker:
def __init__(self, threshold=0.5, padding=1):
self.threshold = threshold
self.padding = padding
def forward_single_image(self, masks, boxes):
boxes = boxes.convert("xyxy")
im_w, im_h = boxes.size
res = [
paste_mask_in_image(mask[0], box, im_h, im_w, self.threshold, self.padding)
for mask, box in zip(masks, boxes.bbox)
]
if len(res) > 0:
res = torch.stack(res, dim=0)[:, None]
else:
res = masks.new_empty((0, 1, masks.shape[-2], masks.shape[-1]))
return Tensor(res.numpy())
def __call__(self, masks, boxes):
if isinstance(boxes, BoxList):
boxes = [boxes]
results = []
for mask, box in zip(masks, boxes):
result = self.forward_single_image(mask, box)
results.append(result)
return results
masker = Masker(threshold=0.5, padding=1)
def select_top_predictions(predictions, confidence_threshold=0.9):
scores = predictions.get_field("scores").numpy()
keep = [idx for idx, score in enumerate(scores) if score > confidence_threshold]
return predictions[keep]
def compute_prediction(original_image, model, confidence_threshold, size_scale=1.0):
image = transforms(size_scale)(original_image).numpy()
image = Tensor(image, requires_grad=False)
predictions = model(image)
prediction = predictions[0]
prediction = select_top_predictions(prediction, confidence_threshold)
width, height = original_image.size
prediction = prediction.resize((width, height))
if prediction.has_field("mask"):
masks = prediction.get_field("mask")
masks = masker([masks], [prediction])[0]
prediction.add_field("mask", masks)
return prediction
def compute_prediction_batched(batch, model, size_scale=1.0):
imgs = []
for img in batch:
imgs.append(transforms(size_scale)(img).numpy())
image = [Tensor(image, requires_grad=False) for image in imgs]
predictions = model(image)
del image
return predictions
palette = np.array([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])
def findContours(*args, **kwargs):
if cv2.__version__.startswith('4'):
contours, hierarchy = cv2.findContours(*args, **kwargs)
elif cv2.__version__.startswith('3'):
_, contours, hierarchy = cv2.findContours(*args, **kwargs)
return contours, hierarchy
def compute_colors_for_labels(labels):
l = labels[:, None]
colors = l * palette
colors = (colors % 255).astype("uint8")
return colors
def overlay_mask(image, predictions):
image = np.asarray(image)
masks = predictions.get_field("mask").numpy()
labels = predictions.get_field("labels").numpy()
colors = compute_colors_for_labels(labels).tolist()
for mask, color in zip(masks, colors):
thresh = mask[0, :, :, None]
contours, hierarchy = findContours(
thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
)
image = cv2.drawContours(image, contours, -1, color, 3)
composite = image
return composite
CATEGORIES = [
"__background", "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light",
"fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow", "elephant",
"bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard",
"sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle",
"wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange", "broccoli",
"carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", "potted plant", "bed", "dining table",
"toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven", "toaster",
"sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush",
]
def overlay_boxes(image, predictions):
labels = predictions.get_field("labels").numpy()
boxes = predictions.bbox
image = np.asarray(image)
colors = compute_colors_for_labels(labels).tolist()
for box, color in zip(boxes, colors):
box = torch.tensor(box.numpy())
box = box.to(torch.int64)
top_left, bottom_right = box[:2].tolist(), box[2:].tolist()
image = cv2.rectangle(
image, tuple(top_left), tuple(bottom_right), tuple(color), 1
)
return image
def overlay_class_names(image, predictions):
scores = predictions.get_field("scores").numpy().tolist()
labels = predictions.get_field("labels").numpy().tolist()
labels = [CATEGORIES[int(i)] for i in labels]
boxes = predictions.bbox.numpy()
image = np.asarray(image)
template = "{}: {:.2f}"
for box, score, label in zip(boxes, scores, labels):
x, y = box[:2]
s = template.format(label, score)
x, y = int(x), int(y)
cv2.putText(
image, s, (x, y), cv2.FONT_HERSHEY_SIMPLEX, .5, (255, 255, 255), 1
)
return image
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Run MaskRCNN', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--image', type=str, help="Path of the image to run")
parser.add_argument('--threshold', type=float, default=0.7, help="Detector threshold")
parser.add_argument('--size_scale', type=float, default=1.0, help="Image resize multiplier")
parser.add_argument('--out', type=str, default="/tmp/rendered.png", help="Output filename")
args = parser.parse_args()
resnet = ResNet(50, num_classes=None, stride_in_1x1=True)
model_tiny = MaskRCNN(resnet)
model_tiny.load_from_pretrained()
img = Image.open(args.image)
top_result_tiny = compute_prediction(img, model_tiny, confidence_threshold=args.threshold, size_scale=args.size_scale)
bbox_image = overlay_boxes(img, top_result_tiny)
mask_image = overlay_mask(bbox_image, top_result_tiny)
final_image = overlay_class_names(mask_image, top_result_tiny)
im = Image.fromarray(final_image)
print(f"saving {args.out}")
im.save(args.out)
im.show()

View File

@ -184,12 +184,49 @@ def eval_bert():
st = time.perf_counter()
def eval_mrcnn():
from tqdm import tqdm
from models.mask_rcnn import MaskRCNN
from models.resnet import ResNet
from datasets.coco import BASEDIR, images, convert_prediction_to_coco_bbox, convert_prediction_to_coco_mask, accumulate_predictions_for_coco, evaluate_predictions_on_coco, iterate
from examples.mask_rcnn import compute_prediction_batched, Image
mdl = MaskRCNN(ResNet(50, num_classes=None, stride_in_1x1=True))
mdl.load_from_pretrained()
bbox_output = '/tmp/results_bbox.json'
mask_output = '/tmp/results_mask.json'
accumulate_predictions_for_coco([], bbox_output, rm=True)
accumulate_predictions_for_coco([], mask_output, rm=True)
#TODO: bs > 1 not as accurate
bs = 1
for batch in tqdm(iterate(images, bs=bs), total=len(images)//bs):
batch_imgs = []
for image_row in batch:
image_name = image_row['file_name']
img = Image.open(BASEDIR/f'val2017/{image_name}').convert("RGB")
batch_imgs.append(img)
batch_result = compute_prediction_batched(batch_imgs, mdl)
for image_row, result in zip(batch, batch_result):
image_name = image_row['file_name']
box_pred = convert_prediction_to_coco_bbox(image_name, result)
mask_pred = convert_prediction_to_coco_mask(image_name, result)
accumulate_predictions_for_coco(box_pred, bbox_output)
accumulate_predictions_for_coco(mask_pred, mask_output)
del batch_imgs
del batch_result
evaluate_predictions_on_coco(bbox_output, iou_type='bbox')
evaluate_predictions_on_coco(mask_output, iou_type='segm')
if __name__ == "__main__":
# inference only
Tensor.training = False
Tensor.no_grad = True
models = getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert").split(",")
models = getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,mrcnn").split(",")
for m in models:
nm = f"eval_{m}"
if nm in globals():

View File

@ -5,7 +5,8 @@ import numpy as np
def test_model(model, *inputs):
GlobalCounters.reset()
model(*inputs).numpy()
out = model(*inputs)
if isinstance(out, Tensor): out = out.numpy()
# TODO: return event future to still get the time_sum_s without DEBUG=2
print(f"{GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.time_sum_s*1000:.2f} ms")
@ -49,15 +50,21 @@ def spec_bert():
tt = Tensor(np.random.randint(0, 2, (1, 384)).astype(np.float32))
test_model(mdl, x, am, tt)
def spec_mrcnn():
from models.mask_rcnn import MaskRCNN, ResNet
mdl = MaskRCNN(ResNet(50, num_classes=None, stride_in_1x1=True))
mdl.load_from_pretrained()
x = Tensor.randn(3, 224, 224)
test_model(mdl, [x])
if __name__ == "__main__":
# inference only for now
Tensor.training = False
Tensor.no_grad = True
for m in getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert").split(","):
for m in getenv("MODEL", "resnet,retinanet,unet3d,rnnt,bert,mrcnn").split(","):
nm = f"spec_{m}"
if nm in globals():
print(f"testing {m}")
globals()[nm]()

1273
models/mask_rcnn.py Normal file

File diff suppressed because it is too large Load Diff

View File

@ -27,14 +27,15 @@ class BasicBlock:
class Bottleneck:
# NOTE: the original implementation places stride at the first convolution (self.conv1), this is the v1.5 variant
# NOTE: stride_in_1x1=False, this is the v1.5 variant
expansion = 4
def __init__(self, in_planes, planes, stride=1, groups=1, base_width=64):
def __init__(self, in_planes, planes, stride=1, stride_in_1x1=False, groups=1, base_width=64):
width = int(planes * (base_width / 64.0)) * groups
self.conv1 = nn.Conv2d(in_planes, width, kernel_size=1, bias=False)
# NOTE: the original implementation places stride at the first convolution (self.conv1), control with stride_in_1x1
self.conv1 = nn.Conv2d(in_planes, width, kernel_size=1, stride=stride if stride_in_1x1 else 1, bias=False)
self.bn1 = nn.BatchNorm2d(width)
self.conv2 = nn.Conv2d(width, width, kernel_size=3, padding=1, stride=stride, groups=groups, bias=False)
self.conv2 = nn.Conv2d(width, width, kernel_size=3, padding=1, stride=1 if stride_in_1x1 else stride, groups=groups, bias=False)
self.bn2 = nn.BatchNorm2d(width)
self.conv3 = nn.Conv2d(width, self.expansion*planes, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(self.expansion*planes)
@ -54,9 +55,8 @@ class Bottleneck:
return out
class ResNet:
def __init__(self, num, num_classes, groups=1, width_per_group=64):
def __init__(self, num, num_classes=None, groups=1, width_per_group=64, stride_in_1x1=False):
self.num = num
self.block = {
18: BasicBlock,
34: BasicBlock,
@ -79,30 +79,41 @@ class ResNet:
self.base_width = width_per_group
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, bias=False, padding=3)
self.bn1 = nn.BatchNorm2d(64)
self.layer1 = self._make_layer(self.block, 64, self.num_blocks[0], stride=1)
self.layer2 = self._make_layer(self.block, 128, self.num_blocks[1], stride=2)
self.layer3 = self._make_layer(self.block, 256, self.num_blocks[2], stride=2)
self.layer4 = self._make_layer(self.block, 512, self.num_blocks[3], stride=2)
self.fc = nn.Linear(512 * self.block.expansion, num_classes)
self.layer1 = self._make_layer(self.block, 64, self.num_blocks[0], stride=1, stride_in_1x1=stride_in_1x1)
self.layer2 = self._make_layer(self.block, 128, self.num_blocks[1], stride=2, stride_in_1x1=stride_in_1x1)
self.layer3 = self._make_layer(self.block, 256, self.num_blocks[2], stride=2, stride_in_1x1=stride_in_1x1)
self.layer4 = self._make_layer(self.block, 512, self.num_blocks[3], stride=2, stride_in_1x1=stride_in_1x1)
self.fc = nn.Linear(512 * self.block.expansion, num_classes) if num_classes is not None else None
def _make_layer(self, block, planes, num_blocks, stride):
def _make_layer(self, block, planes, num_blocks, stride, stride_in_1x1):
strides = [stride] + [1] * (num_blocks-1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride, self.groups, self.base_width))
if block == Bottleneck:
layers.append(block(self.in_planes, planes, stride, stride_in_1x1, self.groups, self.base_width))
else:
layers.append(block(self.in_planes, planes, stride, self.groups, self.base_width))
self.in_planes = planes * block.expansion
return layers
def forward(self, x):
is_feature_only = self.fc is None
if is_feature_only: features = []
out = self.bn1(self.conv1(x)).relu()
out = out.pad2d([1,1,1,1]).max_pool2d((3,3), 2)
out = out.sequential(self.layer1)
if is_feature_only: features.append(out)
out = out.sequential(self.layer2)
if is_feature_only: features.append(out)
out = out.sequential(self.layer3)
if is_feature_only: features.append(out)
out = out.sequential(self.layer4)
out = out.mean([2,3])
out = self.fc(out).log_softmax()
return out
if is_feature_only: features.append(out)
if not is_feature_only:
out = out.mean([2,3])
out = self.fc(out).log_softmax()
return out
return features
def __call__(self, x):
return self.forward(x)

View File

@ -243,6 +243,9 @@ class TestOps(unittest.TestCase):
def test_log(self):
helper_test_op([(45,65)], lambda x: torch.log(x), Tensor.log)
helper_test_op([()], lambda x: torch.log(x), Tensor.log)
def test_log2(self):
helper_test_op([(45,65)], lambda x: torch.log2(x), Tensor.log2)
helper_test_op([()], lambda x: torch.log2(x), Tensor.log2)
def test_exp(self):
helper_test_op([(45,65)], lambda x: torch.exp(x), Tensor.exp)
helper_test_op([()], lambda x: torch.exp(x), Tensor.exp)

View File

@ -7,7 +7,7 @@ import operator
import numpy as np
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, cast
from tinygrad.helpers import ImageDType, argfix, make_pair, getenv, IMAGE, DEBUG, flatten, DType, dtypes
from math import ceil, pi, prod, sqrt
from math import ceil, pi, prod, sqrt, log
from tinygrad.lazy import Device, LazyBuffer
from tinygrad.ops import LoadOps
@ -481,6 +481,7 @@ class Tensor:
def contiguous(self): return mlops.Contiguous.apply(self)
def log(self): return mlops.Log.apply(self)
def log2(self): return mlops.Log.apply(self)/log(2)
def exp(self): return mlops.Exp.apply(self)
def relu(self): return mlops.Relu.apply(self)
def sin(self): return mlops.Sin.apply(self)