1
1
import torch
2
2
import torch .onnx
3
- from IMF .model import IMFModel
4
-
5
-
6
-
3
+ from model import IMFModel
4
+ import torch .nn as nn
5
+ from PIL import Image
6
+ from torchvision import transforms
7
+
8
+ # Define the IMFEncoder class
9
+ class IMFEncoder (nn .Module ):
10
+ def __init__ (self , model ):
11
+ super (IMFEncoder , self ).__init__ ()
12
+ self .model = model
13
+
14
+ def forward (self , x_current , x_reference ):
15
+ f_r = self .model .dense_feature_encoder (x_reference )
16
+ t_r = self .model .latent_token_encoder (x_reference )
17
+ t_c = self .model .latent_token_encoder (x_current )
18
+ return f_r , t_r , t_c # Fixed indentation here
19
+
20
+ # Define the trace handler and utility functions
7
21
def trace_handler (module , input , output ):
8
22
print (f"\n Module: { module .__class__ .__name__ } " )
9
23
for idx , inp in enumerate (input ):
@@ -14,7 +28,6 @@ def trace_handler(module, input, output):
14
28
for idx , out in enumerate (output ):
15
29
print_tensor_info (out , f" Output[{ idx } ]" )
16
30
17
-
18
31
def print_tensor_info (tensor , name , indent = 0 ):
19
32
indent_str = ' ' * indent
20
33
print (f"{ indent_str } { name } :" )
@@ -39,6 +52,7 @@ def print_model_structure(model):
39
52
if hasattr (module , 'bias' ) and module .bias is not None :
40
53
print (f" Bias shape: { module .bias .shape } " )
41
54
55
+ # Adjusted export_to_onnx function
42
56
def export_to_onnx (model , x_current , x_reference , file_name ):
43
57
try :
44
58
print ("Model structure before tracing:" )
@@ -67,21 +81,21 @@ def export_to_onnx(model, x_current, x_reference, file_name):
67
81
torch .onnx .export (
68
82
model ,
69
83
(x_current , x_reference ),
70
- "imf_model.onnx" ,
84
+ file_name ,
71
85
export_params = True ,
72
86
opset_version = 11 ,
73
87
do_constant_folding = True ,
74
88
input_names = ['x_current' , 'x_reference' ],
75
- output_names = ['output' ],
89
+ output_names = ['f_r' , 't_r' , 't_c' ], # Adjusted output names
76
90
dynamic_axes = {
77
91
'x_current' : {0 : 'batch_size' },
78
92
'x_reference' : {0 : 'batch_size' },
79
- 'output' : {0 : 'batch_size' }
93
+ 'f_r' : {0 : 'batch_size' },
94
+ 't_r' : {0 : 'batch_size' },
95
+ 't_c' : {0 : 'batch_size' }
80
96
},
81
97
verbose = True
82
98
)
83
-
84
-
85
99
print (f"Model exported successfully to { file_name } " )
86
100
except Exception as e :
87
101
print (f"Error during ONNX export: { str (e )} " )
@@ -93,16 +107,27 @@ def export_to_onnx(model, x_current, x_reference, file_name):
93
107
model .eval ()
94
108
95
109
# Load the checkpoint
96
- checkpoint = torch .load ("./checkpoints/checkpoint.pth" , map_location = lambda storage , loc : storage )
97
- state_dict = checkpoint ['model_state_dict' ]
98
-
99
- # # Adjust the weights in the state_dict
100
- # for key in state_dict.keys():
101
- # if 'csonv.weight' in key and state_dict[key].dim() == 5:
102
- # state_dict[key] = state_dict[key].squeeze(0)
103
- # Create dummy input tensors
104
- x_current = torch .randn (1 , 3 , 256 , 256 )
105
- x_reference = torch .randn (1 , 3 , 256 , 256 )
110
+ checkpoint = torch .load ("./checkpoints/checkpoint.pth" , map_location = 'cpu' )
111
+ model .load_state_dict (checkpoint ['model_state_dict' ])
112
+
113
+ # Create the IMFEncoder instance
114
+ encoder_model = IMFEncoder (model )
115
+ encoder_model .eval ()
116
+
117
+ # Load images and preprocess
118
+ def load_image (image_path ):
119
+ transform = transforms .Compose ([
120
+ transforms .Resize ((256 , 256 )), # Adjust as per your model's requirements
121
+ transforms .ToTensor (),
122
+ transforms .Normalize (mean = [0.485 , 0.456 , 0.406 ], # Adjust as per your model's requirements
123
+ std = [0.229 , 0.224 , 0.225 ])
124
+ ])
125
+ image = Image .open (image_path ).convert ('RGB' )
126
+ image = transform (image ).unsqueeze (0 ) # Add batch dimension
127
+ return image
128
+
129
+ x_current = load_image ("x_current.png" )
130
+ x_reference = load_image ("x_reference.png" )
106
131
107
132
# Export the model
108
- export_to_onnx (model , x_current , x_reference , "imf_model .onnx" )
133
+ export_to_onnx (encoder_model , x_current , x_reference , "imf_encoder .onnx" )
0 commit comments