@@ -40,8 +40,8 @@ def print_usage(msg=""):
40
40
"""
41
41
Usage: Report statistics on compiler and runtime characteristics of ONNX ops.
42
42
make-report.py -[vh] [-c <compile_log>] [-r <run_log>] [-l <num>]
43
- [-s <stats>] [--sort <val>] [--supported] [-u <val>] [-p <op regexp >]
44
- [-w <num>]
43
+ [-p <plot_file_name][- s <stats>] [--sort <val>] [--supported] [-u <val>]
44
+ [-f <op regexp>] [- w <num>]
45
45
46
46
Compile-time statistics are collected from a `onnx-mlir` compiler output
47
47
with the `--opt-report` option equal to `Simd` or other supported sub-options.
@@ -57,11 +57,23 @@ def print_usage(msg=""):
57
57
`onnx-mlir --debug-only=lowering-to-krnl` and look at the compiler output.
58
58
Use `-l 3` to correlate the node name printed here with compiler output.
59
59
60
- Parameters:
60
+ Parameters on inputs :
61
61
-c/--compile <file_log>: File name containing the compile-time statistics
62
62
or runtime signature statistics.
63
63
-r/--runtime <file_log>: File name containing the runtime time statistics.
64
64
65
+ Parameters to focus analysis:
66
+ -f/--focus <regexp>: Focus only on ops that match the regexp pattern.
67
+ -m/--min <num>: Focus on operations with at least <num>% of exec time.
68
+ --supported: Focus only on ops that are supported. Namely, the report
69
+ will skip ops for which compile-time statistics list
70
+ the 'unsupported' keyword in its printout.
71
+ For SIMD/parallel statistics, this include all ops that
72
+ have currently no support for it.
73
+ -w/--warmup <num>: If multiple runtime statistics are given, ignore the first
74
+ <num> stats. Default is zero.
75
+
76
+ Parameters on what to print:
65
77
-s/--stats <name>: Print specific statistics:
66
78
simd: Print simd optimization stats.
67
79
Default if a compile time file is given.
@@ -74,19 +86,15 @@ def print_usage(msg=""):
74
86
1: Also count reasons for success/failure.
75
87
2: Also list metrics.
76
88
3: Also list node name.
77
- -f/--focus <regexp>: Focus only on ops that match the regexp pattern.
78
- -m/--min <num>: Focus on operations with at least <num>% of exec time.
79
- -supported: Focus only on ops that are supported. Namely, the report
80
- will skip ops for which compile-time statistics list
81
- the 'unsupported' keyword in its printout.
82
- For SIMD/parallel statistics, this include all ops that
83
- have currently no support for it.
89
+ -p/--plot <name>: Print a "<name>.jpg" plot of the output.
90
+
91
+ Parameters on how to print:
84
92
-u/--unit <str>: Time in second ('s', default), millisecond ('ms') or
85
93
microsecond ('us).
86
94
--sort <str>: Sort output by op 'name', occurrence 'num' or `time`.
95
+
96
+ Help:
87
97
-v/--verbose: Run in verbose mode (see error and warnings).
88
- -w/--warmup <num>: If multiple runtime statistics are given, ignore the first
89
- <num> stats. Default is zero.
90
98
-h/--help: Print usage.
91
99
"""
92
100
)
@@ -117,6 +125,11 @@ def print_usage(msg=""):
117
125
time_unit = 1 # seconds.
118
126
min_percent_reporting = 0.0 # percentage.
119
127
128
+ # For plot info
129
+ plot_names = np .array ([])
130
+ plot_values = np .array ([])
131
+ plot_x_axis = "time"
132
+
120
133
121
134
# Basic pattern for reports: "==" <stat name> "==," <op name> "," <node name> ","
122
135
def common_report_str (stat_name ):
@@ -438,7 +451,7 @@ def get_percent(n, d):
438
451
def get_sorting_key (count , name , time ):
439
452
global sorting_preference
440
453
if sorting_preference == "num" :
441
- key = - count
454
+ return - count
442
455
if sorting_preference == "name" :
443
456
return name
444
457
return - time
@@ -450,18 +463,25 @@ def make_report(stat_message):
450
463
global has_timing , time_unit , error_missing_time
451
464
global report_level , supported_only , verbose , min_percent_reporting
452
465
global sorting_preference
466
+ global plot_names , plot_values , plot_x_axis
453
467
454
468
# Gather statistics in a dictionary so that we may sort the entries.
455
469
sorted_output = {}
470
+ sorted_plot_output = {}
456
471
for op in op_count_dict :
457
472
count = op_count_dict [op ]
458
473
count_time_str = str (count )
459
474
time = 0
475
+ plot_name_val = []
460
476
if op in op_time_dict :
461
477
time = np .sum (op_time_dict [op ])
462
478
count_time_str += ", {:.7f}" .format (time * time_unit / count )
463
479
count_time_str += ", {:.7f}" .format (time * time_unit )
464
480
count_time_str += ", {:.1f}%" .format (get_percent (time , tot_time ))
481
+ # Keep plot_name -> plot_val in sorted_plot_output.
482
+ plot_name = "{} ({})" .format (op , count )
483
+ plot_val = time * time_unit
484
+ plot_name_val = [plot_name , plot_val ]
465
485
if get_percent (time , tot_time ) < min_percent_reporting :
466
486
continue
467
487
output = " " + op + ", " + count_time_str
@@ -495,22 +515,34 @@ def make_report(stat_message):
495
515
for key in sorted (sorted_det_output ):
496
516
output += sorted_det_output [key ]
497
517
498
- # add output to sorted_output
518
+ # Add output to sorted_output.
499
519
output_key = get_sorting_key (count , op , time )
500
520
if output_key in sorted_output :
501
521
sorted_output [output_key ] += "\n " + output
502
522
else :
503
523
sorted_output [output_key ] = output
524
+ # Add plot name/val tuple to sorted_plot_output.
525
+ if len (plot_name_val ) == 2 :
526
+ if output_key in sorted_plot_output :
527
+ curr_list = sorted_plot_output [output_key ]
528
+ curr_list .append (plot_name_val )
529
+ sorted_plot_output [output_key ] = curr_list
530
+ else :
531
+ sorted_plot_output [output_key ] = [plot_name_val ]
504
532
505
533
# Print legend and stats.
506
534
num_desc = "num"
507
535
if has_timing :
508
536
if time_unit == 1 :
509
- unit_str = "(s) "
537
+ unit_str = "s "
510
538
elif time_unit == 1000 :
511
- unit_str = "(ms) "
539
+ unit_str = "ms "
512
540
elif time_unit == 1000 * 1000 :
513
- unit_str = "(us)"
541
+ unit_str = "us"
542
+ plot_x_axis = "execution time ({}) out of total time of {:.2f}{}" .format (
543
+ unit_str , tot_time * time_unit , unit_str
544
+ )
545
+ unit_str = "(" + unit_str + ")"
514
546
num_desc += ", average time " + unit_str
515
547
num_desc += ", cumulative time " + unit_str
516
548
num_desc += ", percent of total "
@@ -535,6 +567,10 @@ def make_report(stat_message):
535
567
print ("Statistics start" + stat_details )
536
568
for key in sorted (sorted_output ):
537
569
print (sorted_output [key ])
570
+ if key in sorted_plot_output :
571
+ for t in sorted_plot_output [key ]:
572
+ plot_names = np .append (plot_names , t [0 ])
573
+ plot_values = np .append (plot_values , t [1 ])
538
574
print ("Statistics end" + stat_details )
539
575
540
576
# Report spurious node name if any.
@@ -543,6 +579,42 @@ def make_report(stat_message):
543
579
print ("> Run with `-v` for detailed list of errors." )
544
580
545
581
582
+ ################################################################################
583
+ # Print plot.
584
+
585
+
586
+ def output_plot (runtime_file_name , plot_file_name ):
587
+ global plot_names , plot_values , plot_x_axis
588
+ if len (plot_names ) == 0 :
589
+ print ("\n > No info to plot, skip" )
590
+ return
591
+ try :
592
+ import matplotlib .pyplot as plt
593
+
594
+ # Create the horizontal bar graph.
595
+ fig , ax = plt .subplots ()
596
+ bars = ax .barh (np .flip (plot_names ), np .flip (plot_values ), color = "blue" )
597
+
598
+ # Adding the data values next to the bars.
599
+ for bar in bars :
600
+ width = bar .get_width ()
601
+ text = " {:.1f}" .format (width )
602
+ ax .text (width , bar .get_y () + bar .get_height () / 2 , text , va = "center" )
603
+
604
+ # Setting the axes limits to make room for annotations.
605
+ ax .set_xlim (0 , max (plot_values ) * 1.2 )
606
+
607
+ # Adjust layout to make room for the table:
608
+ # plt.subplots_adjust(left=0.1)
609
+ plt .xlabel (plot_x_axis )
610
+ plt .title ("execution time summary for " + runtime_file_name )
611
+ output_file_name = "{}.jpg" .format (plot_file_name )
612
+ plt .savefig (output_file_name , bbox_inches = "tight" )
613
+ print ('\n > output plot printed in "{}"' .format (output_file_name ))
614
+ except ImportError :
615
+ print ("\n > Could not import mathplotlib, please add if you want to plot." )
616
+
617
+
546
618
################################################################################
547
619
# Main.
548
620
@@ -554,19 +626,21 @@ def main(argv):
554
626
555
627
compile_file_name = ""
556
628
runtime_file_name = ""
629
+ plot_file_name = ""
557
630
make_stats = ""
558
631
make_legend = ""
559
632
warmup_num = 0
560
633
try :
561
634
opts , args = getopt .getopt (
562
635
argv ,
563
- "c:f:hl:m:r:s:u:vw:" ,
636
+ "c:f:hl:m:p: r:s:u:vw:" ,
564
637
[
565
638
"compile=" ,
566
639
"focus=" ,
567
640
"help" ,
568
641
"level=" ,
569
642
"min=" ,
643
+ "plot=" ,
570
644
"runtime=" ,
571
645
"stats=" ,
572
646
"sort=" ,
@@ -593,6 +667,8 @@ def main(argv):
593
667
print_usage ("detail levels are 0, 1, 2, or 3" )
594
668
elif opt in ("-m" , "--min" ):
595
669
min_percent_reporting = float (arg )
670
+ elif opt in ("-p" , "--plot" ):
671
+ plot_file_name = arg
596
672
elif opt in ("-r" , "--runtime" ):
597
673
runtime_file_name = arg
598
674
elif opt in ("-s" , "--stats" ):
@@ -675,6 +751,9 @@ def main(argv):
675
751
else :
676
752
print_usage ("Command requires an input file name (compile/runtime or both).\n " )
677
753
754
+ if plot_file_name :
755
+ output_plot (runtime_file_name , plot_file_name )
756
+
678
757
679
758
if __name__ == "__main__" :
680
759
main (sys .argv [1 :])
0 commit comments