Skip to content

Commit c6eb4da

Browse files
committed
onnx
1 parent 50c7418 commit c6eb4da

File tree

3 files changed

+109
-1
lines changed

3 files changed

+109
-1
lines changed

M2Ohb0FAaJU_1.mp4

440 KB
Binary file not shown.

configs/inference.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ checkpoint_path: "./checkpoints/checkpoint.pth"
88

99
input:
1010
# For video processing
11-
video_path: "face.mp4"
11+
video_path: "M2Ohb0FAaJU_1.mp4"
1212
frame_skip: 0
1313

1414
# For single frame processing

onnxconv.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
import torch
2+
import torch.onnx
3+
from IMF.model import IMFModel
4+
5+
6+
7+
def trace_handler(module, input, output):
8+
print(f"\nModule: {module.__class__.__name__}")
9+
for idx, inp in enumerate(input):
10+
print_tensor_info(inp, f" Input[{idx}]")
11+
if isinstance(output, torch.Tensor):
12+
print_tensor_info(output, " Output")
13+
else:
14+
for idx, out in enumerate(output):
15+
print_tensor_info(out, f" Output[{idx}]")
16+
17+
18+
def print_tensor_info(tensor, name, indent=0):
19+
indent_str = ' ' * indent
20+
print(f"{indent_str}{name}:")
21+
if isinstance(tensor, torch.Tensor):
22+
print(f"{indent_str} Shape: {tensor.shape}")
23+
print(f"{indent_str} Dtype: {tensor.dtype}")
24+
print(f"{indent_str} Device: {tensor.device}")
25+
print(f"{indent_str} Requires grad: {tensor.requires_grad}")
26+
elif isinstance(tensor, (list, tuple)):
27+
print(f"{indent_str} Type: {type(tensor).__name__}, Length: {len(tensor)}")
28+
for idx, item in enumerate(tensor):
29+
print_tensor_info(item, f"{name}[{idx}]", indent=indent + 2)
30+
else:
31+
print(f"{indent_str} Type: {type(tensor).__name__}")
32+
33+
def print_model_structure(model):
34+
print("Model Structure:")
35+
for name, module in model.named_modules():
36+
print(f"{name}: {module.__class__.__name__}")
37+
if hasattr(module, 'weight'):
38+
print(f" Weight shape: {module.weight.shape}")
39+
if hasattr(module, 'bias') and module.bias is not None:
40+
print(f" Bias shape: {module.bias.shape}")
41+
42+
def export_to_onnx(model, x_current, x_reference, file_name):
43+
try:
44+
print("Model structure before tracing:")
45+
print_model_structure(model)
46+
47+
print("\nInput tensor information:")
48+
print_tensor_info(x_current, "x_current")
49+
print_tensor_info(x_reference, "x_reference")
50+
51+
hooks = []
52+
for name, module in model.named_modules():
53+
hooks.append(module.register_forward_hook(trace_handler))
54+
55+
# Use torch.jit.trace to create a traced version of the model
56+
print("\nTracing model...")
57+
traced_model = torch.jit.trace(model, (x_current, x_reference))
58+
print("Model traced successfully")
59+
60+
for hook in hooks:
61+
hook.remove()
62+
63+
print("\nModel structure after tracing:")
64+
print_model_structure(traced_model)
65+
66+
print("\nExporting to ONNX...")
67+
torch.onnx.export(
68+
model,
69+
(x_current, x_reference),
70+
"imf_model.onnx",
71+
export_params=True,
72+
opset_version=11,
73+
do_constant_folding=True,
74+
input_names=['x_current', 'x_reference'],
75+
output_names=['output'],
76+
dynamic_axes={
77+
'x_current': {0: 'batch_size'},
78+
'x_reference': {0: 'batch_size'},
79+
'output': {0: 'batch_size'}
80+
},
81+
verbose=True
82+
)
83+
84+
85+
print(f"Model exported successfully to {file_name}")
86+
except Exception as e:
87+
print(f"Error during ONNX export: {str(e)}")
88+
import traceback
89+
traceback.print_exc()
90+
91+
# Load your model
92+
model = IMFModel()
93+
model.eval()
94+
95+
# Load the checkpoint
96+
checkpoint = torch.load("./checkpoints/checkpoint.pth", map_location=lambda storage, loc: storage)
97+
state_dict = checkpoint['model_state_dict']
98+
99+
# # Adjust the weights in the state_dict
100+
# for key in state_dict.keys():
101+
# if 'csonv.weight' in key and state_dict[key].dim() == 5:
102+
# state_dict[key] = state_dict[key].squeeze(0)
103+
# Create dummy input tensors
104+
x_current = torch.randn(1, 3, 256, 256)
105+
x_reference = torch.randn(1, 3, 256, 256)
106+
107+
# Export the model
108+
export_to_onnx(model, x_current, x_reference, "imf_model.onnx")

0 commit comments

Comments
 (0)