29
29
30
30
def _prepare_op_inputs (inputs , run_backward , dtype , ctx ):
31
31
kwargs_list = []
32
+ args_list = []
32
33
33
34
for inp in inputs :
34
35
kwargs = {}
35
36
for key , value in inp .items ():
36
- if key in PARAMS_OF_TYPE_NDARRAY :
37
+ if key in PARAMS_OF_TYPE_NDARRAY and key == 'args' :
38
+ args_list .append (get_mx_ndarray (ctx = ctx , in_tensor = value ,
39
+ dtype = dtype ,
40
+ initializer = nd .normal ,
41
+ attach_grad = run_backward ))
42
+ elif key in PARAMS_OF_TYPE_NDARRAY :
37
43
kwargs [key ] = get_mx_ndarray (ctx = ctx , in_tensor = value ,
38
44
dtype = dtype ,
39
45
initializer = nd .normal ,
40
46
attach_grad = run_backward )
41
47
else :
42
48
kwargs [key ] = value
43
49
kwargs_list .append (kwargs )
44
-
45
- return kwargs_list
50
+ return args_list , kwargs_list
46
51
47
52
48
53
def _run_nd_operator_performance_test (op , inputs , run_backward , warmup , runs , kwargs_list , profiler ):
@@ -60,17 +65,28 @@ def _run_nd_operator_performance_test(op, inputs, run_backward, warmup, runs, kw
60
65
raise ValueError ("Incorrect input for profiler. Valid input - 'python' or 'native'" )
61
66
62
67
# Warm up, ignore the profiler output
63
- _ , _ = benchmark_helper_func (op , warmup , ** kwargs_list [0 ])
68
+ if not args_list :
69
+ _ , _ = benchmark_helper_func (op , warmup , [], ** kwargs_list [0 ])
70
+ else :
71
+ _ , _ = benchmark_helper_func (op , warmup , args_list [0 ], ** kwargs_list [0 ])
64
72
65
73
# Run Benchmarks
66
74
op_benchmark_result = {op .__name__ : []}
67
75
logging .info ("Begin Benchmark - {name}" .format (name = op .__name__ ))
68
- for idx , kwargs in enumerate (kwargs_list ):
69
- _ , profiler_output = benchmark_helper_func (op , runs , ** kwargs )
76
+ if not args_list :
77
+ for idx , kwargs in enumerate (kwargs_list ):
78
+ _ , profiler_output = benchmark_helper_func (op , runs , [], ** kwargs )
79
+
80
+ # Add inputs used for profiling this operator into result
81
+ profiler_output ["inputs" ] = inputs [idx ]
82
+ op_benchmark_result [op .__name__ ].append (profiler_output )
83
+ else :
84
+ for idx , (args ,kwargs ) in enumerate (zip (args_list ,kwargs_list )):
85
+ _ , profiler_output = benchmark_helper_func (op , runs , args , ** kwargs )
70
86
71
- # Add inputs used for profiling this operator into result
72
- profiler_output ["inputs" ] = inputs [idx ]
73
- op_benchmark_result [op .__name__ ].append (profiler_output )
87
+ # Add inputs used for profiling this operator into result
88
+ profiler_output ["inputs" ] = inputs [idx ]
89
+ op_benchmark_result [op .__name__ ].append (profiler_output )
74
90
logging .info ("Complete Benchmark - {name}" .format (name = op .__name__ ))
75
91
return op_benchmark_result
76
92
@@ -110,15 +126,15 @@ def run_performance_test(ops, inputs, run_backward=True,
110
126
List of dictionary of benchmark results. key -> name of the operator, Value is benchmark results.
111
127
112
128
"""
113
- kwargs_list = _prepare_op_inputs (inputs , run_backward , dtype , ctx )
129
+ args_list , kwargs_list = _prepare_op_inputs (inputs , run_backward , dtype , ctx )
114
130
115
131
if not isinstance (ops , list ):
116
132
ops = [ops ]
117
133
118
134
op_benchmark_result = []
119
135
for op in ops :
120
136
if hasattr (mx .nd , op .__name__ ):
121
- benchmark_result = _run_nd_operator_performance_test (op , inputs , run_backward , warmup , runs , kwargs_list , profiler )
137
+ benchmark_result = _run_nd_operator_performance_test (op , inputs , run_backward , warmup , runs , args_list , kwargs_list , profiler )
122
138
else :
123
139
raise ValueError ("Unknown NDArray operator provided to benchmark. - " , op .__name__ )
124
140
op_benchmark_result .append (benchmark_result )
0 commit comments