4
4
# LICENSE file in the root directory of this source tree.
5
5
6
6
import logging
7
+ import os
7
8
import sys
8
9
import uuid
9
10
from argparse import Namespace
13
14
import torch
14
15
import torch .nn .functional as F
15
16
import torchvision
16
- from classy_vision .generic .util import copy_model_to_gpu
17
+ from classy_vision .generic .util import copy_model_to_gpu , load_checkpoint
17
18
from fvcore .common .file_io import PathManager
18
19
from hydra .experimental import compose , initialize_config_module
19
20
from vissl .config import AttrDict
20
21
from vissl .models import build_model
21
- from vissl .utils .checkpoint import init_model_from_consolidated_weights
22
+ from vissl .utils .checkpoint import (
23
+ init_model_from_consolidated_weights ,
24
+ get_checkpoint_folder ,
25
+ )
22
26
from vissl .utils .env import set_env_vars
23
27
from vissl .utils .hydra_config import convert_to_attrdict , is_hydra_available , print_cfg
24
28
from vissl .utils .instance_retrieval_utils .data_util import (
36
40
)
37
41
from vissl .utils .instance_retrieval_utils .pca import load_pca , train_and_save_pca
38
42
from vissl .utils .instance_retrieval_utils .rmac import get_rmac_descriptors
39
- from vissl .utils .io import cleanup_dir , load_file , makedir , save_file
43
+ from vissl .utils .io import load_file , makedir , save_file
40
44
from vissl .utils .logger import setup_logging , shutdown_logging
41
45
42
46
@@ -53,7 +57,7 @@ def build_retrieval_model(cfg):
53
57
if PathManager .exists (cfg .MODEL .WEIGHTS_INIT .PARAMS_FILE ):
54
58
init_weights_path = cfg .MODEL .WEIGHTS_INIT .PARAMS_FILE
55
59
logging .info (f"Initializing model from: { init_weights_path } " )
56
- weights = torch . load (init_weights_path , map_location = torch .device ("cuda" ))
60
+ weights = load_checkpoint (init_weights_path , device = torch .device ("cuda" ))
57
61
skip_layers = cfg .MODEL .WEIGHTS_INIT .get ("SKIP_LAYERS" , [])
58
62
replace_prefix = cfg .MODEL .WEIGHTS_INIT .get ("REMOVE_PREFIX" , None )
59
63
append_prefix = cfg .MODEL .WEIGHTS_INIT .get ("APPEND_PREFIX" , None )
@@ -77,14 +81,16 @@ def build_retrieval_model(cfg):
77
81
78
82
79
83
def gem_pool_and_save_features (features , p , add_bias , gem_out_fname ):
80
- if PathManager .exists (gem_out_fname ):
84
+ if gem_out_fname and PathManager .exists (gem_out_fname ):
81
85
logging .info ("Loading train GeM features..." )
82
86
features = load_file (gem_out_fname )
83
87
else :
84
88
logging .info (f"GeM pooling features: { features .shape } " )
85
89
features = l2n (gem (features , p = p , add_bias = True ))
86
- save_file (features , gem_out_fname )
87
- logging .info (f"Saved GeM features to: { gem_out_fname } " )
90
+
91
+ if gem_out_fname :
92
+ save_file (features , gem_out_fname , verbose = False )
93
+ logging .info (f"Saved GeM features to: { gem_out_fname } " )
88
94
return features
89
95
90
96
@@ -103,8 +109,12 @@ def get_train_features(
103
109
def process_train_image (i , out_dir ):
104
110
if i % LOG_FREQUENCY == 0 :
105
111
logging .info (f"Train Image: { i } " ),
106
- fname_out = f"{ out_dir } /{ i } .npy"
107
- if PathManager .exists (fname_out ):
112
+
113
+ fname_out = None
114
+ if out_dir :
115
+ fname_out = f"{ out_dir } /{ i } .npy"
116
+
117
+ if fname_out and PathManager .exists (fname_out ):
108
118
feat = load_file (fname_out )
109
119
train_features .append (feat )
110
120
else :
@@ -123,19 +133,33 @@ def process_train_image(i, out_dir):
123
133
# we can perform: rmac | gem pooling | l2 norm
124
134
if cfg .IMG_RETRIEVAL .FEATS_PROCESSING_TYPE == "rmac" :
125
135
descriptors = get_rmac_descriptors (activation_map , spatial_levels )
136
+ elif cfg .IMG_RETRIEVAL .FEATS_PROCESSING_TYPE == "l2_norm" :
137
+ # we simply L2 normalize the features otherwise
138
+ descriptors = F .normalize (activation_map , p = 2 , dim = 0 )
126
139
else :
127
140
descriptors = activation_map
128
- save_file (descriptors .data .numpy (), fname_out )
141
+
142
+ if fname_out :
143
+ save_file (descriptors .data .numpy (), fname_out , verbose = False )
129
144
train_features .append (descriptors .data .numpy ())
130
145
131
146
num_images = train_dataset .get_num_images ()
132
- out_dir = f"{ temp_dir } /{ train_dataset_name } _S{ resize_img } _features_train"
133
- makedir (out_dir )
147
+
148
+ out_dir = None
149
+ if temp_dir :
150
+ out_dir = f"{ temp_dir } /{ train_dataset_name } _S{ resize_img } _features_train"
151
+ makedir (out_dir )
152
+
153
+ logging .info (f"Getting features for train images: { num_images } " )
134
154
for i in range (num_images ):
135
155
process_train_image (i , out_dir )
136
156
137
157
if cfg .IMG_RETRIEVAL .FEATS_PROCESSING_TYPE == "gem" :
138
- gem_out_fname = f"{ out_dir } /{ train_dataset_name } _GeM.npy"
158
+
159
+ gem_out_fname = None
160
+ if out_dir :
161
+ gem_out_fname = f"{ out_dir } /{ train_dataset_name } _GeM.npy"
162
+
139
163
train_features = torch .tensor (np .concatenate (train_features ))
140
164
train_features = gem_pool_and_save_features (
141
165
train_features ,
@@ -165,10 +189,12 @@ def process_eval_image(
165
189
img = image_helper .load_and_prepare_instre_image (fname_in )
166
190
else :
167
191
img = image_helper .load_and_prepare_image (fname_in , roi = roi )
192
+
168
193
v = torch .autograd .Variable (img .unsqueeze (0 ))
169
194
vc = v .cuda ()
170
195
# the model output is a list always.
171
196
activation_map = model (vc )[0 ].cpu ()
197
+
172
198
# process the features: rmac | l2 norm
173
199
if cfg .IMG_RETRIEVAL .FEATS_PROCESSING_TYPE == "rmac" :
174
200
descriptors = get_rmac_descriptors (activation_map , spatial_levels , pca = pca )
@@ -177,7 +203,9 @@ def process_eval_image(
177
203
descriptors = F .normalize (activation_map , p = 2 , dim = 0 )
178
204
else :
179
205
descriptors = activation_map
180
- save_file (descriptors .data .numpy (), fname_out )
206
+
207
+ if fname_out :
208
+ save_file (descriptors .data .numpy (), fname_out , verbose = False )
181
209
return descriptors .data .numpy ()
182
210
183
211
@@ -195,15 +223,23 @@ def get_dataset_features(
195
223
features_dataset = []
196
224
num_images = eval_dataset .get_num_images ()
197
225
logging .info (f"Getting features for dataset images: { num_images } " )
198
- db_fname_out_dir = "{}/{}_S{}_db" .format (temp_dir , eval_dataset_name , resize_img )
226
+
227
+ db_fname_out_dir = None
228
+ if temp_dir :
229
+ db_fname_out_dir = f"{ temp_dir } /{ eval_dataset_name } _S{ resize_img } _db"
230
+
199
231
makedir (db_fname_out_dir )
200
232
201
233
for idx in range (num_images ):
202
234
if idx % LOG_FREQUENCY == 0 :
203
235
logging .info (f"Eval Dataset Image: { idx } " ),
204
236
db_fname_in = eval_dataset .get_filename (idx )
205
- db_fname_out = f"{ db_fname_out_dir } /{ idx } .npy"
206
- if PathManager .exists (db_fname_out ):
237
+
238
+ db_fname_out = None
239
+ if db_fname_out_dir :
240
+ db_fname_out = f"{ db_fname_out_dir } /{ idx } .npy"
241
+
242
+ if db_fname_out and PathManager .exists (db_fname_out ):
207
243
db_feature = load_file (db_fname_out )
208
244
else :
209
245
db_feature = process_eval_image (
@@ -221,7 +257,9 @@ def get_dataset_features(
221
257
222
258
if cfg .IMG_RETRIEVAL .FEATS_PROCESSING_TYPE == "gem" :
223
259
# GeM pool the features and apply the PCA
224
- gem_out_fname = f"{ db_fname_out_dir } /{ eval_dataset_name } _GeM.npy"
260
+ gem_out_fname = None
261
+ if db_fname_out_dir :
262
+ gem_out_fname = f"{ db_fname_out_dir } /{ eval_dataset_name } _GeM.npy"
225
263
features_dataset = torch .tensor (np .concatenate (features_dataset ))
226
264
features_dataset = gem_pool_and_save_features (
227
265
features_dataset ,
@@ -231,7 +269,7 @@ def get_dataset_features(
231
269
)
232
270
features_dataset = pca .apply (features_dataset )
233
271
features_dataset = np .vstack (features_dataset )
234
- logging .info (f"features dataset : { features_dataset .shape } " )
272
+ logging .info (f"Dataset Features Size : { features_dataset .shape } " )
235
273
return features_dataset
236
274
237
275
@@ -248,19 +286,30 @@ def get_queries_features(
248
286
):
249
287
features_queries = []
250
288
num_queries = eval_dataset .get_num_query_images ()
251
- if cfg .IMG_RETRIEVAL .DEBUG_MODE :
252
- num_queries = 50
289
+
290
+ num_queries = (
291
+ num_queries
292
+ if cfg .IMG_RETRIEVAL .NUM_QUERY_SAMPLES == - 1
293
+ else cfg .IMG_RETRIEVAL .NUM_QUERY_SAMPLES
294
+ )
295
+
253
296
logging .info (f"Getting features for queries: { num_queries } " )
254
- q_fname_out_dir = "{}/{}_S{}_q" .format (temp_dir , eval_dataset_name , resize_img )
255
- makedir (q_fname_out_dir )
297
+ q_fname_out_dir = None
298
+ if q_fname_out_dir :
299
+ q_fname_out_dir = f"{ temp_dir } /{ eval_dataset_name } _S{ resize_img } _q"
300
+ makedir (q_fname_out_dir )
256
301
257
302
for idx in range (num_queries ):
258
303
if idx % LOG_FREQUENCY == 0 :
259
304
logging .info (f"Eval Query: { idx } " ),
260
305
q_fname_in = eval_dataset .get_query_filename (idx )
261
306
roi = eval_dataset .get_query_roi (idx )
262
- q_fname_out = f"{ q_fname_out_dir } /{ idx } .npy"
263
- if PathManager .exists (q_fname_out ):
307
+
308
+ q_fname_out = None
309
+ if q_fname_out_dir :
310
+ q_fname_out = f"{ q_fname_out_dir } /{ idx } .npy"
311
+
312
+ if q_fname_out and PathManager .exists (q_fname_out ):
264
313
query_feature = load_file (q_fname_out )
265
314
else :
266
315
query_feature = process_eval_image (
@@ -278,7 +327,9 @@ def get_queries_features(
278
327
279
328
if cfg .IMG_RETRIEVAL .FEATS_PROCESSING_TYPE == "gem" :
280
329
# GeM pool the features and apply the PCA
281
- gem_out_fname = f"{ q_fname_out_dir } /{ eval_dataset_name } _GeM.npy"
330
+ gem_out_fname = None
331
+ if q_fname_out_dir :
332
+ gem_out_fname = f"{ q_fname_out_dir } /{ eval_dataset_name } _GeM.npy"
282
333
features_queries = torch .tensor (np .concatenate (features_queries ))
283
334
features_queries = gem_pool_and_save_features (
284
335
features_queries ,
@@ -288,7 +339,7 @@ def get_queries_features(
288
339
)
289
340
features_queries = pca .apply (features_queries )
290
341
features_queries = np .vstack (features_queries )
291
- logging .info (f"features queries : { features_queries .shape } " )
342
+ logging .info (f"Queries Features Size : { features_queries .shape } " )
292
343
return features_queries
293
344
294
345
@@ -324,11 +375,15 @@ def get_train_dataset(cfg, root_dataset_path, train_dataset_name, eval_binary_pa
324
375
train_data_path = f"{ root_dataset_path } /{ train_dataset_name } "
325
376
assert PathManager .exists (train_data_path ), f"Unknown path: { train_data_path } "
326
377
327
- num_samples = 10 if cfg .IMG_RETRIEVAL .DEBUG_MODE else None
378
+ num_samples = (
379
+ None
380
+ if cfg .IMG_RETRIEVAL .NUM_TRAINING_SAMPLES == - 1
381
+ else cfg .IMG_RETRIEVAL .NUM_TRAINING_SAMPLES
382
+ )
328
383
329
384
if is_revisited_dataset (train_dataset_name ):
330
385
train_dataset = RevisitedInstanceRetrievalDataset (
331
- train_dataset_name , root_dataset_path
386
+ train_dataset_name , root_dataset_path , num_samples = num_samples
332
387
)
333
388
elif is_whiten_dataset (train_dataset_name ):
334
389
train_dataset = WhiteningTrainingImageDataset (
@@ -349,11 +404,15 @@ def get_eval_dataset(cfg, root_dataset_path, eval_dataset_name, eval_binary_path
349
404
eval_data_path = f"{ root_dataset_path } /{ eval_dataset_name } "
350
405
assert PathManager .exists (eval_data_path ), f"Unknown path: { eval_data_path } "
351
406
352
- num_samples = 20 if cfg .IMG_RETRIEVAL .DEBUG_MODE else None
407
+ num_samples = (
408
+ None
409
+ if cfg .IMG_RETRIEVAL .NUM_DATABASE_SAMPLES == - 1
410
+ else cfg .IMG_RETRIEVAL .NUM_DATABASE_SAMPLES
411
+ )
353
412
354
413
if is_revisited_dataset (eval_dataset_name ):
355
414
eval_dataset = RevisitedInstanceRetrievalDataset (
356
- eval_dataset_name , root_dataset_path
415
+ eval_dataset_name , root_dataset_path , num_samples = num_samples
357
416
)
358
417
elif is_instre_dataset (eval_dataset_name ):
359
418
eval_dataset = InstreDataset (eval_data_path , num_samples = num_samples )
@@ -374,8 +433,12 @@ def instance_retrieval_test(args, cfg):
374
433
resize_img = cfg .IMG_RETRIEVAL .RESIZE_IMG
375
434
eval_binary_path = cfg .IMG_RETRIEVAL .EVAL_BINARY_PATH
376
435
root_dataset_path = cfg .IMG_RETRIEVAL .DATASET_PATH
377
- temp_dir = f"{ cfg .IMG_RETRIEVAL .TEMP_DIR } /{ str (uuid .uuid4 ())} "
378
- logging .info (f"Temp directory: { temp_dir } " )
436
+ save_features = cfg .IMG_RETRIEVAL .SAVE_FEATURES
437
+
438
+ temp_dir = None
439
+ if save_features :
440
+ temp_dir = os .path .join (get_checkpoint_folder (cfg ), "features" )
441
+ logging .info (f"Temp directory: { temp_dir } " )
379
442
380
443
############################################################################
381
444
# Step 1: Prepare the train/eval datasets, create model and load weights
@@ -422,8 +485,11 @@ def instance_retrieval_test(args, cfg):
422
485
)
423
486
########################################################################
424
487
# Train PCA on the train features
425
- pca_out_fname = f"{ temp_dir } /{ train_dataset_name } _S{ resize_img } _PCA.pickle"
426
- if PathManager .exists (pca_out_fname ):
488
+ pca_out_fname = None
489
+ if temp_dir :
490
+
491
+ pca_out_fname = f"{ temp_dir } /{ train_dataset_name } _S{ resize_img } _PCA.pickle"
492
+ if pca_out_fname and PathManager .exists (pca_out_fname ):
427
493
logging .info ("Loading PCA..." )
428
494
pca = load_pca (pca_out_fname )
429
495
else :
@@ -462,30 +528,65 @@ def instance_retrieval_test(args, cfg):
462
528
)
463
529
464
530
############################################################################
465
- # Step 5: Compute similarity and score
531
+ # Step 5: Compute similarity, score, and save results
466
532
logging .info ("Calculating similarity and score..." )
467
533
sim = features_queries .dot (features_dataset .T )
468
534
logging .info (f"Similarity tensor: { sim .shape } " )
469
- eval_dataset .score (sim , temp_dir )
535
+ results = eval_dataset .score (sim , temp_dir )
470
536
471
537
############################################################################
472
- # Step 6: cleanup the temp directory
473
- logging .info (f"Cleaning up temp directory: { temp_dir } " )
474
- cleanup_dir (temp_dir )
538
+ # Step 6: save results and cleanup the temp directory
539
+ if cfg .IMG_RETRIEVAL .SAVE_RETRIEVAL_RANKINGS_SCORES :
540
+ # Save the rankings
541
+ sim = sim .T
542
+ ranks = np .argsort (- sim , axis = 0 )
543
+ save_file (
544
+ ranks .T .tolist (), os .path .join (get_checkpoint_folder (cfg ), "rankings.json" )
545
+ )
546
+
547
+ # Save the similarity scores
548
+ save_file (
549
+ sim .tolist (),
550
+ os .path .join (get_checkpoint_folder (cfg ), "similarity_scores.json" ),
551
+ )
552
+
553
+ # Save the result metrics
554
+ save_file (results , os .path .join (get_checkpoint_folder (cfg ), "metrics.json" ))
475
555
476
556
logging .info ("All done!!" )
477
557
478
558
559
+ def validate_and_infer_config (config : AttrDict ):
560
+ if config .IMG_RETRIEVAL .DEBUG_MODE :
561
+ # Set data limits for the number of training, query, and database samples.
562
+ if config .IMG_RETRIEVAL .NUM_TRAINING_SAMPLES == - 1 :
563
+ config .IMG_RETRIEVAL .NUM_TRAINING_SAMPLES = 10
564
+
565
+ if config .IMG_RETRIEVAL .NUM_QUERY_SAMPLES == - 1 :
566
+ config .IMG_RETRIEVAL .NUM_QUERY_SAMPLES = 10
567
+
568
+ if config .IMG_RETRIEVAL .NUM_DATABASE_SAMPLES == - 1 :
569
+ config .IMG_RETRIEVAL .NUM_DATABASE_SAMPLES = 50
570
+
571
+ if config .IMG_RETRIEVAL .EVAL_DATASET_NAME in ["OXFORD" , "PARIS" ]:
572
+ # InstanceRetrievalDataset#score requires the features to be saved.
573
+ config .IMG_RETRIEVAL .SAVE_FEATURES = True
574
+
575
+ return config
576
+
577
+
479
578
def main (args : Namespace , config : AttrDict ):
579
+ config = validate_and_infer_config (config )
580
+ # setup the environment variables
581
+ set_env_vars (local_rank = 0 , node_id = 0 , cfg = config )
582
+
480
583
# setup the logging
481
- setup_logging (__name__ )
584
+ checkpoint_folder = get_checkpoint_folder (config )
585
+ setup_logging (__name__ , output_dir = checkpoint_folder )
482
586
483
587
# print the config
484
588
print_cfg (config )
485
589
486
- # setup the environment variables
487
- set_env_vars (local_rank = 0 , node_id = 0 , cfg = config )
488
-
489
590
instance_retrieval_test (args , config )
490
591
# close the logging streams including the filehandlers
491
592
shutdown_logging ()
0 commit comments