Skip to content

Commit 3c71c75

Browse files
committed
ok
1 parent c6eb4da commit 3c71c75

File tree

3 files changed

+60
-22
lines changed

3 files changed

+60
-22
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,5 @@ recon_epoch_1.png
99
*.png
1010
__pycache__/vit.cpython-311.pyc
1111
__pycache__/helper.cpython-311.pyc
12-
actions-runner/*
12+
actions-runner/*
13+
imf_encoder.onnx

model.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,18 @@ def process_tokens(self, t_c, t_r):
477477

478478
return m_c, m_r
479479

480+
481+
class IMFEncoder(nn.Module):
482+
def __init__(self, model):
483+
super(IMFEncoder, self).__init__()
484+
self.model = model
485+
486+
def forward(self, x_current, x_reference):
487+
f_r = self.model.dense_feature_encoder(x_reference)
488+
t_r = self.model.latent_token_encoder(x_reference)
489+
t_c = self.model.latent_token_encoder(x_current)
490+
return f_r, t_r, t_c
491+
480492
class MappingNetwork(nn.Module):
481493
def __init__(self, latent_dim, w_dim, depth):
482494
super().__init__()

onnxconv.py

Lines changed: 46 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,23 @@
11
import torch
22
import torch.onnx
3-
from IMF.model import IMFModel
4-
5-
6-
3+
from model import IMFModel
4+
import torch.nn as nn
5+
from PIL import Image
6+
from torchvision import transforms
7+
8+
# Define the IMFEncoder class
9+
class IMFEncoder(nn.Module):
10+
def __init__(self, model):
11+
super(IMFEncoder, self).__init__()
12+
self.model = model
13+
14+
def forward(self, x_current, x_reference):
15+
f_r = self.model.dense_feature_encoder(x_reference)
16+
t_r = self.model.latent_token_encoder(x_reference)
17+
t_c = self.model.latent_token_encoder(x_current)
18+
return f_r, t_r, t_c # Fixed indentation here
19+
20+
# Define the trace handler and utility functions
721
def trace_handler(module, input, output):
822
print(f"\nModule: {module.__class__.__name__}")
923
for idx, inp in enumerate(input):
@@ -14,7 +28,6 @@ def trace_handler(module, input, output):
1428
for idx, out in enumerate(output):
1529
print_tensor_info(out, f" Output[{idx}]")
1630

17-
1831
def print_tensor_info(tensor, name, indent=0):
1932
indent_str = ' ' * indent
2033
print(f"{indent_str}{name}:")
@@ -39,6 +52,7 @@ def print_model_structure(model):
3952
if hasattr(module, 'bias') and module.bias is not None:
4053
print(f" Bias shape: {module.bias.shape}")
4154

55+
# Adjusted export_to_onnx function
4256
def export_to_onnx(model, x_current, x_reference, file_name):
4357
try:
4458
print("Model structure before tracing:")
@@ -67,21 +81,21 @@ def export_to_onnx(model, x_current, x_reference, file_name):
6781
torch.onnx.export(
6882
model,
6983
(x_current, x_reference),
70-
"imf_model.onnx",
84+
file_name,
7185
export_params=True,
7286
opset_version=11,
7387
do_constant_folding=True,
7488
input_names=['x_current', 'x_reference'],
75-
output_names=['output'],
89+
output_names=['f_r', 't_r', 't_c'], # Adjusted output names
7690
dynamic_axes={
7791
'x_current': {0: 'batch_size'},
7892
'x_reference': {0: 'batch_size'},
79-
'output': {0: 'batch_size'}
93+
'f_r': {0: 'batch_size'},
94+
't_r': {0: 'batch_size'},
95+
't_c': {0: 'batch_size'}
8096
},
8197
verbose=True
8298
)
83-
84-
8599
print(f"Model exported successfully to {file_name}")
86100
except Exception as e:
87101
print(f"Error during ONNX export: {str(e)}")
@@ -93,16 +107,27 @@ def export_to_onnx(model, x_current, x_reference, file_name):
93107
model.eval()
94108

95109
# Load the checkpoint
96-
checkpoint = torch.load("./checkpoints/checkpoint.pth", map_location=lambda storage, loc: storage)
97-
state_dict = checkpoint['model_state_dict']
98-
99-
# # Adjust the weights in the state_dict
100-
# for key in state_dict.keys():
101-
# if 'csonv.weight' in key and state_dict[key].dim() == 5:
102-
# state_dict[key] = state_dict[key].squeeze(0)
103-
# Create dummy input tensors
104-
x_current = torch.randn(1, 3, 256, 256)
105-
x_reference = torch.randn(1, 3, 256, 256)
110+
checkpoint = torch.load("./checkpoints/checkpoint.pth", map_location='cpu')
111+
model.load_state_dict(checkpoint['model_state_dict'])
112+
113+
# Create the IMFEncoder instance
114+
encoder_model = IMFEncoder(model)
115+
encoder_model.eval()
116+
117+
# Load images and preprocess
118+
def load_image(image_path):
119+
transform = transforms.Compose([
120+
transforms.Resize((256, 256)), # Adjust as per your model's requirements
121+
transforms.ToTensor(),
122+
transforms.Normalize(mean=[0.485, 0.456, 0.406], # Adjust as per your model's requirements
123+
std=[0.229, 0.224, 0.225])
124+
])
125+
image = Image.open(image_path).convert('RGB')
126+
image = transform(image).unsqueeze(0) # Add batch dimension
127+
return image
128+
129+
x_current = load_image("x_current.png")
130+
x_reference = load_image("x_reference.png")
106131

107132
# Export the model
108-
export_to_onnx(model, x_current, x_reference, "imf_model.onnx")
133+
export_to_onnx(encoder_model, x_current, x_reference, "imf_encoder.onnx")

0 commit comments

Comments
 (0)