Skip to content

Replace asserts with exceptions #5587

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 60 commits into from
Mar 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
dfb2862
replace most asserts with exceptions
jdsgomes Mar 10, 2022
40d0528
fix formating issues
jdsgomes Mar 10, 2022
13bfd80
fix linting and remove more asserts
jdsgomes Mar 11, 2022
45ecd61
Merge branch 'main' into replace_asserts_with_exceptions
jdsgomes Mar 11, 2022
f522368
fix regresion
jdsgomes Mar 11, 2022
23bd022
Merge branch 'replace_asserts_with_exceptions' of github.com:jdsgomes…
jdsgomes Mar 11, 2022
6a87e4d
fix regresion
jdsgomes Mar 11, 2022
e179358
fix bug
jdsgomes Mar 11, 2022
30b1714
apply ufmt
jdsgomes Mar 11, 2022
488d2af
apply ufmt
jdsgomes Mar 11, 2022
38d2d01
fix tests
jdsgomes Mar 11, 2022
7d42574
fix format
jdsgomes Mar 11, 2022
dc6856b
fix None check
jdsgomes Mar 11, 2022
2c56adc
fix detection models tests
jdsgomes Mar 11, 2022
aebca6d
non scriptable any
jdsgomes Mar 11, 2022
d54b582
add more checks for None values
jdsgomes Mar 13, 2022
36d2174
Merge branch 'main' into replace_asserts_with_exceptions
jdsgomes Mar 13, 2022
98c2702
fix retinanet test
jdsgomes Mar 13, 2022
bdab5f4
Merge branch 'replace_asserts_with_exceptions' of github.com:jdsgomes…
jdsgomes Mar 13, 2022
4900653
fix retinanet test
jdsgomes Mar 13, 2022
d5ccbf1
Update references/classification/transforms.py
jdsgomes Mar 14, 2022
de2f4b7
Update references/classification/transforms.py
jdsgomes Mar 14, 2022
275012a
Update references/optical_flow/transforms.py
jdsgomes Mar 14, 2022
fddd2ac
Update references/optical_flow/transforms.py
jdsgomes Mar 14, 2022
7e60b46
Update references/optical_flow/transforms.py
jdsgomes Mar 14, 2022
6c2e94f
make value checks more pythonic:
jdsgomes Mar 14, 2022
0a78c6b
fix merge
jdsgomes Mar 14, 2022
cb95c97
Update references/optical_flow/transforms.py
jdsgomes Mar 14, 2022
ff8f557
make value checks more pythonic
jdsgomes Mar 14, 2022
0598990
Merge branch 'replace_asserts_with_exceptions' of github.com:jdsgomes…
jdsgomes Mar 14, 2022
abafdb2
make more checks pythonic
jdsgomes Mar 14, 2022
5b30ce3
fix bug
jdsgomes Mar 14, 2022
ade3364
appy ufmt
jdsgomes Mar 14, 2022
2f4ecc1
fix tracing issues
jdsgomes Mar 14, 2022
981617b
fib typos
jdsgomes Mar 14, 2022
fec7d4b
fix lint
jdsgomes Mar 14, 2022
bdd913b
Merge branch 'main' into replace_asserts_with_exceptions
jdsgomes Mar 14, 2022
ca59cd7
remove unecessary f-strings
jdsgomes Mar 14, 2022
3391f00
Merge branch 'replace_asserts_with_exceptions' of github.com:jdsgomes…
jdsgomes Mar 14, 2022
81ac57c
fix bug
jdsgomes Mar 14, 2022
7affc95
Merge branch 'main' into replace_asserts_with_exceptions
jdsgomes Mar 14, 2022
8dc76e2
Merge branch 'main' into replace_asserts_with_exceptions
jdsgomes Mar 15, 2022
e68c1be
Update torchvision/datasets/mnist.py
jdsgomes Mar 15, 2022
d92a0f9
Update torchvision/datasets/mnist.py
jdsgomes Mar 15, 2022
4a30fc9
Update torchvision/ops/boxes.py
jdsgomes Mar 15, 2022
e4f214d
Update torchvision/ops/poolers.py
jdsgomes Mar 15, 2022
9e9ca6d
Update torchvision/utils.py
jdsgomes Mar 15, 2022
b234e08
address PR comments
jdsgomes Mar 15, 2022
1a45e1e
Update torchvision/io/_video_opt.py
jdsgomes Mar 15, 2022
8437088
Update torchvision/models/detection/generalized_rcnn.py
jdsgomes Mar 15, 2022
cff417a
Update torchvision/models/feature_extraction.py
jdsgomes Mar 15, 2022
1d9e3d3
Update torchvision/models/optical_flow/raft.py
jdsgomes Mar 15, 2022
ce06c29
address PR comments
jdsgomes Mar 15, 2022
2b1870f
addressing further pr comments
jdsgomes Mar 15, 2022
851adb2
fix bug
jdsgomes Mar 15, 2022
a915f1f
remove unecessary else
jdsgomes Mar 15, 2022
f41e115
apply ufmt
jdsgomes Mar 15, 2022
ee21d2e
last pr comment
jdsgomes Mar 15, 2022
d000238
replace RuntimeErrors
jdsgomes Mar 15, 2022
d77739b
Merge branch 'main' into replace_asserts_with_exceptions
jdsgomes Mar 15, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions references/classification/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,14 @@ class RandomMixup(torch.nn.Module):

