|
| 1 | + |
| 2 | +dependencies = ['torch', 'torchvision'] |
| 3 | + |
| 4 | +import torch |
| 5 | +from model import network |
| 6 | + |
| 7 | + |
| 8 | +AVAILABLE_MODELS = { |
| 9 | + "VGG16": [ 64, 128, 256, 512], |
| 10 | + "ResNet18": [32, 64, 128, 256, 512], |
| 11 | + "ResNet50": [32, 64, 128, 256, 512, 1024, 2048], |
| 12 | + "ResNet101": [32, 64, 128, 256, 512, 1024, 2048], |
| 13 | + "ResNet152": [32, 64, 128, 256, 512, 1024, 2048], |
| 14 | +} |
| 15 | + |
| 16 | + |
| 17 | +def get_trained_model(backbone : str = "ResNet18", fc_output_dim : int = 32) -> torch.nn.Module: |
| 18 | + """Return a model trained with CosPlace on San Francisco eXtra Large. |
| 19 | + |
| 20 | + Args: |
| 21 | + backbone (str): which torchvision backbone to use. Must be VGG16 or a ResNet. |
| 22 | + fc_output_dim (int): the output dimension of the last fc layer, equivalent to |
| 23 | + the descriptors dimension. Must be between 32 and 2048, depending on model's availability. |
| 24 | + |
| 25 | + Return: |
| 26 | + model (torch.nn.Module): a trained model. |
| 27 | + """ |
| 28 | + print(f"Returning CosPlace model with backbone: {backbone} with features dimension {fc_output_dim}") |
| 29 | + if backbone not in AVAILABLE_MODELS: |
| 30 | + raise ValueError(f"Parameter `backbone` is set to {backbone} but it must be one of {list(AVAILABLE_MODELS.keys())}") |
| 31 | + try: |
| 32 | + fc_output_dim = int(fc_output_dim) |
| 33 | + except: |
| 34 | + raise ValueError(f"Parameter `fc_output_dim` must be an integer, but it is set to {fc_output_dim}") |
| 35 | + if fc_output_dim not in AVAILABLE_MODELS[backbone]: |
| 36 | + raise ValueError(f"Parameter `fc_output_dim` is set to {fc_output_dim}, but for backbone {backbone} " |
| 37 | + f"it must be one of {list(AVAILABLE_MODELS[backbone])}") |
| 38 | + model = network.GeoLocalizationNet(backbone, fc_output_dim) |
| 39 | + model.load_state_dict( |
| 40 | + torch.hub.load_state_dict_from_url( |
| 41 | + f'https://github.com/gmberton/CosPlace/releases/download/v0.1.0/{backbone}_{fc_output_dim}_cosplace.pth', |
| 42 | + map_location=torch.device('cpu')) |
| 43 | + ) |
| 44 | + return model |
0 commit comments