Skip to content
This repository was archived by the owner on Mar 19, 2024. It is now read-only.

Commit 9ec6d8b

Browse files
iseesselfacebook-github-bot
authored andcommitted
Refactor instance retrieval optional normalization (#381)
Summary: Pull Request resolved: #381 1. Rename SHOULD_TRAIN_PCA_OR_WHITENING to TRAIN_PCA_WHITENING 2. Make l2 normalization optional. 3. Fix cfg access bugs 4. Add some more experiments. Reviewed By: prigoyal Differential Revision: D30002757 fbshipit-source-id: a0aaf8ac17fc9044ec427ee14a4b39ee4ca92a7b
1 parent 360218e commit 9ec6d8b

File tree

3 files changed

+73
-40
lines changed

3 files changed

+73
-40
lines changed

tools/instance_retrieval_test.py

+56-34
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,11 @@
66
import logging
77
import os
88
import sys
9-
import uuid
109
from argparse import Namespace
1110
from typing import Any, List
1211

1312
import numpy as np
1413
import torch
15-
import torch.nn.functional as F
1614
import torchvision
1715
from classy_vision.generic.util import copy_model_to_gpu, load_checkpoint
1816
from fvcore.common.file_io import PathManager
@@ -92,7 +90,7 @@ def get_train_features(
9290
):
9391
train_features = []
9492

95-
def process_train_image(i, out_dir):
93+
def process_train_image(i, out_dir, verbose=False):
9694
if i % LOG_FREQUENCY == 0:
9795
logging.info(f"Train Image: {i}"),
9896

@@ -115,24 +113,35 @@ def process_train_image(i, out_dir):
115113
vc = v.cuda()
116114
# the model output is a list always.
117115
activation_map = model(vc)[0].cpu()
116+
117+
if verbose:
118+
print(f"Train Image raw activation map shape: { activation_map.shape }")
119+
118120
# once we have the features,
119121
# we can perform: rmac | gem pooling | l2 norm
120122
if cfg.IMG_RETRIEVAL.FEATS_PROCESSING_TYPE == "rmac":
121-
descriptors = get_rmac_descriptors(activation_map, spatial_levels)
122-
elif cfg.IMG_RETRIEVAL.FEATS_PROCESSING_TYPE == "l2_norm":
123-
# we simply L2 normalize the features otherwise
124-
descriptors = F.normalize(activation_map, p=2, dim=0)
123+
descriptors = get_rmac_descriptors(
124+
activation_map,
125+
spatial_levels,
126+
normalize=cfg.IMG_RETRIEVAL.NORMALIZE_FEATURES,
127+
)
125128
elif cfg.IMG_RETRIEVAL.FEATS_PROCESSING_TYPE == "gem":
126-
descriptors = l2n(
127-
gem(
128-
activation_map,
129-
p=cfg.IMG_RETRIEVAL.GEM_POOL_POWER,
130-
add_bias=False,
131-
)
129+
descriptors = gem(
130+
activation_map,
131+
p=cfg.IMG_RETRIEVAL.GEM_POOL_POWER,
132+
add_bias=True,
132133
)
133134
else:
134135
descriptors = activation_map
135136

137+
# Optionally l2 normalize the features.
138+
if (
139+
cfg.IMG_RETRIEVAL.NORMALIZE_FEATURES
140+
and cfg.IMG_RETRIEVAL.FEATS_PROCESSING_TYPE != "rmac"
141+
):
142+
# RMAC performs normalization within the algorithm, hence we skip it here.
143+
descriptors = l2n(descriptors, dim=1)
144+
136145
if fname_out:
137146
save_file(descriptors.data.numpy(), fname_out, verbose=False)
138147
train_features.append(descriptors.data.numpy())
@@ -146,7 +155,7 @@ def process_train_image(i, out_dir):
146155

147156
logging.info(f"Getting features for train images: {num_images}")
148157
for i in range(num_images):
149-
process_train_image(i, out_dir)
158+
process_train_image(i, out_dir, verbose=(i == 0))
150159

151160
train_features = np.vstack([x.reshape(-1, x.shape[-1]) for x in train_features])
152161
logging.info(f"Train features size: {train_features.shape}")
@@ -163,6 +172,7 @@ def process_eval_image(
163172
model,
164173
pca,
165174
eval_dataset_name,
175+
verbose=False,
166176
):
167177
if is_revisited_dataset(eval_dataset_name):
168178
img = image_helper.load_and_prepare_revisited_image(fname_in, roi=roi)
@@ -176,30 +186,39 @@ def process_eval_image(
176186
# the model output is a list always.
177187
activation_map = model(vc)[0].cpu()
178188

189+
if verbose:
190+
print(f"Eval image raw activation map shape: { activation_map.shape }")
191+
179192
# process the features: rmac | l2 norm
180193
if cfg.IMG_RETRIEVAL.FEATS_PROCESSING_TYPE == "rmac":
181-
descriptors = get_rmac_descriptors(activation_map, spatial_levels, pca=pca)
182-
elif cfg.IMG_RETRIEVAL.FEATS_PROCESSING_TYPE == "l2_norm":
183-
# we simply L2 normalize the features otherwise
184-
descriptors = F.normalize(activation_map, p=2, dim=0)
185-
# Optionally apply pca.
186-
if pca:
187-
descriptors = pca.apply(descriptors)
188-
194+
descriptors = get_rmac_descriptors(
195+
activation_map,
196+
spatial_levels,
197+
pca=pca,
198+
normalize=cfg.IMG_RETRIEVAL.NORMALIZE_FEATURES,
199+
)
189200
elif cfg.IMG_RETRIEVAL.FEATS_PROCESSING_TYPE == "gem":
190-
descriptors = l2n(
191-
gem(
192-
activation_map,
193-
p=cfg.IMG_RETRIEVAL.GEM_POOL_POWER,
194-
add_bias=True,
195-
)
201+
descriptors = gem(
202+
activation_map,
203+
p=cfg.IMG_RETRIEVAL.GEM_POOL_POWER,
204+
add_bias=True,
196205
)
197-
# Optionally apply pca.
198-
if pca:
199-
descriptors = pca.apply(descriptors)
200206
else:
201207
descriptors = activation_map
202208

209+
# Optionally l2 normalize the features.
210+
if (
211+
cfg.IMG_RETRIEVAL.NORMALIZE_FEATURES
212+
and cfg.IMG_RETRIEVAL.FEATS_PROCESSING_TYPE != "rmac"
213+
):
214+
# RMAC performs normalization within the algorithm, hence we skip it here.
215+
descriptors = l2n(descriptors, dim=1)
216+
217+
# Optionally apply pca.
218+
if pca and cfg.IMG_RETRIEVAL.FEATS_PROCESSING_TYPE != "rmac":
219+
# RMAC performs pca within the algorithm, hence we skip it here.
220+
descriptors = pca.apply(descriptors)
221+
203222
if fname_out:
204223
save_file(descriptors.data.numpy(), fname_out, verbose=False)
205224
return descriptors.data.numpy()
@@ -248,6 +267,7 @@ def get_dataset_features(
248267
model,
249268
pca,
250269
eval_dataset_name,
270+
verbose=(idx == 0),
251271
)
252272
features_dataset.append(db_feature)
253273

@@ -286,6 +306,7 @@ def get_queries_features(
286306
if idx % LOG_FREQUENCY == 0:
287307
logging.info(f"Eval Query: {idx}"),
288308
q_fname_in = eval_dataset.get_query_filename(idx)
309+
# Optionally crop the query by the region-of-interest (ROI).
289310
roi = (
290311
eval_dataset.get_query_roi(idx)
291312
if cfg.IMG_RETRIEVAL.CROP_QUERY_ROI
@@ -309,6 +330,7 @@ def get_queries_features(
309330
model,
310331
pca,
311332
eval_dataset_name,
333+
verbose=(idx == 0),
312334
)
313335
features_queries.append(query_feature)
314336

@@ -345,7 +367,7 @@ def get_transforms(cfg, dataset_name):
345367
def get_train_dataset(cfg, root_dataset_path, train_dataset_name, eval_binary_path):
346368
# We only create the train dataset if we need PCA or whitening training.
347369
# Otherwise not.
348-
if cfg.IMG_RETRIEVAL.SHOULD_TRAIN_PCA_OR_WHITENING:
370+
if cfg.IMG_RETRIEVAL.TRAIN_PCA_WHITENING:
349371
train_data_path = f"{root_dataset_path}/{train_dataset_name}"
350372
assert PathManager.exists(train_data_path), f"Unknown path: {train_data_path}"
351373

@@ -444,7 +466,7 @@ def instance_retrieval_test(args, cfg):
444466
############################################################################
445467
# Step 2: Extract the features for the train dataset, calculate PCA or
446468
# whitening and save
447-
if cfg.IMG_RETRIEVAL.SHOULD_TRAIN_PCA_OR_WHITENING:
469+
if cfg.IMG_RETRIEVAL.TRAIN_PCA_WHITENING:
448470
logging.info("Extracting training features...")
449471
# the features are already processed based on type: rmac | GeM | l2 norm
450472
train_features = get_train_features(
@@ -551,7 +573,7 @@ def validate_and_infer_config(config: AttrDict):
551573
), "Spatial levels must be greater than 0."
552574
if config.IMG_RETRIEVAL.FEATS_PROCESSING_TYPE == "rmac":
553575
assert (
554-
config.IMG_RETRIEVAL.SHOULD_TRAIN_PCA_OR_WHITENING
576+
config.IMG_RETRIEVAL.TRAIN_PCA_WHITENING
555577
), "PCA Whitening is built-in to the RMAC algorithm and is required"
556578

557579
return config

vissl/config/defaults.yaml

+6-2
Original file line numberDiff line numberDiff line change
@@ -1257,10 +1257,10 @@ config:
12571257
# Whether or not to save the features that were extracted
12581258
SAVE_FEATURES: False
12591259
# Whether to apply PCA/whitening or not
1260-
SHOULD_TRAIN_PCA_OR_WHITENING: True
1260+
TRAIN_PCA_WHITENING: True
12611261
# gem | rmac | l2_norm
12621262
FEATS_PROCESSING_TYPE: ""
1263-
# valid only for GeM pooling of features
1263+
# valid only for GeM pooling of features. Note that GEM_POOL_POWER=1 equates to average pooling.
12641264
GEM_POOL_POWER: 4.0
12651265
# valid only if we are training whitening on the whitening dataset
12661266
WHITEN_IMG_LIST: ""
@@ -1276,6 +1276,10 @@ config:
12761276
# Relevant for Oxford, Paris, ROxford, and RParis datasets.
12771277
# Our experiments with RN-50/rmac show that ROI cropping degrades performance.
12781278
CROP_QUERY_ROI: False
1279+
# Whether or not to apply L2 norm after the features have been post-processed.
1280+
# Normalization is heavily recommended based on experiments run.
1281+
NORMALIZE_FEATURES: True
1282+
12791283

12801284
# ----------------------------------------------------------------------------------- #
12811285
# K-NEAREST NEIGHBOR (benchmark)

vissl/utils/instance_retrieval_utils/rmac.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def get_rmac_region_coordinates(H, W, L):
8181

8282
# Credits: https://github.com/facebookresearch/deepcluster/blob/master/eval_retrieval.py # NOQA
8383
# Adapted by: Priya Goyal ([email protected])
84-
def get_rmac_descriptors(features, rmac_levels, pca=None):
84+
def get_rmac_descriptors(features, rmac_levels, pca=None, normalize=True):
8585
"""
8686
RMAC descriptors. Coordinates are retrieved following Tolias et al.
8787
L2 normalize the descriptors and optionally apply PCA on the descriptors
@@ -104,18 +104,25 @@ def get_rmac_descriptors(features, rmac_levels, pca=None):
104104

105105
rmac_descriptors = torch.cat(rmac_descriptors, 1)
106106

107-
rmac_descriptors = normalize_L2(rmac_descriptors, 2)
107+
if normalize:
108+
# Can optionally skip normalization -- not recommended.
109+
# the original RMAC paper normalizes.
110+
rmac_descriptors = normalize_L2(rmac_descriptors, 2)
108111

109112
if pca is None:
110113
return rmac_descriptors
111114

112115
# PCA + whitening
113116
npca = pca.n_components
114117
rmac_descriptors = pca.apply(rmac_descriptors.view(nr * nim, nc))
115-
rmac_descriptors = normalize_L2(rmac_descriptors, 1)
118+
119+
if normalize:
120+
rmac_descriptors = normalize_L2(rmac_descriptors, 1)
121+
116122
rmac_descriptors = rmac_descriptors.view(nim, nr, npca)
117123

118124
# Sum aggregation and L2-normalization
119125
rmac_descriptors = torch.sum(rmac_descriptors, 1)
120-
rmac_descriptors = normalize_L2(rmac_descriptors, 1)
126+
if normalize:
127+
rmac_descriptors = normalize_L2(rmac_descriptors, 1)
121128
return rmac_descriptors

0 commit comments

Comments
 (0)