Skip to content

Commit cfc1ee9

Browse files
committed
ok
1 parent 7940ef1 commit cfc1ee9

File tree

3 files changed

+44
-7
lines changed

3 files changed

+44
-7
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,4 @@ __pycache__/vit.cpython-311.pyc
1111
__pycache__/helper.cpython-311.pyc
1212
actions-runner/*
1313
imf_encoder.onnx
14+
imf_encoder_web.onnx

onnxconv.py

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
import torch.nn as nn
55
from PIL import Image
66
from torchvision import transforms
7-
7+
from onnxconverter_common import float16
8+
import onnx
89

910
class IMFDecoder(nn.Module):
1011
def __init__(self, model):
@@ -63,7 +64,22 @@ def print_model_structure(model):
6364
if hasattr(module, 'bias') and module.bias is not None:
6465
print(f" Bias shape: {module.bias.shape}")
6566

66-
# Adjusted export_to_onnx function
67+
68+
def convert_int64_to_int32(model):
69+
for tensor in model.graph.initializer:
70+
if tensor.data_type == onnx.TensorProto.INT64:
71+
tensor.data_type = onnx.TensorProto.INT32
72+
tensor.int64_data = tensor.int64_data.astype(np.int32)
73+
74+
for node in model.graph.node:
75+
for attr in node.attribute:
76+
if attr.type == onnx.AttributeProto.INT:
77+
attr.i = int(attr.i)
78+
elif attr.type == onnx.AttributeProto.INTS:
79+
attr.ints[:] = [int(i) for i in attr.ints]
80+
81+
return model
82+
6783
def export_to_onnx(model, x_current, x_reference, file_name):
6884
try:
6985
print("Model structure before tracing:")
@@ -97,7 +113,7 @@ def export_to_onnx(model, x_current, x_reference, file_name):
97113
opset_version=11,
98114
do_constant_folding=True,
99115
input_names=['x_current', 'x_reference'],
100-
output_names=['f_r', 't_r', 't_c'], # Adjusted output names
116+
output_names=['f_r', 't_r', 't_c'],
101117
dynamic_axes={
102118
'x_current': {0: 'batch_size'},
103119
'x_reference': {0: 'batch_size'},
@@ -108,6 +124,23 @@ def export_to_onnx(model, x_current, x_reference, file_name):
108124
verbose=True
109125
)
110126
print(f"Model exported successfully to {file_name}")
127+
128+
# Load the ONNX model
129+
onnx_model = onnx.load(file_name)
130+
131+
# Convert int64 to int32
132+
print("Converting int64 to int32...")
133+
onnx_model = convert_int64_to_int32(onnx_model)
134+
135+
# Optionally, convert float32 to float16 to reduce model size
136+
print("Converting float32 to float16...")
137+
onnx_model = float16.convert_float_to_float16(onnx_model)
138+
139+
# Save the converted model
140+
web_compatible_file = file_name.replace('.onnx', '_web.onnx')
141+
onnx.save(onnx_model, web_compatible_file)
142+
print(f"Web-compatible model saved as {web_compatible_file}")
143+
111144
except Exception as e:
112145
print(f"Error during ONNX export: {str(e)}")
113146
import traceback
@@ -128,9 +161,9 @@ def export_to_onnx(model, x_current, x_reference, file_name):
128161
# Load images and preprocess
129162
def load_image(image_path):
130163
transform = transforms.Compose([
131-
transforms.Resize((256, 256)), # Adjust as per your model's requirements
164+
transforms.Resize((256, 256)),
132165
transforms.ToTensor(),
133-
transforms.Normalize(mean=[0.485, 0.456, 0.406], # Adjust as per your model's requirements
166+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
134167
std=[0.229, 0.224, 0.225])
135168
])
136169
image = Image.open(image_path).convert('RGB')
@@ -141,4 +174,4 @@ def load_image(image_path):
141174
x_reference = load_image("x_reference.png")
142175

143176
# Export the model
144-
export_to_onnx(encoder_model, x_current, x_reference, "imf_encoder.onnx")
177+
export_to_onnx(encoder_model, x_current, x_reference, "imf_encoder.onnx")

requirements.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,7 @@ onnxruntime
1515
opencv-python
1616
pymatting
1717
decord
18-
mediapipe
18+
mediapipe
19+
20+
onnx
21+
onnxconverter_common

0 commit comments

Comments
 (0)