diff --git a/monai/losses/__init__.py b/monai/losses/__init__.py index e4926d9682..087a24f9d7 100644 --- a/monai/losses/__init__.py +++ b/monai/losses/__init__.py @@ -12,6 +12,7 @@ from __future__ import annotations from .adversarial_loss import PatchAdversarialLoss +from .aucm_loss import AUCMLoss from .barlow_twins import BarlowTwinsLoss from .cldice import SoftclDiceLoss, SoftDiceclDiceLoss from .contrastive import ContrastiveLoss diff --git a/monai/losses/aucm_loss.py b/monai/losses/aucm_loss.py new file mode 100644 index 0000000000..c6fa24e0cb --- /dev/null +++ b/monai/losses/aucm_loss.py @@ -0,0 +1,207 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import torch +import torch.nn as nn +from torch.nn.modules.loss import _Loss + +from monai.utils import LossReduction + + +class AUCMLoss(_Loss): + """ + AUC-Margin loss with squared-hinge surrogate loss for optimizing AUROC. + + The loss optimizes the Area Under the ROC Curve (AUROC) by using margin-based constraints + on positive and negative predictions. It supports two versions: 'v1' includes class prior + information, while 'v2' removes this dependency for better generalization. + + Reference: + Yuan, Zhuoning, Yan, Yan, Sonka, Milan, and Yang, Tianbao. + "Large-scale robust deep auc maximization: A new surrogate loss and empirical studies on medical image classification." + Proceedings of the IEEE/CVF International Conference on Computer Vision. 2021. + https://arxiv.org/abs/2012.03173 + + Implementation based on: https://github.com/Optimization-AI/LibAUC/blob/1.4.0/libauc/losses/auc.py + + Example: + >>> import torch + >>> from monai.losses import AUCMLoss + >>> loss_fn = AUCMLoss() + >>> input = torch.randn(32, 1, requires_grad=True) + >>> target = torch.randint(0, 2, (32, 1)).float() + >>> loss = loss_fn(input, target) + """ + + def __init__( + self, + margin: float = 1.0, + imratio: float | None = None, + version: str = "v1", + reduction: LossReduction | str = LossReduction.MEAN, + ) -> None: + """ + Args: + margin: margin for squared-hinge surrogate loss (default: ``1.0``). + imratio: the ratio of the number of positive samples to the number of total samples in the training dataset. + If this value is not given, it will be automatically calculated with mini-batch samples. + This value is ignored when ``version`` is set to ``'v2'``. + version: whether to include prior class information in the objective function (default: ``'v1'``). + 'v1' includes class prior, 'v2' removes this dependency. + reduction: {``"none"``, ``"mean"``, ``"sum"``} + Specifies the reduction to apply to the output. Defaults to ``"mean"``. + Note: This loss is computed at the batch level and always returns a scalar. + The reduction parameter is accepted for API consistency but has no effect. + + Raises: + ValueError: When ``version`` is not one of ["v1", "v2"]. + ValueError: When ``imratio`` is not in [0, 1]. + + Example: + >>> import torch + >>> from monai.losses import AUCMLoss + >>> loss_fn = AUCMLoss(version='v2') + >>> input = torch.randn(32, 1, requires_grad=True) + >>> target = torch.randint(0, 2, (32, 1)).float() + >>> loss = loss_fn(input, target) + """ + super().__init__(reduction=LossReduction(reduction).value) + if version not in ["v1", "v2"]: + raise ValueError(f"version should be 'v1' or 'v2', got {version}") + if imratio is not None and not (0.0 <= imratio <= 1.0): + raise ValueError(f"imratio must be in [0, 1], got {imratio}") + self.margin = margin + self.imratio = imratio + self.version = version + self.a = nn.Parameter(torch.tensor(0.0)) + self.b = nn.Parameter(torch.tensor(0.0)) + self.alpha = nn.Parameter(torch.tensor(0.0)) + + def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Args: + input: the shape should be B1HW[D], where the channel dimension is 1 for binary classification. + target: the shape should be B1HW[D], with values 0 or 1. + + Returns: + torch.Tensor: scalar AUCM loss. + + Raises: + ValueError: When input or target have incorrect shapes. + ValueError: When input or target have fewer than 2 dimensions. + ValueError: When target contains non-binary values. + """ + if input.ndim < 2 or target.ndim < 2: + raise ValueError("Input and target must have at least 2 dimensions (B, C, ...)") + if input.shape[1] != 1: + raise ValueError(f"Input should have 1 channel for binary classification, got {input.shape[1]}") + if target.shape[1] != 1: + raise ValueError(f"Target should have 1 channel, got {target.shape[1]}") + if input.shape != target.shape: + raise ValueError(f"Input and target shapes do not match: {input.shape} vs {target.shape}") + + input = input.flatten() + target = target.flatten() + + if input.numel() == 0: + raise ValueError("Input and target must contain at least one element.") + + if not torch.all((target == 0) | (target == 1)): + raise ValueError("Target must contain only binary values (0 or 1)") + + pos_mask = (target == 1).float() + neg_mask = (target == 0).float() + + mean_pos_sq = (input - self.a) ** 2 + mean_neg_sq = (input - self.b) ** 2 + + # Note: + # v1 uses global expectations (normalized by total number of samples), + # following the original LibAUC implementation. + # v2 uses class-conditional expectations (normalized by number of samples + # in each class), implemented via non-zero averaging. + # These behaviors differ and should not be unified. + if self.version == "v1": + p = float(self.imratio) if self.imratio is not None else float(pos_mask.mean().item()) + p1 = 1.0 - p + + mean_pos = self._global_mean(mean_pos_sq, pos_mask) + mean_neg = self._global_mean(mean_neg_sq, neg_mask) + + interaction = self._global_mean(p * input * neg_mask - p1 * input * pos_mask, pos_mask + neg_mask) + + loss = ( + p1 * mean_pos + + p * mean_neg + + 2 * self.alpha * (p * p1 * self.margin + interaction) + - p * p1 * self.alpha**2 + ) + + else: # v2 + mean_pos = self._class_mean(mean_pos_sq, pos_mask) + mean_neg = self._class_mean(mean_neg_sq, neg_mask) + + mean_input_pos = self._class_mean(input, pos_mask) + mean_input_neg = self._class_mean(input, neg_mask) + + loss = ( + mean_pos + mean_neg + 2 * self.alpha * (self.margin + mean_input_neg - mean_input_pos) - self.alpha**2 + ) + + return loss + + def _global_mean(self, tensor: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """ + Compute the global mean of a masked tensor. + + This computes the mean over all elements, where values outside the mask + are zeroed out. The result is normalized by the total number of elements, + not by the number of masked elements. + + This corresponds to a global expectation: + E[mask * tensor] + + Args: + tensor: Input tensor. + mask: Binary mask tensor of the same shape as ``tensor``. + + Returns: + Scalar tensor representing the global mean. + """ + masked = tensor * mask + if masked.numel() == 0: + return torch.zeros((), dtype=tensor.dtype, device=tensor.device) + return masked.mean() + + def _class_mean(self, tensor: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """ + Compute the class-conditional mean of a masked tensor. + + This computes the mean over only the masked (non-zero) elements, i.e., + the result is normalized by the number of masked elements. + + This corresponds to a class-conditional expectation: + E[tensor | mask] + + Args: + tensor: Input tensor. + mask: Binary mask tensor of the same shape as ``tensor``. + + Returns: + Scalar tensor representing the class-conditional mean. + Returns 0 if no elements are selected by the mask. + """ + denom = mask.sum() + if denom.item() == 0: + return torch.zeros((), dtype=tensor.dtype, device=tensor.device) + return (tensor * mask).sum() / denom diff --git a/tests/losses/test_aucm_loss.py b/tests/losses/test_aucm_loss.py new file mode 100644 index 0000000000..a8b6639487 --- /dev/null +++ b/tests/losses/test_aucm_loss.py @@ -0,0 +1,211 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.losses import AUCMLoss +from tests.test_utils import test_script_save + +TEST_CASES = [ + [{"version": "v1"}, {"input": torch.tensor([[1.0], [2.0]]), "target": torch.tensor([[1.0], [0.0]])}, 2.375000], + [{"version": "v2"}, {"input": torch.tensor([[1.0], [2.0]]), "target": torch.tensor([[1.0], [0.0]])}, 9.500000], + # ------------------------------------------------------------------ + # Explicit imratio coverage for v1 + # ------------------------------------------------------------------ + [ + {"version": "v1", "imratio": 0.25}, + {"input": torch.tensor([[0.0], [1.0], [2.0], [3.0]]), "target": torch.tensor([[1.0], [1.0], [0.0], [0.0]])}, + 1.687500, + ], + [ + {"version": "v1", "imratio": 0.5}, + {"input": torch.tensor([[0.0], [1.0], [2.0], [3.0]]), "target": torch.tensor([[1.0], [1.0], [0.0], [0.0]])}, + 3.625000, + ], + [ + {"version": "v1", "imratio": 0.75}, + {"input": torch.tensor([[0.0], [1.0], [2.0], [3.0]]), "target": torch.tensor([[1.0], [1.0], [0.0], [0.0]])}, + 5.437500, + ], + # ------------------------------------------------------------------ + # imratio ignored in v2 + # ------------------------------------------------------------------ + [ + {"version": "v2", "imratio": 0.25}, + {"input": torch.tensor([[0.0], [1.0], [2.0], [3.0]]), "target": torch.tensor([[1.0], [1.0], [0.0], [0.0]])}, + 14.500000, + ], + [ + {"version": "v2", "imratio": 0.75}, + {"input": torch.tensor([[0.0], [1.0], [2.0], [3.0]]), "target": torch.tensor([[1.0], [1.0], [0.0], [0.0]])}, + 14.500000, + ], + # ------------------------------------------------------------------ + # Margin coverage for v1 + # ------------------------------------------------------------------ + [ + {"version": "v1", "margin": 0.5}, + {"input": torch.tensor([[2.0], [0.5], [-1.0], [-0.5]]), "target": torch.tensor([[1.0], [1.0], [0.0], [0.0]])}, + -0.687500, + ], + [ + {"version": "v1", "margin": 2.0}, + {"input": torch.tensor([[2.0], [0.5], [-1.0], [-0.5]]), "target": torch.tensor([[1.0], [1.0], [0.0], [0.0]])}, + 0.062500, + ], + # ------------------------------------------------------------------ + # Combined imratio + margin coverage + # ------------------------------------------------------------------ + [ + {"version": "v1", "imratio": 0.25, "margin": 0.5}, + {"input": torch.tensor([[2.0], [0.5], [-1.0], [-0.5]]), "target": torch.tensor([[1.0], [1.0], [0.0], [0.0]])}, + -0.687500, + ], + [ + {"version": "v2", "imratio": 0.25, "margin": 0.5}, + {"input": torch.tensor([[2.0], [0.5], [-1.0], [-0.5]]), "target": torch.tensor([[1.0], [1.0], [0.0], [0.0]])}, + -2.750000, + ], + # ------------------------------------------------------------------ + # Margin coverage for v2 + # ------------------------------------------------------------------ + [ + {"version": "v2", "margin": 0.5}, + {"input": torch.tensor([[2.0], [0.5], [-1.0], [-0.5]]), "target": torch.tensor([[1.0], [1.0], [0.0], [0.0]])}, + -2.750000, + ], + [ + {"version": "v2", "margin": 2.0}, + {"input": torch.tensor([[2.0], [0.5], [-1.0], [-0.5]]), "target": torch.tensor([[1.0], [1.0], [0.0], [0.0]])}, + 0.250000, + ], + # ------------------------------------------------------------------ + # Blank / degenerate inputs + # ------------------------------------------------------------------ + [{"version": "v1"}, {"input": torch.zeros((4, 1)), "target": torch.tensor([[1.0], [0.0], [1.0], [0.0]])}, 0.375000], + # ------------------------------------------------------------------ + # Higher-dimensional tensors + # ------------------------------------------------------------------ + [ + {"version": "v1"}, + {"input": torch.tensor([[[[2.0, -1.0], [0.5, -0.5]]]]), "target": torch.tensor([[[[1.0, 0.0], [1.0, 0.0]]]])}, + -0.437500, + ], + [ + {"version": "v2"}, + {"input": torch.tensor([[[[2.0, -1.0], [0.5, -0.5]]]]), "target": torch.tensor([[[[1.0, 0.0], [1.0, 0.0]]]])}, + -1.750000, + ], +] + +BAD_ARGS = [[{"version": "invalid"}], [{"imratio": -0.1}], [{"imratio": 1.5}], [{"reduction": "invalid"}]] + + +SHAPE_ERROR_CASES = [ + [torch.randn(32), torch.randint(0, 2, (32, 1)).float()], + [torch.randn(32, 2), torch.randint(0, 2, (32, 1)).float()], + [torch.randn(32, 1), torch.randint(0, 2, (32, 2)).float()], + [torch.randn(32, 1), torch.randint(0, 2, (16, 1)).float()], +] + + +class TestAUCMLoss(unittest.TestCase): + + @parameterized.expand(TEST_CASES) + def test_forward_values(self, input_param, input_data, expected_val): + loss_fn = AUCMLoss(**input_param) + + # ------------------------------------------------------------ + # Set deterministic non-zero internal optimization variables + # to make margin-dependent behavior testable + # ------------------------------------------------------------ + loss_fn.a.data.fill_(0.5) + loss_fn.b.data.fill_(-0.5) + loss_fn.alpha.data.fill_(1.0) + + result = loss_fn.forward(**input_data) + + np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-5, atol=1e-5) + + @parameterized.expand(BAD_ARGS) + def test_bad_args(self, kwargs): + with self.assertRaises((ValueError, TypeError)): + AUCMLoss(**kwargs) + + @parameterized.expand(SHAPE_ERROR_CASES) + def test_invalid_shapes(self, pred, target): + with self.assertRaises(ValueError): + AUCMLoss()(pred, target) + + @parameterized.expand([("v1",), ("v2",)]) + def test_all_negative_batch(self, version): + pred = torch.zeros((8, 1)) + target = torch.zeros((8, 1)) + + loss = AUCMLoss(version=version)(pred, target) + + self.assertTrue(torch.isfinite(loss)) + + def test_non_binary_target(self): + pred = torch.randn(32, 1) + + target = torch.tensor([[0.5], [1.0], [2.0], [0.0]] * 8) + + with self.assertRaises(ValueError): + AUCMLoss()(pred, target) + + @parameterized.expand([("v1",), ("v2",)]) + def test_backward(self, version): + pred = torch.randn(32, 1, requires_grad=True) + target = torch.randint(0, 2, (32, 1)).float() + + loss = AUCMLoss(version=version)(pred, target) + + loss.backward() + + self.assertIsNotNone(pred.grad) + self.assertTrue(torch.isfinite(pred.grad).all()) + + @parameterized.expand([("v1",), ("v2",)]) + def test_blank_predictions_mixed_targets(self, version): + pred = torch.zeros((4, 1)) + target = torch.tensor([[1.0], [0.0], [1.0], [0.0]]) + + loss = AUCMLoss(version=version)(pred, target) + if version == "v1": + self.assertTrue(torch.isfinite(loss)) + else: + self.assertTrue(torch.isfinite(loss) or torch.isnan(loss)) + + @parameterized.expand([("mean",), ("sum",), ("none",)]) + def test_reduction_argument(self, reduction): + pred = torch.tensor([[1.0], [2.0]]) + target = torch.tensor([[1.0], [0.0]]) + + loss = AUCMLoss(reduction=reduction)(pred, target) + + self.assertEqual(loss.ndim, 0) + self.assertTrue(torch.isfinite(loss)) + + def test_script_save(self): + loss_fn = AUCMLoss() + + test_script_save(loss_fn, torch.randn(32, 1), torch.randint(0, 2, (32, 1)).float()) + + +if __name__ == "__main__": + unittest.main()