def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None:
super().__init__()
assert num_classes > 0, "Please provide a valid positive value for the num_classes."
assert alpha > 0, "Alpha param can't be zero."

if num_classes < 1:
raise ValueError(
f"Please provide a valid positive value for the num_classes. Got num_classes={num_classes}"
)

if alpha <= 0:
raise ValueError("Alpha param can't be zero.")

self.num_classes = num_classes
self.p = p
Expand Down Expand Up @@ -99,8 +105,10 @@ class RandomCutmix(torch.nn.Module):

def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None:
super().__init__()
assert num_classes > 0, "Please provide a valid positive value for the num_classes."
assert alpha > 0, "Alpha param can't be zero."
if num_classes < 1:
raise ValueError("Please provide a valid positive value for the num_classes.")
if alpha <= 0:
raise ValueError("Alpha param can't be zero.")

self.num_classes = num_classes
self.p = p
Expand Down
3 changes: 2 additions & 1 deletion references/detection/coco_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@

class CocoEvaluator:
def __init__(self, coco_gt, iou_types):
assert isinstance(iou_types, (list, tuple))
if not isinstance(iou_types, (list, tuple)):
raise TypeError(f"This constructor expects iou_types of type list or tuple, instead got {type(iou_types)}")
coco_gt = copy.deepcopy(coco_gt)
self.coco_gt = coco_gt

Expand Down
5 changes: 4 additions & 1 deletion references/detection/coco_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,10 @@ def _has_valid_annotation(anno):
return True
return False

assert isinstance(dataset, torchvision.datasets.CocoDetection)
if not isinstance(dataset, torchvision.datasets.CocoDetection):
raise TypeError(
f"This function expects dataset of type torchvision.datasets.CocoDetection, instead got {type(dataset)}"
)
ids = []
for ds_idx, img_id in enumerate(dataset.ids):
ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None)
Expand Down
22 changes: 14 additions & 8 deletions references/optical_flow/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,21 @@ class ValidateModelInput(torch.nn.Module):
# Pass-through transform that checks the shape and dtypes to make sure the model gets what it expects
def forward(self, img1, img2, flow, valid_flow_mask):

assert all(isinstance(arg, torch.Tensor) for arg in (img1, img2, flow, valid_flow_mask) if arg is not None)
assert all(arg.dtype == torch.float32 for arg in (img1, img2, flow) if arg is not None)
if not all(isinstance(arg, torch.Tensor) for arg in (img1, img2, flow, valid_flow_mask) if arg is not None):
raise TypeError("This method expects all input arguments to be of type torch.Tensor.")
if not all(arg.dtype == torch.float32 for arg in (img1, img2, flow) if arg is not None):
raise TypeError("This method expects the tensors img1, img2 and flow of be of dtype torch.float32.")

assert img1.shape == img2.shape
if img1.shape != img2.shape:
raise ValueError("img1 and img2 should have the same shape.")
h, w = img1.shape[-2:]
if flow is not None:
assert flow.shape == (2, h, w)
if flow is not None and flow.shape != (2, h, w):
raise ValueError(f"flow.shape should be (2, {h}, {w}) instead of {flow.shape}")
if valid_flow_mask is not None:
assert valid_flow_mask.shape == (h, w)
assert valid_flow_mask.dtype == torch.bool
if valid_flow_mask.shape != (h, w):
raise ValueError(f"valid_flow_mask.shape should be ({h}, {w}) instead of {valid_flow_mask.shape}")
if valid_flow_mask.dtype != torch.bool:
raise TypeError("valid_flow_mask should be of dtype torch.bool instead of {valid_flow_mask.dtype}")

return img1, img2, flow, valid_flow_mask

