Skip to content
This repository was archived by the owner on Mar 19, 2024. It is now read-only.

Commit d77d6c8

Browse files
iseesselfacebook-github-bot
authored andcommitted
Instance Retrieval Improvements (#380)
Summary: Pull Request resolved: #380 Various Instance Retrieval improvements: 1. Add support for Manifold 2. Cleanup noisy logs and add helpful logging. 3. Add DEBUG_MODE support for the Revisited Datasets. 4. Add ability to save results/logs/features. 5. Fix ROI crop bug. 6. Fix typo in benchmark_workflow.py causing benchmarks to fail. 7. Add a bunch of json configs to track and group multiple experiments. Reviewed By: prigoyal Differential Revision: D29995282 fbshipit-source-id: 2382963f39c6c61aa417b690a39754d4b30b3fe2
1 parent 79161e9 commit d77d6c8

File tree

6 files changed

+214
-61
lines changed

6 files changed

+214
-61
lines changed

tools/instance_retrieval_test.py

Lines changed: 145 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# LICENSE file in the root directory of this source tree.
55

66
import logging
7+
import os
78
import sys
89
import uuid
910
from argparse import Namespace
@@ -13,12 +14,15 @@
1314
import torch
1415
import torch.nn.functional as F
1516
import torchvision
16-
from classy_vision.generic.util import copy_model_to_gpu
17+
from classy_vision.generic.util import copy_model_to_gpu, load_checkpoint
1718
from fvcore.common.file_io import PathManager
1819
from hydra.experimental import compose, initialize_config_module
1920
from vissl.config import AttrDict
2021
from vissl.models import build_model
21-
from vissl.utils.checkpoint import init_model_from_consolidated_weights
22+
from vissl.utils.checkpoint import (
23+
init_model_from_consolidated_weights,
24+
get_checkpoint_folder,
25+
)
2226
from vissl.utils.env import set_env_vars
2327
from vissl.utils.hydra_config import convert_to_attrdict, is_hydra_available, print_cfg
2428
from vissl.utils.instance_retrieval_utils.data_util import (
@@ -36,7 +40,7 @@
3640
)
3741
from vissl.utils.instance_retrieval_utils.pca import load_pca, train_and_save_pca
3842
from vissl.utils.instance_retrieval_utils.rmac import get_rmac_descriptors
39-
from vissl.utils.io import cleanup_dir, load_file, makedir, save_file
43+
from vissl.utils.io import load_file, makedir, save_file
4044
from vissl.utils.logger import setup_logging, shutdown_logging
4145

4246

@@ -53,7 +57,7 @@ def build_retrieval_model(cfg):
5357
if PathManager.exists(cfg.MODEL.WEIGHTS_INIT.PARAMS_FILE):
5458
init_weights_path = cfg.MODEL.WEIGHTS_INIT.PARAMS_FILE
5559
logging.info(f"Initializing model from: {init_weights_path}")
56-
weights = torch.load(init_weights_path, map_location=torch.device("cuda"))
60+
weights = load_checkpoint(init_weights_path, device=torch.device("cuda"))
5761
skip_layers = cfg.MODEL.WEIGHTS_INIT.get("SKIP_LAYERS", [])
5862
replace_prefix = cfg.MODEL.WEIGHTS_INIT.get("REMOVE_PREFIX", None)
5963
append_prefix = cfg.MODEL.WEIGHTS_INIT.get("APPEND_PREFIX", None)
@@ -77,14 +81,16 @@ def build_retrieval_model(cfg):
7781

7882

7983
def gem_pool_and_save_features(features, p, add_bias, gem_out_fname):
80-
if PathManager.exists(gem_out_fname):
84+
if gem_out_fname and PathManager.exists(gem_out_fname):
8185
logging.info("Loading train GeM features...")
8286
features = load_file(gem_out_fname)
8387
else:
8488
logging.info(f"GeM pooling features: {features.shape}")
8589
features = l2n(gem(features, p=p, add_bias=True))
86-
save_file(features, gem_out_fname)
87-
logging.info(f"Saved GeM features to: {gem_out_fname}")
90+
91+
if gem_out_fname:
92+
save_file(features, gem_out_fname, verbose=False)
93+
logging.info(f"Saved GeM features to: {gem_out_fname}")
8894
return features
8995

9096

@@ -103,8 +109,12 @@ def get_train_features(
103109
def process_train_image(i, out_dir):
104110
if i % LOG_FREQUENCY == 0:
105111
logging.info(f"Train Image: {i}"),
106-
fname_out = f"{out_dir}/{i}.npy"
107-
if PathManager.exists(fname_out):
112+
113+
fname_out = None
114+
if out_dir:
115+
fname_out = f"{out_dir}/{i}.npy"
116+
117+
if fname_out and PathManager.exists(fname_out):
108118
feat = load_file(fname_out)
109119
train_features.append(feat)
110120
else:
@@ -123,19 +133,33 @@ def process_train_image(i, out_dir):
123133
# we can perform: rmac | gem pooling | l2 norm
124134
if cfg.IMG_RETRIEVAL.FEATS_PROCESSING_TYPE == "rmac":
125135
descriptors = get_rmac_descriptors(activation_map, spatial_levels)
136+
elif cfg.IMG_RETRIEVAL.FEATS_PROCESSING_TYPE == "l2_norm":
137+
# we simply L2 normalize the features otherwise
138+
descriptors = F.normalize(activation_map, p=2, dim=0)
126139
else:
127140
descriptors = activation_map
128-
save_file(descriptors.data.numpy(), fname_out)
141+
142+
if fname_out:
143+
save_file(descriptors.data.numpy(), fname_out, verbose=False)
129144
train_features.append(descriptors.data.numpy())
130145

131146
num_images = train_dataset.get_num_images()
132-
out_dir = f"{temp_dir}/{train_dataset_name}_S{resize_img}_features_train"
133-
makedir(out_dir)
147+
148+
out_dir = None
149+
if temp_dir:
150+
out_dir = f"{temp_dir}/{train_dataset_name}_S{resize_img}_features_train"
151+
makedir(out_dir)
152+
153+
logging.info(f"Getting features for train images: {num_images}")
134154
for i in range(num_images):
135155
process_train_image(i, out_dir)
136156

137157
if cfg.IMG_RETRIEVAL.FEATS_PROCESSING_TYPE == "gem":
138-
gem_out_fname = f"{out_dir}/{train_dataset_name}_GeM.npy"
158+
159+
gem_out_fname = None
160+
if out_dir:
161+
gem_out_fname = f"{out_dir}/{train_dataset_name}_GeM.npy"
162+
139163
train_features = torch.tensor(np.concatenate(train_features))
140164
train_features = gem_pool_and_save_features(
141165
train_features,
@@ -165,10 +189,12 @@ def process_eval_image(
165189
img = image_helper.load_and_prepare_instre_image(fname_in)
166190
else:
167191
img = image_helper.load_and_prepare_image(fname_in, roi=roi)
192+
168193
v = torch.autograd.Variable(img.unsqueeze(0))
169194
vc = v.cuda()
170195
# the model output is a list always.
171196
activation_map = model(vc)[0].cpu()
197+
172198
# process the features: rmac | l2 norm
173199
if cfg.IMG_RETRIEVAL.FEATS_PROCESSING_TYPE == "rmac":
174200
descriptors = get_rmac_descriptors(activation_map, spatial_levels, pca=pca)
@@ -177,7 +203,9 @@ def process_eval_image(
177203
descriptors = F.normalize(activation_map, p=2, dim=0)
178204
else:
179205
descriptors = activation_map
180-
save_file(descriptors.data.numpy(), fname_out)
206+
207+
if fname_out:
208+
save_file(descriptors.data.numpy(), fname_out, verbose=False)
181209
return descriptors.data.numpy()
182210

183211

@@ -195,15 +223,23 @@ def get_dataset_features(
195223
features_dataset = []
196224
num_images = eval_dataset.get_num_images()
197225
logging.info(f"Getting features for dataset images: {num_images}")
198-
db_fname_out_dir = "{}/{}_S{}_db".format(temp_dir, eval_dataset_name, resize_img)
226+
227+
db_fname_out_dir = None
228+
if temp_dir:
229+
db_fname_out_dir = f"{temp_dir}/{eval_dataset_name}_S{resize_img}_db"
230+
199231
makedir(db_fname_out_dir)
200232

201233
for idx in range(num_images):
202234
if idx % LOG_FREQUENCY == 0:
203235
logging.info(f"Eval Dataset Image: {idx}"),
204236
db_fname_in = eval_dataset.get_filename(idx)
205-
db_fname_out = f"{db_fname_out_dir}/{idx}.npy"
206-
if PathManager.exists(db_fname_out):
237+
238+
db_fname_out = None
239+
if db_fname_out_dir:
240+
db_fname_out = f"{db_fname_out_dir}/{idx}.npy"
241+
242+
if db_fname_out and PathManager.exists(db_fname_out):
207243
db_feature = load_file(db_fname_out)
208244
else:
209245
db_feature = process_eval_image(
@@ -221,7 +257,9 @@ def get_dataset_features(
221257

222258
if cfg.IMG_RETRIEVAL.FEATS_PROCESSING_TYPE == "gem":
223259
# GeM pool the features and apply the PCA
224-
gem_out_fname = f"{db_fname_out_dir}/{eval_dataset_name}_GeM.npy"
260+
gem_out_fname = None
261+
if db_fname_out_dir:
262+
gem_out_fname = f"{db_fname_out_dir}/{eval_dataset_name}_GeM.npy"
225263
features_dataset = torch.tensor(np.concatenate(features_dataset))
226264
features_dataset = gem_pool_and_save_features(
227265
features_dataset,
@@ -231,7 +269,7 @@ def get_dataset_features(
231269
)
232270
features_dataset = pca.apply(features_dataset)
233271
features_dataset = np.vstack(features_dataset)
234-
logging.info(f"features dataset: {features_dataset.shape}")
272+
logging.info(f"Dataset Features Size: {features_dataset.shape}")
235273
return features_dataset
236274

237275

@@ -248,19 +286,30 @@ def get_queries_features(
248286
):
249287
features_queries = []
250288
num_queries = eval_dataset.get_num_query_images()
251-
if cfg.IMG_RETRIEVAL.DEBUG_MODE:
252-
num_queries = 50
289+
290+
num_queries = (
291+
num_queries
292+
if cfg.IMG_RETRIEVAL.NUM_QUERY_SAMPLES == -1
293+
else cfg.IMG_RETRIEVAL.NUM_QUERY_SAMPLES
294+
)
295+
253296
logging.info(f"Getting features for queries: {num_queries}")
254-
q_fname_out_dir = "{}/{}_S{}_q".format(temp_dir, eval_dataset_name, resize_img)
255-
makedir(q_fname_out_dir)
297+
q_fname_out_dir = None
298+
if q_fname_out_dir:
299+
q_fname_out_dir = f"{temp_dir}/{eval_dataset_name}_S{resize_img}_q"
300+
makedir(q_fname_out_dir)
256301

257302
for idx in range(num_queries):
258303
if idx % LOG_FREQUENCY == 0:
259304
logging.info(f"Eval Query: {idx}"),
260305
q_fname_in = eval_dataset.get_query_filename(idx)
261306
roi = eval_dataset.get_query_roi(idx)
262-
q_fname_out = f"{q_fname_out_dir}/{idx}.npy"
263-
if PathManager.exists(q_fname_out):
307+
308+
q_fname_out = None
309+
if q_fname_out_dir:
310+
q_fname_out = f"{q_fname_out_dir}/{idx}.npy"
311+
312+
if q_fname_out and PathManager.exists(q_fname_out):
264313
query_feature = load_file(q_fname_out)
265314
else:
266315
query_feature = process_eval_image(
@@ -278,7 +327,9 @@ def get_queries_features(
278327

279328
if cfg.IMG_RETRIEVAL.FEATS_PROCESSING_TYPE == "gem":
280329
# GeM pool the features and apply the PCA
281-
gem_out_fname = f"{q_fname_out_dir}/{eval_dataset_name}_GeM.npy"
330+
gem_out_fname = None
331+
if q_fname_out_dir:
332+
gem_out_fname = f"{q_fname_out_dir}/{eval_dataset_name}_GeM.npy"
282333
features_queries = torch.tensor(np.concatenate(features_queries))
283334
features_queries = gem_pool_and_save_features(
284335
features_queries,
@@ -288,7 +339,7 @@ def get_queries_features(
288339
)
289340
features_queries = pca.apply(features_queries)
290341
features_queries = np.vstack(features_queries)
291-
logging.info(f"features queries: {features_queries.shape}")
342+
logging.info(f"Queries Features Size: {features_queries.shape}")
292343
return features_queries
293344

294345

@@ -324,11 +375,15 @@ def get_train_dataset(cfg, root_dataset_path, train_dataset_name, eval_binary_pa
324375
train_data_path = f"{root_dataset_path}/{train_dataset_name}"
325376
assert PathManager.exists(train_data_path), f"Unknown path: {train_data_path}"
326377

327-
num_samples = 10 if cfg.IMG_RETRIEVAL.DEBUG_MODE else None
378+
num_samples = (
379+
None
380+
if cfg.IMG_RETRIEVAL.NUM_TRAINING_SAMPLES == -1
381+
else cfg.IMG_RETRIEVAL.NUM_TRAINING_SAMPLES
382+
)
328383

329384
if is_revisited_dataset(train_dataset_name):
330385
train_dataset = RevisitedInstanceRetrievalDataset(
331-
train_dataset_name, root_dataset_path
386+
train_dataset_name, root_dataset_path, num_samples=num_samples
332387
)
333388
elif is_whiten_dataset(train_dataset_name):
334389
train_dataset = WhiteningTrainingImageDataset(
@@ -349,11 +404,15 @@ def get_eval_dataset(cfg, root_dataset_path, eval_dataset_name, eval_binary_path
349404
eval_data_path = f"{root_dataset_path}/{eval_dataset_name}"
350405
assert PathManager.exists(eval_data_path), f"Unknown path: {eval_data_path}"
351406

352-
num_samples = 20 if cfg.IMG_RETRIEVAL.DEBUG_MODE else None
407+
num_samples = (
408+
None
409+
if cfg.IMG_RETRIEVAL.NUM_DATABASE_SAMPLES == -1
410+
else cfg.IMG_RETRIEVAL.NUM_DATABASE_SAMPLES
411+
)
353412

354413
if is_revisited_dataset(eval_dataset_name):
355414
eval_dataset = RevisitedInstanceRetrievalDataset(
356-
eval_dataset_name, root_dataset_path
415+
eval_dataset_name, root_dataset_path, num_samples=num_samples
357416
)
358417
elif is_instre_dataset(eval_dataset_name):
359418
eval_dataset = InstreDataset(eval_data_path, num_samples=num_samples)
@@ -374,8 +433,12 @@ def instance_retrieval_test(args, cfg):
374433
resize_img = cfg.IMG_RETRIEVAL.RESIZE_IMG
375434
eval_binary_path = cfg.IMG_RETRIEVAL.EVAL_BINARY_PATH
376435
root_dataset_path = cfg.IMG_RETRIEVAL.DATASET_PATH
377-
temp_dir = f"{cfg.IMG_RETRIEVAL.TEMP_DIR}/{str(uuid.uuid4())}"
378-
logging.info(f"Temp directory: {temp_dir}")
436+
save_features = cfg.IMG_RETRIEVAL.SAVE_FEATURES
437+
438+
temp_dir = None
439+
if save_features:
440+
temp_dir = os.path.join(get_checkpoint_folder(cfg), "features")
441+
logging.info(f"Temp directory: {temp_dir}")
379442

380443
############################################################################
381444
# Step 1: Prepare the train/eval datasets, create model and load weights
@@ -422,8 +485,11 @@ def instance_retrieval_test(args, cfg):
422485
)
423486
########################################################################
424487
# Train PCA on the train features
425-
pca_out_fname = f"{temp_dir}/{train_dataset_name}_S{resize_img}_PCA.pickle"
426-
if PathManager.exists(pca_out_fname):
488+
pca_out_fname = None
489+
if temp_dir:
490+
491+
pca_out_fname = f"{temp_dir}/{train_dataset_name}_S{resize_img}_PCA.pickle"
492+
if pca_out_fname and PathManager.exists(pca_out_fname):
427493
logging.info("Loading PCA...")
428494
pca = load_pca(pca_out_fname)
429495
else:
@@ -462,30 +528,65 @@ def instance_retrieval_test(args, cfg):
462528
)
463529

464530
############################################################################
465-
# Step 5: Compute similarity and score
531+
# Step 5: Compute similarity, score, and save results
466532
logging.info("Calculating similarity and score...")
467533
sim = features_queries.dot(features_dataset.T)
468534
logging.info(f"Similarity tensor: {sim.shape}")
469-
eval_dataset.score(sim, temp_dir)
535+
results = eval_dataset.score(sim, temp_dir)
470536

471537
############################################################################
472-
# Step 6: cleanup the temp directory
473-
logging.info(f"Cleaning up temp directory: {temp_dir}")
474-
cleanup_dir(temp_dir)
538+
# Step 6: save results and cleanup the temp directory
539+
if cfg.IMG_RETRIEVAL.SAVE_RETRIEVAL_RANKINGS_SCORES:
540+
# Save the rankings
541+
sim = sim.T
542+
ranks = np.argsort(-sim, axis=0)
543+
save_file(
544+
ranks.T.tolist(), os.path.join(get_checkpoint_folder(cfg), "rankings.json")
545+
)
546+
547+
# Save the similarity scores
548+
save_file(
549+
sim.tolist(),
550+
os.path.join(get_checkpoint_folder(cfg), "similarity_scores.json"),
551+
)
552+
553+
# Save the result metrics
554+
save_file(results, os.path.join(get_checkpoint_folder(cfg), "metrics.json"))
475555

476556
logging.info("All done!!")
477557

478558

559+
def validate_and_infer_config(config: AttrDict):
560+
if config.IMG_RETRIEVAL.DEBUG_MODE:
561+
# Set data limits for the number of training, query, and database samples.
562+
if config.IMG_RETRIEVAL.NUM_TRAINING_SAMPLES == -1:
563+
config.IMG_RETRIEVAL.NUM_TRAINING_SAMPLES = 10
564+
565+
if config.IMG_RETRIEVAL.NUM_QUERY_SAMPLES == -1:
566+
config.IMG_RETRIEVAL.NUM_QUERY_SAMPLES = 10
567+
568+
if config.IMG_RETRIEVAL.NUM_DATABASE_SAMPLES == -1:
569+
config.IMG_RETRIEVAL.NUM_DATABASE_SAMPLES = 50
570+
571+
if config.IMG_RETRIEVAL.EVAL_DATASET_NAME in ["OXFORD", "PARIS"]:
572+
# InstanceRetrievalDataset#score requires the features to be saved.
573+
config.IMG_RETRIEVAL.SAVE_FEATURES = True
574+
575+
return config
576+
577+
479578
def main(args: Namespace, config: AttrDict):
579+
config = validate_and_infer_config(config)
580+
# setup the environment variables
581+
set_env_vars(local_rank=0, node_id=0, cfg=config)
582+
480583
# setup the logging
481-
setup_logging(__name__)
584+
checkpoint_folder = get_checkpoint_folder(config)
585+
setup_logging(__name__, output_dir=checkpoint_folder)
482586

483587
# print the config
484588
print_cfg(config)
485589

486-
# setup the environment variables
487-
set_env_vars(local_rank=0, node_id=0, cfg=config)
488-
489590
instance_retrieval_test(args, config)
490591
# close the logging streams including the filehandlers
491592
shutdown_logging()

0 commit comments

Comments
 (0)