We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Thank you for your work! I am trying to provide a demo that can run on CPU:
""" Copyright (C) 2019 NVIDIA Corporation. All rights reserved. Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). """ from __future__ import print_function, division import argparse import torch import torch.nn as nn from torch.autograd import Variable import numpy as np from torchvision import transforms from PIL import Image import os import yaml from reIDmodel import ft_netAB import matplotlib.pyplot as plt # 命令行参数 parser = argparse.ArgumentParser(description='Demo') parser.add_argument('--which_epoch', default=100000, type=int, help='which epoch to load model') parser.add_argument('--name', default='E0.5new_reid0.5_w30000', type=str, help='model name') parser.add_argument('--img_dirs', default='', type=str, help='directory for input images, separate by comma') opt = parser.parse_args() # 数据预处理 data_transforms = transforms.Compose([ transforms.Resize((256, 128), interpolation=3), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # 水平翻转 def fliplr(img): '''水平翻转图像''' inv_idx = torch.arange(img.size(3) - 1, -1, -1).long() # N x C x H x W img_flip = img.index_select(3, inv_idx) return img_flip # 特征归一化 def norm(x): # 计算全局 L2 范数并归一化 global_norm = torch.norm(x, p=2) # p=2 表示 L2 范数 normalized_global = x / global_norm return normalized_global # 加载模型 def load_network(network): save_path = os.path.join('./outputs', opt.name, 'checkpoints/id_%08d.pt' % opt.which_epoch) state_dict = torch.load(save_path, map_location=torch.device('cpu')) network.load_state_dict(state_dict['a'], strict=False) return network # 从图像中提取特征 def extract_feature(model, img_path): img = Image.open(img_path).convert('RGB') img = data_transforms(img) img = img.unsqueeze(0) # 添加批次维度 # 提取特征 with torch.no_grad(): n, c, h, w = img.size() ff = torch.FloatTensor(n, 1024).zero_() for i in range(2): if i == 1: img = fliplr(img) input_img = Variable(img) f, x = model(input_img) x[0] = norm(x[0]) x[1] = norm(x[1]) f = torch.cat((x[0], x[1]), dim=1) # 使用512维特征 ff = ff + f # 归一化特征向量 ff[:, 0:512] = norm(ff[:, 0:512]) ff[:, 512:1024] = norm(ff[:, 512:1024]) return ff def compute_similarity(features_list): """计算特征之间的余弦相似度""" features = torch.cat(features_list, dim=0) features_np = features.numpy() print('特征形状: {}'.format(features_np.shape)) # 计算余弦相似度 cosine_similarity = np.matmul(features_np, features_np.transpose()) return cosine_similarity, features_np def main(): # 打印程序信息 print('---------- 行人重识别推理Demo (CPU版本) ----------') try: # 加载配置文件 config_path = os.path.join('./outputs', opt.name, 'config.yaml') with open(config_path, 'r') as stream: config = yaml.safe_load(stream) print(f'成功加载配置文件: {config_path}') except Exception as e: print(f'加载配置文件失败: {str(e)}') print('尝试创建默认配置...') # 如果无法加载配置文件,使用默认配置 config = { 'ID_class': 751, # Market-1501 默认类别数 'norm_id': 1, 'ID_stride': 2, 'pool': 'avg' } # 创建模型 print('创建模型...') model_structure = ft_netAB(config['ID_class'], norm=config['norm_id'], stride=config['ID_stride'], pool=config['pool']) try: model = load_network(model_structure) print('成功加载模型权重') except Exception as e: print(f'加载模型权重失败: {str(e)}') print('请确保模型文件存在于指定路径') return # 移除最终的fc层和分类器层 model.model.fc = nn.Sequential() model.classifier1.classifier = nn.Sequential() model.classifier2.classifier = nn.Sequential() # 设置为评估模式 model = model.eval() # 处理图像目录 if not opt.img_dirs: print('请提供图像目录!使用 --img_dirs 参数,多个目录用逗号分隔') return import glob img_paths = glob.glob(opt.img_dirs + '/*.png') img_paths.sort() if len(img_paths) == 0: print('未找到图像文件!') return print(f'找到 {len(img_paths)} 张图像') # 提取特征 features_list = [] for i, img_path in enumerate(img_paths): print(f'处理图像 {i + 1}/{len(img_paths)}: {img_path}') try: feature = extract_feature(model, img_path) features_list.append(feature) except Exception as e: print(f'处理图像 {img_path} 失败: {str(e)}') if len(features_list) == 0: print('没有成功提取任何特征!') return # 计算相似度 similarity_matrix, features = compute_similarity(features_list) n = len(img_paths) # 在函数末尾添加热图 plt.figure(figsize=(n, n)) plt.imshow(similarity_matrix, cmap='viridis') plt.colorbar(label='cosine similarity') for i in range(n): for j in range(n): plt.text(j, i, f'{similarity_matrix[i, j]:.2f}', ha='center', va='center', color='white' if similarity_matrix[i, j] < 0.7 else 'black') plt.xticks(range(n), [os.path.basename(path) for path in img_paths], rotation=90) plt.yticks(range(n), [os.path.basename(path) for path in img_paths]) plt.tight_layout() plt.savefig('similarity_heatmap.png', dpi=300) plt.close() print("相似度热图已保存为 similarity_heatmap.png") try: plt.show() except Exception as e: print(f'无法显示图像: {str(e)}') print('余弦相似度矩阵:') print(similarity_matrix) if __name__ == '__main__': main()
The text was updated successfully, but these errors were encountered:
No branches or pull requests
Thank you for your work!
I am trying to provide a demo that can run on CPU:
The text was updated successfully, but these errors were encountered: