mirror of https://github.com/commaai/tinygrad.git
[MLPerf][UNet3D] Add DICE loss + metrics (#4204)
* add DICE loss and metrics * update dice to include reference implementation's link * remove unused imports * remove unnecessary test file and update pred + label for metrics and losses test * add tests to CI + add exclusion of mlperf_unet3d --------- Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
parent
cd801a15f3
commit
3644077a42
|
@ -228,6 +228,12 @@ jobs:
|
|||
- if: ${{ matrix.task == 'onnx' }}
|
||||
name: Test MLPerf optimizers
|
||||
run: GPU=1 python -m pytest -n=auto test/external/external_test_optim.py --durations=20
|
||||
- if: ${{ matrix.task == 'onnx' }}
|
||||
name: Test MLPerf losses
|
||||
run: GPU=1 python -m pytest -n=auto test/external/external_test_losses.py --durations=20
|
||||
- if: ${{ matrix.task == 'onnx' }}
|
||||
name: Test MLPerf metrics
|
||||
run: GPU=1 python -m pytest -n=auto test/external/external_test_metrics.py --durations=20
|
||||
- if: ${{ matrix.task == 'onnx' }}
|
||||
name: Test THREEFRY
|
||||
run: PYTHONPATH=. THREEFRY=1 GPU=1 python3 -m pytest test/test_randomness.py test/test_jit.py
|
||||
|
|
|
@ -0,0 +1,6 @@
|
|||
from examples.mlperf.metrics import dice_score
|
||||
|
||||
def dice_ce_loss(pred, tgt):
|
||||
ce = pred.permute(0, 2, 3, 4, 1).sparse_categorical_crossentropy(tgt.squeeze(1))
|
||||
dice = (1.0 - dice_score(pred, tgt, argmax=False, to_one_hot_x=False)).mean()
|
||||
return (dice + ce) / 2
|
|
@ -1,7 +1,6 @@
|
|||
import re
|
||||
import string
|
||||
from collections import Counter
|
||||
import numpy as np
|
||||
|
||||
def levenshtein(a, b):
|
||||
n, m = len(a), len(b)
|
||||
|
@ -29,21 +28,22 @@ def word_error_rate(x, y):
|
|||
scores += levenshtein(h_list, r_list)
|
||||
return float(scores) / words, float(scores), words
|
||||
|
||||
def one_hot(arr, num_classes=3):
|
||||
res = np.eye(num_classes)[np.array(arr).reshape(-1)]
|
||||
arr = res.reshape(list(arr.shape) + [num_classes])
|
||||
arr = arr.transpose((0, 4, 1, 2, 3)).astype(np.float32)
|
||||
return arr
|
||||
def one_hot(x):
|
||||
return x.one_hot(3).squeeze(1).permute(0, 4, 1, 2, 3)
|
||||
|
||||
def get_dice_score(prediction, target, channel_axis=1, smooth_nr=1e-6, smooth_dr=1e-6):
|
||||
def dice_score(prediction, target, channel_axis=1, smooth_nr=1e-6, smooth_dr=1e-6, argmax=True, to_one_hot_x=True):
|
||||
channel_axis, reduce_axis = 1, tuple(range(2, len(prediction.shape)))
|
||||
prediction = prediction.argmax(axis=channel_axis)
|
||||
prediction, target= one_hot(prediction)[:, 1:], one_hot(target)[:, 1:]
|
||||
intersection = np.sum(prediction * target, axis=reduce_axis)
|
||||
target_sum = np.sum(target, axis=reduce_axis)
|
||||
prediction_sum = np.sum(prediction, axis=reduce_axis)
|
||||
if argmax: prediction = prediction.argmax(axis=channel_axis)
|
||||
else: prediction = prediction.softmax(axis=channel_axis)
|
||||
if to_one_hot_x: prediction = one_hot(prediction)
|
||||
target = one_hot(target)
|
||||
prediction, target = prediction[:, 1:], target[:, 1:]
|
||||
assert prediction.shape == target.shape, f"prediction ({prediction.shape}) and target ({target.shape}) shapes do not match"
|
||||
intersection = (prediction * target).sum(axis=reduce_axis)
|
||||
target_sum = target.sum(axis=reduce_axis)
|
||||
prediction_sum = prediction.sum(axis=reduce_axis)
|
||||
result = (2.0 * intersection + smooth_nr) / (target_sum + prediction_sum + smooth_dr)
|
||||
return result[0]
|
||||
return result
|
||||
|
||||
def normalize_string(s):
|
||||
s = "".join(c for c in s.lower() if c not in string.punctuation)
|
||||
|
|
|
@ -64,7 +64,7 @@ def eval_unet3d():
|
|||
# UNet3D
|
||||
from extra.models.unet3d import UNet3D
|
||||
from extra.datasets.kits19 import iterate, sliding_window_inference
|
||||
from examples.mlperf.metrics import get_dice_score
|
||||
from examples.mlperf.metrics import dice_score
|
||||
mdl = UNet3D()
|
||||
mdl.load_from_pretrained()
|
||||
s = 0
|
||||
|
@ -74,7 +74,7 @@ def eval_unet3d():
|
|||
pred, label = sliding_window_inference(mdl, image, label)
|
||||
et = time.perf_counter()
|
||||
print(f"{(mt-st)*1000:.2f} ms loading data, {(et-mt)*1000:.2f} ms to run model")
|
||||
s += get_dice_score(pred, label).mean()
|
||||
s += dice_score(Tensor(pred), Tensor(label)).mean().item()
|
||||
print(f"****** {s:.2f}/{i} {s/i:.5f} Mean DICE score")
|
||||
st = time.perf_counter()
|
||||
|
||||
|
|
|
@ -34,4 +34,5 @@ exclude = [
|
|||
"openpilot/",
|
||||
"tinygrad/runtime/autogen",
|
||||
"test/external/mlperf_resnet",
|
||||
"test/external/mlperf_unet3d",
|
||||
]
|
||||
|
|
|
@ -0,0 +1,20 @@
|
|||
from tinygrad import Tensor
|
||||
from test.external.mlperf_unet3d.dice import DiceCELoss
|
||||
from examples.mlperf.losses import dice_ce_loss
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import unittest
|
||||
|
||||
class ExternalTestLosses(unittest.TestCase):
|
||||
def _test_losses(self, tinygrad_metrics, orig_metrics, pred, label):
|
||||
tinygrad_metrics_res = tinygrad_metrics(Tensor(pred), Tensor(label)).numpy()
|
||||
orig_metrics_res = orig_metrics(torch.from_numpy(pred), torch.from_numpy(label)).numpy()
|
||||
np.testing.assert_allclose(tinygrad_metrics_res, orig_metrics_res, atol=1e-4)
|
||||
|
||||
def test_dice_ce(self):
|
||||
pred, label = np.random.rand(1, 3, 128, 128, 128).astype(np.float32), np.ones((1, 1, 128, 128, 128)).astype(np.uint8)
|
||||
self._test_losses(dice_ce_loss, DiceCELoss(True, True, "NCDHW", False), pred, label)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
|
@ -0,0 +1,20 @@
|
|||
from tinygrad import Tensor
|
||||
from test.external.mlperf_unet3d.dice import DiceScore
|
||||
from examples.mlperf.metrics import dice_score
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import unittest
|
||||
|
||||
class ExternalTestMetrics(unittest.TestCase):
|
||||
def _test_metrics(self, tinygrad_metrics, orig_metrics, pred, label):
|
||||
tinygrad_metrics_res = tinygrad_metrics(Tensor(pred), Tensor(label)).squeeze().numpy()
|
||||
orig_metrics_res = orig_metrics(torch.from_numpy(pred), torch.from_numpy(label)).numpy()
|
||||
np.testing.assert_equal(tinygrad_metrics_res, orig_metrics_res)
|
||||
|
||||
def test_dice(self):
|
||||
pred, label = np.random.rand(1, 3, 128, 128, 128).astype(np.float32), np.ones((1, 1, 128, 128, 128)).astype(np.uint8)
|
||||
self._test_metrics(dice_score, DiceScore(), pred, label)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
|
@ -0,0 +1,94 @@
|
|||
# https://github.com/mlcommons/training/blob/master/image_segmentation/pytorch/model/losses.py
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class Dice:
|
||||
def __init__(self,
|
||||
to_onehot_y: bool = True,
|
||||
to_onehot_x: bool = False,
|
||||
use_softmax: bool = True,
|
||||
use_argmax: bool = False,
|
||||
include_background: bool = False,
|
||||
layout: str = "NCDHW"):
|
||||
self.include_background = include_background
|
||||
self.to_onehot_y = to_onehot_y
|
||||
self.to_onehot_x = to_onehot_x
|
||||
self.use_softmax = use_softmax
|
||||
self.use_argmax = use_argmax
|
||||
self.smooth_nr = 1e-6
|
||||
self.smooth_dr = 1e-6
|
||||
self.layout = layout
|
||||
|
||||
def __call__(self, prediction, target):
|
||||
if self.layout == "NCDHW":
|
||||
channel_axis = 1
|
||||
reduce_axis = list(range(2, len(prediction.shape)))
|
||||
else:
|
||||
channel_axis = -1
|
||||
reduce_axis = list(range(1, len(prediction.shape) - 1))
|
||||
num_pred_ch = prediction.shape[channel_axis]
|
||||
|
||||
if self.use_softmax:
|
||||
prediction = torch.softmax(prediction, dim=channel_axis)
|
||||
elif self.use_argmax:
|
||||
prediction = torch.argmax(prediction, dim=channel_axis)
|
||||
|
||||
if self.to_onehot_y:
|
||||
target = to_one_hot(target, self.layout, channel_axis)
|
||||
|
||||
if self.to_onehot_x:
|
||||
prediction = to_one_hot(prediction, self.layout, channel_axis)
|
||||
|
||||
if not self.include_background:
|
||||
assert num_pred_ch > 1, \
|
||||
f"To exclude background the prediction needs more than one channel. Got {num_pred_ch}."
|
||||
if self.layout == "NCDHW":
|
||||
target = target[:, 1:]
|
||||
prediction = prediction[:, 1:]
|
||||
else:
|
||||
target = target[..., 1:]
|
||||
prediction = prediction[..., 1:]
|
||||
|
||||
assert (target.shape == prediction.shape), \
|
||||
f"Target and prediction shape do not match. Target: ({target.shape}), prediction: ({prediction.shape})."
|
||||
|
||||
intersection = torch.sum(target * prediction, dim=reduce_axis)
|
||||
target_sum = torch.sum(target, dim=reduce_axis)
|
||||
prediction_sum = torch.sum(prediction, dim=reduce_axis)
|
||||
|
||||
return (2.0 * intersection + self.smooth_nr) / (target_sum + prediction_sum + self.smooth_dr)
|
||||
|
||||
|
||||
def to_one_hot(array, layout, channel_axis):
|
||||
if len(array.shape) >= 5:
|
||||
array = torch.squeeze(array, dim=channel_axis)
|
||||
array = F.one_hot(array.long(), num_classes=3)
|
||||
if layout == "NCDHW":
|
||||
array = array.permute(0, 4, 1, 2, 3).float()
|
||||
return array
|
||||
|
||||
|
||||
class DiceCELoss(nn.Module):
|
||||
def __init__(self, to_onehot_y, use_softmax, layout, include_background):
|
||||
super(DiceCELoss, self).__init__()
|
||||
self.dice = Dice(to_onehot_y=to_onehot_y, use_softmax=use_softmax, layout=layout,
|
||||
include_background=include_background)
|
||||
self.cross_entropy = nn.CrossEntropyLoss()
|
||||
|
||||
def forward(self, y_pred, y_true):
|
||||
cross_entropy = self.cross_entropy(y_pred, torch.squeeze(y_true, dim=1).long())
|
||||
dice = torch.mean(1.0 - self.dice(y_pred, y_true))
|
||||
return (dice + cross_entropy) / 2
|
||||
|
||||
|
||||
class DiceScore:
|
||||
def __init__(self, to_onehot_y: bool = True, use_argmax: bool = True, layout: str = "NCDHW",
|
||||
include_background: bool = False):
|
||||
self.dice = Dice(to_onehot_y=to_onehot_y, to_onehot_x=True, use_softmax=False,
|
||||
use_argmax=use_argmax, layout=layout, include_background=include_background)
|
||||
|
||||
def __call__(self, y_pred, y_true):
|
||||
return torch.mean(self.dice(y_pred, y_true), dim=0)
|
Loading…
Reference in New Issue