Expand Down Expand Up @@ -109,7 +114,8 @@ class RandomErasing(T.RandomErasing):
def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False, max_erase=1):
super().__init__(p=p, scale=scale, ratio=ratio, value=value, inplace=inplace)
self.max_erase = max_erase
assert self.max_erase > 0
if self.max_erase <= 0:
raise ValueError("max_raise should be greater than 0")

def forward(self, img1, img2, flow, valid_flow_mask):
if torch.rand(1) > self.p:
Expand Down
5 changes: 4 additions & 1 deletion references/optical_flow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,10 @@ def update(self, **kwargs):
for k, v in kwargs.items():
if isinstance(v, torch.Tensor):
v = v.item()
assert isinstance(v, (float, int))
if not isinstance(v, (float, int)):
raise TypeError(
f"This method expects the value of the input arguments to be of type float or int, instead got {type(v)}"
)
self.meters[k].update(v)

def __getattr__(self, attr):
Expand Down
6 changes: 5 additions & 1 deletion references/segmentation/coco_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,11 @@ def _has_valid_annotation(anno):
# if more than 1k pixels occupied in the image
return sum(obj["area"] for obj in anno) > 1000

assert isinstance(dataset, torchvision.datasets.CocoDetection)
if not isinstance(dataset, torchvision.datasets.CocoDetection):
raise TypeError(
f"This function expects dataset of type torchvision.datasets.CocoDetection, instead got {type(dataset)}"
)

ids = []
for ds_idx, img_id in enumerate(dataset.ids):
ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None)
Expand Down
5 changes: 4 additions & 1 deletion references/segmentation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,10 @@ def update(self, **kwargs):
for k, v in kwargs.items():
if isinstance(v, torch.Tensor):
v = v.item()
assert isinstance(v, (float, int))
if not isinstance(v, (float, int)):
raise TypeError(
f"This method expects the value of the input arguments to be of type float or int, instead got {type(v)}"
)
self.meters[k].update(v)

def __getattr__(self, attr):
Expand Down
3 changes: 2 additions & 1 deletion references/similarity/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ def __init__(self, groups, p, k):
self.groups = create_groups(groups, self.k)

# Ensures there are enough classes to sample from
assert len(self.groups) >= p
if len(self.groups) < p:
raise ValueError("There are not enought classes to sample from")

def __iter__(self):
# Shuffle samples within groups
Expand Down
5 changes: 4 additions & 1 deletion references/video_classification/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,10 @@ def update(self, **kwargs):
for k, v in kwargs.items():
if isinstance(v, torch.Tensor):
v = v.item()
assert isinstance(v, (float, int))
if not isinstance(v, (float, int)):
raise TypeError(
f"This method expects the value of the input arguments to be of type float or int, instead got {type(v)}"
)
self.meters[k].update(v)

def __getattr__(self, attr):
Expand Down
6 changes: 3 additions & 3 deletions test/test_backbone_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,16 +144,16 @@ def test_build_fx_feature_extractor(self, model_name):
model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
)
# Check must specify return nodes
with pytest.raises(AssertionError):
with pytest.raises(ValueError):
self._create_feature_extractor(model)
# Check return_nodes and train_return_nodes / eval_return nodes
# mutual exclusivity
with pytest.raises(AssertionError):
with pytest.raises(ValueError):
self._create_feature_extractor(
model, return_nodes=train_return_nodes, train_return_nodes=train_return_nodes
)
# Check train_return_nodes / eval_return nodes must both be specified
with pytest.raises(AssertionError):
with pytest.raises(ValueError):
self._create_feature_extractor(model, train_return_nodes=train_return_nodes)
# Check invalid node name raises ValueError
with pytest.raises(ValueError):
Expand Down
2 changes: 1 addition & 1 deletion test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,7 +767,7 @@ def test_detection_model_validation(model_fn):

# validate type
targets = [{"boxes": 0.0}]
with pytest.raises(ValueError):
with pytest.raises(TypeError):
model(x, targets=targets)

# validate boxes shape
Expand Down
4 changes: 2 additions & 2 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,13 +138,13 @@ def test_autocast(self, x_dtype, rois_dtype):

def _helper_boxes_shape(self, func):
# test boxes as Tensor[N, 5]
with pytest.raises(AssertionError):
with pytest.raises(ValueError):
a = torch.linspace(1, 8 * 8, 8 * 8).reshape(1, 1, 8, 8)
boxes = torch.tensor([[0, 0, 3, 3]], dtype=a.dtype)
func(a, boxes, output_size=(2, 2))

