[MLPerf][UNet3D] Add DICE loss + metrics (#4204)

* add DICE loss and metrics

* update dice to include reference implementation's link

* remove unused imports

* remove unnecessary test file and update pred + label for metrics and losses test

* add tests to CI + add exclusion of mlperf_unet3d

---------

Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
Francis Lata 2024-04-17 20:09:33 -04:00 committed by GitHub
parent cd801a15f3
commit 3644077a42
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 162 additions and 15 deletions

View File

@ -228,6 +228,12 @@ jobs:
- if: ${{ matrix.task == 'onnx' }}
name: Test MLPerf optimizers
run: GPU=1 python -m pytest -n=auto test/external/external_test_optim.py --durations=20
- if: ${{ matrix.task == 'onnx' }}
name: Test MLPerf losses
run: GPU=1 python -m pytest -n=auto test/external/external_test_losses.py --durations=20
- if: ${{ matrix.task == 'onnx' }}
name: Test MLPerf metrics
run: GPU=1 python -m pytest -n=auto test/external/external_test_metrics.py --durations=20
- if: ${{ matrix.task == 'onnx' }}
name: Test THREEFRY
run: PYTHONPATH=. THREEFRY=1 GPU=1 python3 -m pytest test/test_randomness.py test/test_jit.py

View File

@ -0,0 +1,6 @@
from examples.mlperf.metrics import dice_score
def dice_ce_loss(pred, tgt):
ce = pred.permute(0, 2, 3, 4, 1).sparse_categorical_crossentropy(tgt.squeeze(1))
dice = (1.0 - dice_score(pred, tgt, argmax=False, to_one_hot_x=False)).mean()
return (dice + ce) / 2

View File

@ -1,7 +1,6 @@
import re
import string
from collections import Counter
import numpy as np
def levenshtein(a, b):
n, m = len(a), len(b)
@ -29,21 +28,22 @@ def word_error_rate(x, y):
scores += levenshtein(h_list, r_list)
return float(scores) / words, float(scores), words
def one_hot(arr, num_classes=3):
res = np.eye(num_classes)[np.array(arr).reshape(-1)]
arr = res.reshape(list(arr.shape) + [num_classes])
arr = arr.transpose((0, 4, 1, 2, 3)).astype(np.float32)
return arr
def one_hot(x):
return x.one_hot(3).squeeze(1).permute(0, 4, 1, 2, 3)
def get_dice_score(prediction, target, channel_axis=1, smooth_nr=1e-6, smooth_dr=1e-6):
def dice_score(prediction, target, channel_axis=1, smooth_nr=1e-6, smooth_dr=1e-6, argmax=True, to_one_hot_x=True):
channel_axis, reduce_axis = 1, tuple(range(2, len(prediction.shape)))
prediction = prediction.argmax(axis=channel_axis)
prediction, target= one_hot(prediction)[:, 1:], one_hot(target)[:, 1:]
intersection = np.sum(prediction * target, axis=reduce_axis)
target_sum = np.sum(target, axis=reduce_axis)
prediction_sum = np.sum(prediction, axis=reduce_axis)
if argmax: prediction = prediction.argmax(axis=channel_axis)
else: prediction = prediction.softmax(axis=channel_axis)
if to_one_hot_x: prediction = one_hot(prediction)
target = one_hot(target)
prediction, target = prediction[:, 1:], target[:, 1:]
assert prediction.shape == target.shape, f"prediction ({prediction.shape}) and target ({target.shape}) shapes do not match"
intersection = (prediction * target).sum(axis=reduce_axis)
target_sum = target.sum(axis=reduce_axis)
prediction_sum = prediction.sum(axis=reduce_axis)
result = (2.0 * intersection + smooth_nr) / (target_sum + prediction_sum + smooth_dr)
return result[0]
return result
def normalize_string(s):
s = "".join(c for c in s.lower() if c not in string.punctuation)

View File

@ -64,7 +64,7 @@ def eval_unet3d():
# UNet3D
from extra.models.unet3d import UNet3D
from extra.datasets.kits19 import iterate, sliding_window_inference
from examples.mlperf.metrics import get_dice_score
from examples.mlperf.metrics import dice_score
mdl = UNet3D()
mdl.load_from_pretrained()
s = 0
@ -74,7 +74,7 @@ def eval_unet3d():
pred, label = sliding_window_inference(mdl, image, label)
et = time.perf_counter()
print(f"{(mt-st)*1000:.2f} ms loading data, {(et-mt)*1000:.2f} ms to run model")
s += get_dice_score(pred, label).mean()
s += dice_score(Tensor(pred), Tensor(label)).mean().item()
print(f"****** {s:.2f}/{i} {s/i:.5f} Mean DICE score")
st = time.perf_counter()

View File

@ -34,4 +34,5 @@ exclude = [
"openpilot/",
"tinygrad/runtime/autogen",
"test/external/mlperf_resnet",
"test/external/mlperf_unet3d",
]

20
test/external/external_test_losses.py vendored Normal file
View File

@ -0,0 +1,20 @@
from tinygrad import Tensor
from test.external.mlperf_unet3d.dice import DiceCELoss
from examples.mlperf.losses import dice_ce_loss
import numpy as np
import torch
import unittest
class ExternalTestLosses(unittest.TestCase):
def _test_losses(self, tinygrad_metrics, orig_metrics, pred, label):
tinygrad_metrics_res = tinygrad_metrics(Tensor(pred), Tensor(label)).numpy()
orig_metrics_res = orig_metrics(torch.from_numpy(pred), torch.from_numpy(label)).numpy()
np.testing.assert_allclose(tinygrad_metrics_res, orig_metrics_res, atol=1e-4)
def test_dice_ce(self):
pred, label = np.random.rand(1, 3, 128, 128, 128).astype(np.float32), np.ones((1, 1, 128, 128, 128)).astype(np.uint8)
self._test_losses(dice_ce_loss, DiceCELoss(True, True, "NCDHW", False), pred, label)
if __name__ == '__main__':
unittest.main()

20
test/external/external_test_metrics.py vendored Normal file
View File

@ -0,0 +1,20 @@
from tinygrad import Tensor
from test.external.mlperf_unet3d.dice import DiceScore
from examples.mlperf.metrics import dice_score
import numpy as np
import torch
import unittest
class ExternalTestMetrics(unittest.TestCase):
def _test_metrics(self, tinygrad_metrics, orig_metrics, pred, label):
tinygrad_metrics_res = tinygrad_metrics(Tensor(pred), Tensor(label)).squeeze().numpy()
orig_metrics_res = orig_metrics(torch.from_numpy(pred), torch.from_numpy(label)).numpy()
np.testing.assert_equal(tinygrad_metrics_res, orig_metrics_res)
def test_dice(self):
pred, label = np.random.rand(1, 3, 128, 128, 128).astype(np.float32), np.ones((1, 1, 128, 128, 128)).astype(np.uint8)
self._test_metrics(dice_score, DiceScore(), pred, label)
if __name__ == '__main__':
unittest.main()

94
test/external/mlperf_unet3d/dice.py vendored Normal file
View File

@ -0,0 +1,94 @@
# https://github.com/mlcommons/training/blob/master/image_segmentation/pytorch/model/losses.py
import torch
import torch.nn as nn
import torch.nn.functional as F
class Dice:
def __init__(self,
to_onehot_y: bool = True,
to_onehot_x: bool = False,
use_softmax: bool = True,
use_argmax: bool = False,
include_background: bool = False,
layout: str = "NCDHW"):
self.include_background = include_background
self.to_onehot_y = to_onehot_y
self.to_onehot_x = to_onehot_x
self.use_softmax = use_softmax
self.use_argmax = use_argmax
self.smooth_nr = 1e-6
self.smooth_dr = 1e-6
self.layout = layout
def __call__(self, prediction, target):
if self.layout == "NCDHW":
channel_axis = 1
reduce_axis = list(range(2, len(prediction.shape)))
else:
channel_axis = -1
reduce_axis = list(range(1, len(prediction.shape) - 1))
num_pred_ch = prediction.shape[channel_axis]
if self.use_softmax:
prediction = torch.softmax(prediction, dim=channel_axis)
elif self.use_argmax:
prediction = torch.argmax(prediction, dim=channel_axis)
if self.to_onehot_y:
target = to_one_hot(target, self.layout, channel_axis)
if self.to_onehot_x:
prediction = to_one_hot(prediction, self.layout, channel_axis)
if not self.include_background:
assert num_pred_ch > 1, \
f"To exclude background the prediction needs more than one channel. Got {num_pred_ch}."
if self.layout == "NCDHW":
target = target[:, 1:]
prediction = prediction[:, 1:]
else:
target = target[..., 1:]
prediction = prediction[..., 1:]
assert (target.shape == prediction.shape), \
f"Target and prediction shape do not match. Target: ({target.shape}), prediction: ({prediction.shape})."
intersection = torch.sum(target * prediction, dim=reduce_axis)
target_sum = torch.sum(target, dim=reduce_axis)
prediction_sum = torch.sum(prediction, dim=reduce_axis)
return (2.0 * intersection + self.smooth_nr) / (target_sum + prediction_sum + self.smooth_dr)
def to_one_hot(array, layout, channel_axis):
if len(array.shape) >= 5:
array = torch.squeeze(array, dim=channel_axis)
array = F.one_hot(array.long(), num_classes=3)
if layout == "NCDHW":
array = array.permute(0, 4, 1, 2, 3).float()
return array
class DiceCELoss(nn.Module):
def __init__(self, to_onehot_y, use_softmax, layout, include_background):
super(DiceCELoss, self).__init__()
self.dice = Dice(to_onehot_y=to_onehot_y, use_softmax=use_softmax, layout=layout,
include_background=include_background)
self.cross_entropy = nn.CrossEntropyLoss()
def forward(self, y_pred, y_true):
cross_entropy = self.cross_entropy(y_pred, torch.squeeze(y_true, dim=1).long())
dice = torch.mean(1.0 - self.dice(y_pred, y_true))
return (dice + cross_entropy) / 2
class DiceScore:
def __init__(self, to_onehot_y: bool = True, use_argmax: bool = True, layout: str = "NCDHW",
include_background: bool = False):
self.dice = Dice(to_onehot_y=to_onehot_y, to_onehot_x=True, use_softmax=False,
use_argmax=use_argmax, layout=layout, include_background=include_background)
def __call__(self, y_pred, y_true):
return torch.mean(self.dice(y_pred, y_true), dim=0)