1
+ import torch
2
+ import torch .onnx
3
+ from IMF .model import IMFModel
4
+
5
+
6
+
7
+ def trace_handler (module , input , output ):
8
+ print (f"\n Module: { module .__class__ .__name__ } " )
9
+ for idx , inp in enumerate (input ):
10
+ print_tensor_info (inp , f" Input[{ idx } ]" )
11
+ if isinstance (output , torch .Tensor ):
12
+ print_tensor_info (output , " Output" )
13
+ else :
14
+ for idx , out in enumerate (output ):
15
+ print_tensor_info (out , f" Output[{ idx } ]" )
16
+
17
+
18
+ def print_tensor_info (tensor , name , indent = 0 ):
19
+ indent_str = ' ' * indent
20
+ print (f"{ indent_str } { name } :" )
21
+ if isinstance (tensor , torch .Tensor ):
22
+ print (f"{ indent_str } Shape: { tensor .shape } " )
23
+ print (f"{ indent_str } Dtype: { tensor .dtype } " )
24
+ print (f"{ indent_str } Device: { tensor .device } " )
25
+ print (f"{ indent_str } Requires grad: { tensor .requires_grad } " )
26
+ elif isinstance (tensor , (list , tuple )):
27
+ print (f"{ indent_str } Type: { type (tensor ).__name__ } , Length: { len (tensor )} " )
28
+ for idx , item in enumerate (tensor ):
29
+ print_tensor_info (item , f"{ name } [{ idx } ]" , indent = indent + 2 )
30
+ else :
31
+ print (f"{ indent_str } Type: { type (tensor ).__name__ } " )
32
+
33
+ def print_model_structure (model ):
34
+ print ("Model Structure:" )
35
+ for name , module in model .named_modules ():
36
+ print (f"{ name } : { module .__class__ .__name__ } " )
37
+ if hasattr (module , 'weight' ):
38
+ print (f" Weight shape: { module .weight .shape } " )
39
+ if hasattr (module , 'bias' ) and module .bias is not None :
40
+ print (f" Bias shape: { module .bias .shape } " )
41
+
42
+ def export_to_onnx (model , x_current , x_reference , file_name ):
43
+ try :
44
+ print ("Model structure before tracing:" )
45
+ print_model_structure (model )
46
+
47
+ print ("\n Input tensor information:" )
48
+ print_tensor_info (x_current , "x_current" )
49
+ print_tensor_info (x_reference , "x_reference" )
50
+
51
+ hooks = []
52
+ for name , module in model .named_modules ():
53
+ hooks .append (module .register_forward_hook (trace_handler ))
54
+
55
+ # Use torch.jit.trace to create a traced version of the model
56
+ print ("\n Tracing model..." )
57
+ traced_model = torch .jit .trace (model , (x_current , x_reference ))
58
+ print ("Model traced successfully" )
59
+
60
+ for hook in hooks :
61
+ hook .remove ()
62
+
63
+ print ("\n Model structure after tracing:" )
64
+ print_model_structure (traced_model )
65
+
66
+ print ("\n Exporting to ONNX..." )
67
+ torch .onnx .export (
68
+ model ,
69
+ (x_current , x_reference ),
70
+ "imf_model.onnx" ,
71
+ export_params = True ,
72
+ opset_version = 11 ,
73
+ do_constant_folding = True ,
74
+ input_names = ['x_current' , 'x_reference' ],
75
+ output_names = ['output' ],
76
+ dynamic_axes = {
77
+ 'x_current' : {0 : 'batch_size' },
78
+ 'x_reference' : {0 : 'batch_size' },
79
+ 'output' : {0 : 'batch_size' }
80
+ },
81
+ verbose = True
82
+ )
83
+
84
+
85
+ print (f"Model exported successfully to { file_name } " )
86
+ except Exception as e :
87
+ print (f"Error during ONNX export: { str (e )} " )
88
+ import traceback
89
+ traceback .print_exc ()
90
+
91
+ # Load your model
92
+ model = IMFModel ()
93
+ model .eval ()
94
+
95
+ # Load the checkpoint
96
+ checkpoint = torch .load ("./checkpoints/checkpoint.pth" , map_location = lambda storage , loc : storage )
97
+ state_dict = checkpoint ['model_state_dict' ]
98
+
99
+ # # Adjust the weights in the state_dict
100
+ # for key in state_dict.keys():
101
+ # if 'csonv.weight' in key and state_dict[key].dim() == 5:
102
+ # state_dict[key] = state_dict[key].squeeze(0)
103
+ # Create dummy input tensors
104
+ x_current = torch .randn (1 , 3 , 256 , 256 )
105
+ x_reference = torch .randn (1 , 3 , 256 , 256 )
106
+
107
+ # Export the model
108
+ export_to_onnx (model , x_current , x_reference , "imf_model.onnx" )
0 commit comments