# test boxes as List[Tensor[N, 4]]
with pytest.raises(AssertionError):
with pytest.raises(ValueError):
a = torch.linspace(1, 8 * 8, 8 * 8).reshape(1, 1, 8, 8)
boxes = torch.tensor([[0, 0, 3]], dtype=a.dtype)
ops.roi_pool(a, [boxes], output_size=(2, 2))
Expand Down
3 changes: 2 additions & 1 deletion torchvision/datasets/kinetics.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ def __init__(
print("Using legacy structure")
self.split_folder = root
self.split = "unknown"
assert not download, "Cannot download the videos using legacy_structure."
if download:
raise ValueError("Cannot download the videos using legacy_structure.")
else:
self.split_folder = path.join(root, split)
self.split = verify_str_arg(split, arg="split", valid_values=["train", "val"])
Expand Down
21 changes: 14 additions & 7 deletions torchvision/datasets/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,11 +442,14 @@ def _check_exists(self) -> bool:

def _load_data(self):
data = read_sn3_pascalvincent_tensor(self.images_file)
assert data.dtype == torch.uint8
assert data.ndimension() == 3
if data.dtype != torch.uint8:
raise TypeError(f"data should be of dtype torch.uint8 instead of {data.dtype}")
if data.ndimension() != 3:
raise ValueError("data should have 3 dimensions instead of {data.ndimension()}")

targets = read_sn3_pascalvincent_tensor(self.labels_file).long()
assert targets.ndimension() == 2
if targets.ndimension() != 2:
raise ValueError(f"targets should have 2 dimensions instead of {targets.ndimension()}")

if self.what == "test10k":
data = data[0:10000, :, :].clone()
Expand Down Expand Up @@ -530,13 +533,17 @@ def read_sn3_pascalvincent_tensor(path: str, strict: bool = True) -> torch.Tenso

def read_label_file(path: str) -> torch.Tensor:
x = read_sn3_pascalvincent_tensor(path, strict=False)
assert x.dtype == torch.uint8
assert x.ndimension() == 1
if x.dtype != torch.uint8:
raise TypeError(f"x should be of dtype torch.uint8 instead of {x.dtype}")
if x.ndimension() != 1:
raise ValueError(f"x should have 1 dimension instead of {x.ndimension()}")
return x.long()


def read_image_file(path: str) -> torch.Tensor:
x = read_sn3_pascalvincent_tensor(path, strict=False)
assert x.dtype == torch.uint8
assert x.ndimension() == 3
if x.dtype != torch.uint8:
raise TypeError(f"x should be of dtype torch.uint8 instead of {x.dtype}")
if x.ndimension() != 3:
raise ValueError(f"x should have 3 dimension instead of {x.ndimension()}")
return x
10 changes: 4 additions & 6 deletions torchvision/datasets/samplers/clip_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,10 @@ def __init__(
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
rank = dist.get_rank()
assert (
len(dataset) % group_size == 0
), "dataset length must be a multiplier of group size dataset length: %d, group size: %d" % (
len(dataset),
group_size,
)
if len(dataset) % group_size != 0:
raise ValueError(
f"dataset length must be a multiplier of group size dataset length: {len(dataset)}, group size: {group_size}"
)
self.dataset = dataset
self.group_size = group_size
self.num_replicas = num_replicas
Expand Down
1 change: 0 additions & 1 deletion torchvision/datasets/sbd.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ def __init__(

self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
self.masks = [os.path.join(mask_dir, x + ".mat") for x in file_names]
assert len(self.images) == len(self.masks)

self._get_target = self._get_segmentation_target if self.mode == "segmentation" else self._get_boundaries_target

Expand Down
3 changes: 2 additions & 1 deletion torchvision/datasets/video_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ def unfold(tensor: torch.Tensor, size: int, step: int, dilation: int = 1) -> tor
`step` between windows. The distance between each element
in a window is given by `dilation`.
"""
assert tensor.dim() == 1
if tensor.dim() != 1:
raise ValueError(f"tensor should have 1 dimension instead of {tensor.dim()}")
o_stride = tensor.stride(0)
numel = tensor.numel()
new_stride = (step * o_stride, dilation * o_stride)
Expand Down
10 changes: 3 additions & 7 deletions torchvision/io/_video_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,9 @@ def __init__(self) -> None:

def _validate_pts(pts_range: Tuple[int, int]) -> None:

if pts_range[1] > 0:
assert (
pts_range[0] <= pts_range[1]
), """Start pts should not be smaller than end pts, got
start pts: {:d} and end pts: {:d}""".format(
pts_range[0],
pts_range[1],
if pts_range[0] > pts_range[1] > 0:
raise ValueError(
f"Start pts should not be smaller than end pts, got start pts: {pts_range[0]} and end pts: {pts_range[1]}"
)


Expand Down
12 changes: 8 additions & 4 deletions torchvision/models/detection/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,10 @@ def encode_single(self, reference_boxes: Tensor, proposals: Tensor) -> Tensor:
return targets

def decode(self, rel_codes: Tensor, boxes: List[Tensor]) -> Tensor:
assert isinstance(boxes, (list, tuple))
assert isinstance(rel_codes, torch.Tensor)
if not isinstance(boxes, (list, tuple)):
raise TypeError(f"This function expects boxes of type list or tuple, instead got {type(boxes)}")
if not isinstance(rel_codes, torch.Tensor):
raise TypeError(f"This function expects rel_codes of type torch.Tensor, instead got {type(rel_codes)}")
boxes_per_image = [b.size(0) for b in boxes]
concat_boxes = torch.cat(boxes, dim=0)
box_sum = 0
Expand Down Expand Up @@ -333,7 +335,8 @@ def __init__(self, high_threshold: float, low_threshold: float, allow_low_qualit
"""
self.BELOW_LOW_THRESHOLD = -1
self.BETWEEN_THRESHOLDS = -2
assert low_threshold <= high_threshold
if low_threshold > high_threshold:
raise ValueError("low_threshold should be <= high_threshold")
self.high_threshold = high_threshold
self.low_threshold = low_threshold
self.allow_low_quality_matches = allow_low_quality_matches
Expand Down Expand Up @@ -371,7 +374,8 @@ def __call__(self, match_quality_matrix: Tensor) -> Tensor:
matches[between_thresholds] = self.BETWEEN_THRESHOLDS

if self.allow_low_quality_matches:
assert all_matches is not None
if all_matches is None:
raise ValueError("all_matches should not be None")
self.set_low_quality_matches_(matches, all_matches, match_quality_matrix)

return matches
Expand Down
10 changes: 5 additions & 5 deletions torchvision/models/detection/anchor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,6 @@ def __init__(
if not isinstance(aspect_ratios[0], (list, tuple)):
aspect_ratios = (aspect_ratios,) * len(sizes)

assert len(sizes) == len(aspect_ratios)

self.sizes = sizes
self.aspect_ratios = aspect_ratios
self.cell_anchors = [
Expand Down Expand Up @@ -86,7 +84,9 @@ def num_anchors_per_location(self):
def grid_anchors(self, grid_sizes: List[List[int]], strides: List[List[Tensor]]) -> List[Tensor]:
anchors = []
cell_anchors = self.cell_anchors
assert cell_anchors is not None

if cell_anchors is None:
ValueError("cell_anchors should not be None")

if not (len(grid_sizes) == len(strides) == len(cell_anchors)):
raise ValueError(
Expand Down Expand Up @@ -164,8 +164,8 @@ def __init__(
clip: bool = True,
):
super().__init__()
if steps is not None:
assert len(aspect_ratios) == len(steps)
if steps is not None and len(aspect_ratios) != len(steps):
raise ValueError("aspect_ratios and steps should have the same length")
self.aspect_ratios = aspect_ratios
self.steps = steps
self.clip = clip
Expand Down
15 changes: 12 additions & 3 deletions torchvision/models/detection/faster_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,14 @@ def __init__(
"same for all the levels)"
)

assert isinstance(rpn_anchor_generator, (AnchorGenerator, type(None)))
assert isinstance(box_roi_pool, (MultiScaleRoIAlign, type(None)))
if not isinstance(rpn_anchor_generator, (AnchorGenerator, type(None))):
raise TypeError(
f"rpn_anchor_generator should be of type AnchorGenerator or None instead of {type(rpn_anchor_generator)}"
)
if not isinstance(box_roi_pool, (MultiScaleRoIAlign, type(None))):
raise TypeError(
f"box_roi_pool should be of type MultiScaleRoIAlign or None instead of {type(box_roi_pool)}"
)

if num_classes is not None:
if box_predictor is not None:
Expand Down Expand Up @@ -299,7 +305,10 @@ def __init__(self, in_channels, num_classes):

def forward(self, x):
if x.dim() == 4:
assert list(x.shape[2:]) == [1, 1]
if list(x.shape[2:]) != [1, 1]:
raise ValueError(
f"x has the wrong shape, expecting the last two dimensions to be [1,1] instead of {list(x.shape[2:])}"
)
x = x.flatten(start_dim=1)
scores = self.cls_score(x)
bbox_deltas = self.bbox_pred(x)
Expand Down
Loading