Skip to content

Commit 92044d5

Browse files
committed
Fix mypy issues and add test
1 parent 355b4d9 commit 92044d5

File tree

3 files changed

+50
-0
lines changed

3 files changed

+50
-0
lines changed

monai/utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,8 @@
9999
str2bool,
100100
str2list,
101101
to_tuple_of_dictionaries,
102+
unsqueeze_left,
103+
unsqueeze_right,
102104
zip_with,
103105
)
104106
from .module import (

monai/utils/component_store.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def _my_func(a, b):
4949
result = func(7, 6)
5050
5151
"""
52+
5253
_Component = namedtuple("_Component", ("description", "value")) # internal value pair
5354

5455
def __init__(self, name: str, description: str) -> None:

tests/test_squeeze_unsqueeze.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
import unittest
15+
16+
import numpy as np
17+
import torch
18+
from parameterized import parameterized
19+
20+
from monai.utils import unsqueeze_left, unsqueeze_right
21+
22+
RIGHT_CASES = [(np.random.rand(3, 4), 5, (3, 4, 1, 1, 1)), (torch.rand(3, 4), 5, (3, 4, 1, 1, 1))]
23+
24+
LEFT_CASES = [(np.random.rand(3, 4), 5, (1, 1, 1, 3, 4)), (torch.rand(3, 4), 5, (1, 1, 1, 3, 4))]
25+
26+
ALL_CASES = [
27+
(np.random.rand(3, 4), 2, (3, 4)),
28+
(np.random.rand(3, 4), 0, (3, 4)),
29+
(np.random.rand(3, 4), -1, (3, 4)),
30+
(np.array(3), 4, (1, 1, 1, 1)),
31+
(np.array(3), 0, ()),
32+
(torch.rand(3, 4), 2, (3, 4)),
33+
(torch.rand(3, 4), 0, (3, 4)),
34+
(torch.rand(3, 4), -1, (3, 4)),
35+
(torch.tensor(3), 4, (1, 1, 1, 1)),
36+
(torch.tensor(3), 0, ()),
37+
]
38+
39+
40+
class TestUnsqueeze(unittest.TestCase):
41+
@parameterized.expand(RIGHT_CASES + ALL_CASES)
42+
def test_unsqueeze_right(self, arr, ndim, shape):
43+
self.assertEqual(unsqueeze_right(arr, ndim).shape, shape)
44+
45+
@parameterized.expand(LEFT_CASES + ALL_CASES)
46+
def test_unsqueeze_left(self, arr, ndim, shape):
47+
self.assertEqual(unsqueeze_left(arr, ndim).shape, shape)

0 commit comments

Comments
 (0)