Skip to content

Commit 088b208

Browse files
committed
onnx validation
1 parent 54a159b commit 088b208

File tree

1 file changed

+47
-7
lines changed

1 file changed

+47
-7
lines changed

onnxconv.py

Lines changed: 47 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from onnxconverter_common import float16
88
import onnx
99
import numpy as np
10+
from onnx import shape_inference
11+
import onnxruntime as ort
1012

1113
class IMFDecoder(nn.Module):
1214
def __init__(self, model):
@@ -118,7 +120,6 @@ def convert_model_to_32bit(model, output_path):
118120
onnx.save(model, output_path)
119121
print(f"Converted model saved to {output_path}")
120122

121-
122123
def export_to_onnx(model, x_current, x_reference, file_name):
123124
try:
124125
print("Model structure before tracing:")
@@ -167,18 +168,57 @@ def export_to_onnx(model, x_current, x_reference, file_name):
167168
# Load the ONNX model
168169
onnx_model = onnx.load(file_name)
169170

170-
# Convert int64 to int32
171-
print("Converting int64 to int32...")
171+
# Check the model
172+
print("\nChecking the model...")
173+
onnx.checker.check_model(onnx_model)
174+
print("Model checked successfully")
175+
176+
# Print model input and output shapes
177+
print("\nModel Input and Output Shapes:")
178+
for input in onnx_model.graph.input:
179+
print(f"Input: {input.name}, Shape: {[dim.dim_value for dim in input.type.tensor_type.shape.dim]}")
180+
for output in onnx_model.graph.output:
181+
print(f"Output: {output.name}, Shape: {[dim.dim_value for dim in output.type.tensor_type.shape.dim]}")
182+
183+
# Perform shape inference
184+
print("\nPerforming shape inference...")
185+
inferred_model = shape_inference.infer_shapes(onnx_model)
186+
onnx.save(inferred_model, file_name)
187+
print("Shape inference completed and model saved")
188+
189+
# Convert int64 to int32
190+
print("\nConverting int64 to int32...")
172191
web_compatible_file = file_name.replace('.onnx', '_web.onnx')
173-
onnx_model = convert_model_to_32bit(onnx_model,web_compatible_file)
174-
175-
192+
convert_model_to_32bit(onnx_model, web_compatible_file)
193+
194+
# Validate the converted model
195+
print("\nValidating the converted model...")
196+
onnx.checker.check_model(onnx.load(web_compatible_file))
197+
print("Converted model validated successfully")
198+
199+
# Test the model with ONNX Runtime
200+
print("\nTesting the model with ONNX Runtime...")
201+
ort_session = ort.InferenceSession(web_compatible_file)
202+
203+
# Prepare inputs (assuming x_current and x_reference are PyTorch tensors)
204+
ort_inputs = {
205+
'x_current': x_current.numpy(),
206+
'x_reference': x_reference.numpy()
207+
}
208+
209+
# Run inference
210+
ort_outputs = ort_session.run(None, ort_inputs)
211+
print("ONNX Runtime inference successful")
212+
213+
print(f"\nConverted and validated model saved to {web_compatible_file}")
214+
print("This model should now be compatible with WONNX")
176215

177216
except Exception as e:
178-
print(f"Error during ONNX export: {str(e)}")
217+
print(f"Error during ONNX export and validation: {str(e)}")
179218
import traceback
180219
traceback.print_exc()
181220

221+
182222
# Load your model
183223
model = IMFModel()
184224
model.eval()

0 commit comments

Comments
 (0)