Skip to content

Commit 7c1e9f6

Browse files
committed
Added hubconf.py to allow models to be downloaded from torch.hub
1 parent dfc524c commit 7c1e9f6

File tree

1 file changed

+44
-0
lines changed

1 file changed

+44
-0
lines changed

hubconf.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
2+
dependencies = ['torch', 'torchvision']
3+
4+
import torch
5+
from model import network
6+
7+
8+
AVAILABLE_MODELS = {
9+
"VGG16": [ 64, 128, 256, 512],
10+
"ResNet18": [32, 64, 128, 256, 512],
11+
"ResNet50": [32, 64, 128, 256, 512, 1024, 2048],
12+
"ResNet101": [32, 64, 128, 256, 512, 1024, 2048],
13+
"ResNet152": [32, 64, 128, 256, 512, 1024, 2048],
14+
}
15+
16+
17+
def get_trained_model(backbone : str = "ResNet18", fc_output_dim : int = 32) -> torch.nn.Module:
18+
"""Return a model trained with CosPlace on San Francisco eXtra Large.
19+
20+
Args:
21+
backbone (str): which torchvision backbone to use. Must be VGG16 or a ResNet.
22+
fc_output_dim (int): the output dimension of the last fc layer, equivalent to
23+
the descriptors dimension. Must be between 32 and 2048, depending on model's availability.
24+
25+
Return:
26+
model (torch.nn.Module): a trained model.
27+
"""
28+
print(f"Returning CosPlace model with backbone: {backbone} with features dimension {fc_output_dim}")
29+
if backbone not in AVAILABLE_MODELS:
30+
raise ValueError(f"Parameter `backbone` is set to {backbone} but it must be one of {list(AVAILABLE_MODELS.keys())}")
31+
try:
32+
fc_output_dim = int(fc_output_dim)
33+
except:
34+
raise ValueError(f"Parameter `fc_output_dim` must be an integer, but it is set to {fc_output_dim}")
35+
if fc_output_dim not in AVAILABLE_MODELS[backbone]:
36+
raise ValueError(f"Parameter `fc_output_dim` is set to {fc_output_dim}, but for backbone {backbone} "
37+
f"it must be one of {list(AVAILABLE_MODELS[backbone])}")
38+
model = network.GeoLocalizationNet(backbone, fc_output_dim)
39+
model.load_state_dict(
40+
torch.hub.load_state_dict_from_url(
41+
f'https://github.com/gmberton/CosPlace/releases/download/v0.1.0/{backbone}_{fc_output_dim}_cosplace.pth',
42+
map_location=torch.device('cpu'))
43+
)
44+
return model

0 commit comments

Comments
 (0)