Skip to content

Commit 6f9cf6b

Browse files
authored
6765 update GeneralizedDiceLoss (#6775)
Fixes #6765 ### Description as discussed in #6765, when `batch=True` the loss should still return 1 aggregated value instead of C channels. #5466 is not actually achievable with this formulation. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Wenqi Li <[email protected]>
1 parent 3b56e7f commit 6f9cf6b

File tree

2 files changed

+10
-8
lines changed

2 files changed

+10
-8
lines changed

monai/losses/dice.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,7 @@ def __init__(
268268
smooth_dr: a small constant added to the denominator to avoid nan.
269269
batch: whether to sum the intersection and union areas over the batch dimension before the dividing.
270270
Defaults to False, intersection over union is computed from each item in the batch.
271+
If True, the class-weighted intersection and union areas are first summed across the batches.
271272
272273
Raises:
273274
TypeError: When ``other_act`` is not an ``Optional[Callable]``.
@@ -360,8 +361,9 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
360361
max_values = torch.max(w, dim=1)[0].unsqueeze(dim=1)
361362
w = w + infs * max_values
362363

363-
numer = 2.0 * (intersection * w) + self.smooth_nr
364-
denom = (denominator * w) + self.smooth_dr
364+
final_reduce_dim = 0 if self.batch else 1
365+
numer = 2.0 * (intersection * w).sum(final_reduce_dim, keepdim=True) + self.smooth_nr
366+
denom = (denominator * w).sum(final_reduce_dim, keepdim=True) + self.smooth_dr
365367
f: torch.Tensor = 1.0 - (numer / denom)
366368

367369
if self.reduction == LossReduction.MEAN.value:

tests/test_generalized_dice_loss.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,15 +48,15 @@
4848
"input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]),
4949
"target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]),
5050
},
51-
0.435035,
51+
0.469964,
5252
],
5353
[ # shape: (2, 2, 3), (2, 1, 3)
5454
{"include_background": True, "to_onehot_y": True, "softmax": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4},
5555
{
5656
"input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]),
5757
"target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]),
5858
},
59-
0.3837,
59+
0.414507,
6060
],
6161
[ # shape: (2, 2, 3), (2, 1, 3)
6262
{
@@ -71,7 +71,7 @@
7171
"input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]),
7272
"target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]),
7373
},
74-
1.5348,
74+
0.829015,
7575
],
7676
[ # shape: (2, 2, 3), (2, 1, 3)
7777
{
@@ -86,7 +86,7 @@
8686
"input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]),
8787
"target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]),
8888
},
89-
[[[0.210949], [0.295351]], [[0.599976], [0.428522]]],
89+
[[[0.273476]], [[0.555539]]],
9090
],
9191
[ # shape: (2, 2, 3), (2, 1, 3)
9292
{"include_background": False, "to_onehot_y": True, "smooth_nr": 1e-8, "smooth_dr": 1e-8},
@@ -114,7 +114,7 @@
114114
"input": torch.tensor([[[0.0, 10.0, 10.0, 10.0], [10.0, 0.0, 0.0, 0.0]]]),
115115
"target": torch.tensor([[[1, 1, 0, 0]]]),
116116
},
117-
0.26669,
117+
0.250023,
118118
],
119119
[ # shape: (2, 1, 2, 2), (2, 1, 2, 2)
120120
{"include_background": True, "other_act": torch.tanh, "smooth_nr": 1e-4, "smooth_dr": 1e-4},
@@ -136,7 +136,7 @@
136136
"input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]),
137137
"target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]),
138138
},
139-
-8.55485,
139+
-0.097833,
140140
],
141141
]
142142

0 commit comments

Comments
 (0)