|
7 | 7 | from onnxconverter_common import float16
|
8 | 8 | import onnx
|
9 | 9 | import numpy as np
|
| 10 | +from onnx import shape_inference |
| 11 | +import onnxruntime as ort |
10 | 12 |
|
11 | 13 | class IMFDecoder(nn.Module):
|
12 | 14 | def __init__(self, model):
|
@@ -118,7 +120,6 @@ def convert_model_to_32bit(model, output_path):
|
118 | 120 | onnx.save(model, output_path)
|
119 | 121 | print(f"Converted model saved to {output_path}")
|
120 | 122 |
|
121 |
| - |
122 | 123 | def export_to_onnx(model, x_current, x_reference, file_name):
|
123 | 124 | try:
|
124 | 125 | print("Model structure before tracing:")
|
@@ -167,18 +168,57 @@ def export_to_onnx(model, x_current, x_reference, file_name):
|
167 | 168 | # Load the ONNX model
|
168 | 169 | onnx_model = onnx.load(file_name)
|
169 | 170 |
|
170 |
| - # Convert int64 to int32 |
171 |
| - print("Converting int64 to int32...") |
| 171 | + # Check the model |
| 172 | + print("\nChecking the model...") |
| 173 | + onnx.checker.check_model(onnx_model) |
| 174 | + print("Model checked successfully") |
| 175 | + |
| 176 | + # Print model input and output shapes |
| 177 | + print("\nModel Input and Output Shapes:") |
| 178 | + for input in onnx_model.graph.input: |
| 179 | + print(f"Input: {input.name}, Shape: {[dim.dim_value for dim in input.type.tensor_type.shape.dim]}") |
| 180 | + for output in onnx_model.graph.output: |
| 181 | + print(f"Output: {output.name}, Shape: {[dim.dim_value for dim in output.type.tensor_type.shape.dim]}") |
| 182 | + |
| 183 | + # Perform shape inference |
| 184 | + print("\nPerforming shape inference...") |
| 185 | + inferred_model = shape_inference.infer_shapes(onnx_model) |
| 186 | + onnx.save(inferred_model, file_name) |
| 187 | + print("Shape inference completed and model saved") |
| 188 | + |
| 189 | + # Convert int64 to int32 |
| 190 | + print("\nConverting int64 to int32...") |
172 | 191 | web_compatible_file = file_name.replace('.onnx', '_web.onnx')
|
173 |
| - onnx_model = convert_model_to_32bit(onnx_model,web_compatible_file) |
174 |
| - |
175 |
| - |
| 192 | + convert_model_to_32bit(onnx_model, web_compatible_file) |
| 193 | + |
| 194 | + # Validate the converted model |
| 195 | + print("\nValidating the converted model...") |
| 196 | + onnx.checker.check_model(onnx.load(web_compatible_file)) |
| 197 | + print("Converted model validated successfully") |
| 198 | + |
| 199 | + # Test the model with ONNX Runtime |
| 200 | + print("\nTesting the model with ONNX Runtime...") |
| 201 | + ort_session = ort.InferenceSession(web_compatible_file) |
| 202 | + |
| 203 | + # Prepare inputs (assuming x_current and x_reference are PyTorch tensors) |
| 204 | + ort_inputs = { |
| 205 | + 'x_current': x_current.numpy(), |
| 206 | + 'x_reference': x_reference.numpy() |
| 207 | + } |
| 208 | + |
| 209 | + # Run inference |
| 210 | + ort_outputs = ort_session.run(None, ort_inputs) |
| 211 | + print("ONNX Runtime inference successful") |
| 212 | + |
| 213 | + print(f"\nConverted and validated model saved to {web_compatible_file}") |
| 214 | + print("This model should now be compatible with WONNX") |
176 | 215 |
|
177 | 216 | except Exception as e:
|
178 |
| - print(f"Error during ONNX export: {str(e)}") |
| 217 | + print(f"Error during ONNX export and validation: {str(e)}") |
179 | 218 | import traceback
|
180 | 219 | traceback.print_exc()
|
181 | 220 |
|
| 221 | + |
182 | 222 | # Load your model
|
183 | 223 | model = IMFModel()
|
184 | 224 | model.eval()
|
|
0 commit comments