4
4
import torch .nn as nn
5
5
from PIL import Image
6
6
from torchvision import transforms
7
-
7
+ from onnxconverter_common import float16
8
+ import onnx
8
9
9
10
class IMFDecoder (nn .Module ):
10
11
def __init__ (self , model ):
@@ -63,7 +64,22 @@ def print_model_structure(model):
63
64
if hasattr (module , 'bias' ) and module .bias is not None :
64
65
print (f" Bias shape: { module .bias .shape } " )
65
66
66
- # Adjusted export_to_onnx function
67
+
68
+ def convert_int64_to_int32 (model ):
69
+ for tensor in model .graph .initializer :
70
+ if tensor .data_type == onnx .TensorProto .INT64 :
71
+ tensor .data_type = onnx .TensorProto .INT32
72
+ tensor .int64_data = tensor .int64_data .astype (np .int32 )
73
+
74
+ for node in model .graph .node :
75
+ for attr in node .attribute :
76
+ if attr .type == onnx .AttributeProto .INT :
77
+ attr .i = int (attr .i )
78
+ elif attr .type == onnx .AttributeProto .INTS :
79
+ attr .ints [:] = [int (i ) for i in attr .ints ]
80
+
81
+ return model
82
+
67
83
def export_to_onnx (model , x_current , x_reference , file_name ):
68
84
try :
69
85
print ("Model structure before tracing:" )
@@ -97,7 +113,7 @@ def export_to_onnx(model, x_current, x_reference, file_name):
97
113
opset_version = 11 ,
98
114
do_constant_folding = True ,
99
115
input_names = ['x_current' , 'x_reference' ],
100
- output_names = ['f_r' , 't_r' , 't_c' ], # Adjusted output names
116
+ output_names = ['f_r' , 't_r' , 't_c' ],
101
117
dynamic_axes = {
102
118
'x_current' : {0 : 'batch_size' },
103
119
'x_reference' : {0 : 'batch_size' },
@@ -108,6 +124,23 @@ def export_to_onnx(model, x_current, x_reference, file_name):
108
124
verbose = True
109
125
)
110
126
print (f"Model exported successfully to { file_name } " )
127
+
128
+ # Load the ONNX model
129
+ onnx_model = onnx .load (file_name )
130
+
131
+ # Convert int64 to int32
132
+ print ("Converting int64 to int32..." )
133
+ onnx_model = convert_int64_to_int32 (onnx_model )
134
+
135
+ # Optionally, convert float32 to float16 to reduce model size
136
+ print ("Converting float32 to float16..." )
137
+ onnx_model = float16 .convert_float_to_float16 (onnx_model )
138
+
139
+ # Save the converted model
140
+ web_compatible_file = file_name .replace ('.onnx' , '_web.onnx' )
141
+ onnx .save (onnx_model , web_compatible_file )
142
+ print (f"Web-compatible model saved as { web_compatible_file } " )
143
+
111
144
except Exception as e :
112
145
print (f"Error during ONNX export: { str (e )} " )
113
146
import traceback
@@ -128,9 +161,9 @@ def export_to_onnx(model, x_current, x_reference, file_name):
128
161
# Load images and preprocess
129
162
def load_image (image_path ):
130
163
transform = transforms .Compose ([
131
- transforms .Resize ((256 , 256 )), # Adjust as per your model's requirements
164
+ transforms .Resize ((256 , 256 )),
132
165
transforms .ToTensor (),
133
- transforms .Normalize (mean = [0.485 , 0.456 , 0.406 ], # Adjust as per your model's requirements
166
+ transforms .Normalize (mean = [0.485 , 0.456 , 0.406 ],
134
167
std = [0.229 , 0.224 , 0.225 ])
135
168
])
136
169
image = Image .open (image_path ).convert ('RGB' )
@@ -141,4 +174,4 @@ def load_image(image_path):
141
174
x_reference = load_image ("x_reference.png" )
142
175
143
176
# Export the model
144
- export_to_onnx (encoder_model , x_current , x_reference , "imf_encoder.onnx" )
177
+ export_to_onnx (encoder_model , x_current , x_reference , "imf_encoder.onnx" )
0 commit comments