Skip to content

A simple infer demo(only need cpu) #81

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
shellyfung opened this issue May 8, 2025 · 0 comments
Open

A simple infer demo(only need cpu) #81

shellyfung opened this issue May 8, 2025 · 0 comments

Comments

@shellyfung
Copy link

shellyfung commented May 8, 2025

Thank you for your work!
I am trying to provide a demo that can run on CPU:

"""
Copyright (C) 2019 NVIDIA Corporation.  All rights reserved.
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
"""

from __future__ import print_function, division

import argparse
import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
from torchvision import transforms
from PIL import Image
import os
import yaml
from reIDmodel import ft_netAB
import matplotlib.pyplot as plt

# 命令行参数
parser = argparse.ArgumentParser(description='Demo')
parser.add_argument('--which_epoch', default=100000, type=int, help='which epoch to load model')
parser.add_argument('--name', default='E0.5new_reid0.5_w30000', type=str, help='model name')
parser.add_argument('--img_dirs', default='', type=str, help='directory for input images, separate by comma')
opt = parser.parse_args()

# 数据预处理
data_transforms = transforms.Compose([
    transforms.Resize((256, 128), interpolation=3),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])


# 水平翻转
def fliplr(img):
    '''水平翻转图像'''
    inv_idx = torch.arange(img.size(3) - 1, -1, -1).long()  # N x C x H x W
    img_flip = img.index_select(3, inv_idx)
    return img_flip


# 特征归一化
def norm(x):
    # 计算全局 L2 范数并归一化
    global_norm = torch.norm(x, p=2)  # p=2 表示 L2 范数
    normalized_global = x / global_norm
    return normalized_global



# 加载模型
def load_network(network):
    save_path = os.path.join('./outputs', opt.name, 'checkpoints/id_%08d.pt' % opt.which_epoch)
    state_dict = torch.load(save_path, map_location=torch.device('cpu'))
    network.load_state_dict(state_dict['a'], strict=False)
    return network


# 从图像中提取特征
def extract_feature(model, img_path):
    img = Image.open(img_path).convert('RGB')
    img = data_transforms(img)
    img = img.unsqueeze(0)  # 添加批次维度

    # 提取特征
    with torch.no_grad():
        n, c, h, w = img.size()
        ff = torch.FloatTensor(n, 1024).zero_()

        for i in range(2):
            if i == 1:
                img = fliplr(img)
            input_img = Variable(img)
            f, x = model(input_img)
            x[0] = norm(x[0])
            x[1] = norm(x[1])
            f = torch.cat((x[0], x[1]), dim=1)  # 使用512维特征
            ff = ff + f

        # 归一化特征向量
        ff[:, 0:512] = norm(ff[:, 0:512])
        ff[:, 512:1024] = norm(ff[:, 512:1024])

        return ff


def compute_similarity(features_list):
    """计算特征之间的余弦相似度"""
    features = torch.cat(features_list, dim=0)
    features_np = features.numpy()
    print('特征形状: {}'.format(features_np.shape))

    # 计算余弦相似度
    cosine_similarity = np.matmul(features_np, features_np.transpose())

    return cosine_similarity, features_np


def main():
    # 打印程序信息
    print('---------- 行人重识别推理Demo (CPU版本) ----------')

    try:
        # 加载配置文件
        config_path = os.path.join('./outputs', opt.name, 'config.yaml')
        with open(config_path, 'r') as stream:
            config = yaml.safe_load(stream)
            
        print(f'成功加载配置文件: {config_path}')
    except Exception as e:
        print(f'加载配置文件失败: {str(e)}')
        print('尝试创建默认配置...')
        # 如果无法加载配置文件,使用默认配置
        config = {
            'ID_class': 751,  # Market-1501 默认类别数
            'norm_id': 1,
            'ID_stride': 2,
            'pool': 'avg'
        }

    # 创建模型
    print('创建模型...')
    model_structure = ft_netAB(config['ID_class'], norm=config['norm_id'],
                               stride=config['ID_stride'], pool=config['pool'])
    
    try:
        model = load_network(model_structure)
        print('成功加载模型权重')
    except Exception as e:
        print(f'加载模型权重失败: {str(e)}')
        print('请确保模型文件存在于指定路径')
        return

    # 移除最终的fc层和分类器层
    model.model.fc = nn.Sequential()
    model.classifier1.classifier = nn.Sequential()
    model.classifier2.classifier = nn.Sequential()

    # 设置为评估模式
    model = model.eval()

    # 处理图像目录
    if not opt.img_dirs:
        print('请提供图像目录!使用 --img_dirs 参数,多个目录用逗号分隔')
        return

    import glob
    img_paths = glob.glob(opt.img_dirs + '/*.png')
    img_paths.sort()

    if len(img_paths) == 0:
        print('未找到图像文件!')
        return

    print(f'找到 {len(img_paths)} 张图像')

    # 提取特征
    features_list = []
    for i, img_path in enumerate(img_paths):
        print(f'处理图像 {i + 1}/{len(img_paths)}: {img_path}')
        try:
            feature = extract_feature(model, img_path)
            features_list.append(feature)
        except Exception as e:
            print(f'处理图像 {img_path} 失败: {str(e)}')

    if len(features_list) == 0:
        print('没有成功提取任何特征!')
        return
    
    # 计算相似度
    similarity_matrix, features = compute_similarity(features_list)

    n = len(img_paths)
    # 在函数末尾添加热图
    plt.figure(figsize=(n, n))
    plt.imshow(similarity_matrix, cmap='viridis')
    plt.colorbar(label='cosine similarity')
    for i in range(n):
        for j in range(n):
            plt.text(j, i, f'{similarity_matrix[i, j]:.2f}',
                     ha='center', va='center',
                     color='white' if similarity_matrix[i, j] < 0.7 else 'black')
    plt.xticks(range(n), [os.path.basename(path) for path in img_paths], rotation=90)
    plt.yticks(range(n), [os.path.basename(path) for path in img_paths])
    plt.tight_layout()
    plt.savefig('similarity_heatmap.png', dpi=300)
    plt.close()
    print("相似度热图已保存为 similarity_heatmap.png")
    
    try:
        plt.show()
    except Exception as e:
        print(f'无法显示图像: {str(e)}')

    print('余弦相似度矩阵:')
    print(similarity_matrix)


if __name__ == '__main__':
    main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant