Skip to content

Commit 2dd1d31

Browse files
authored
Merge branch 'main' into ferdinand.update_llvm_april_2024
2 parents dea4e10 + 1a0fb56 commit 2dd1d31

File tree

1 file changed

+97
-18
lines changed

1 file changed

+97
-18
lines changed

utils/make-report.py

Lines changed: 97 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ def print_usage(msg=""):
4040
"""
4141
Usage: Report statistics on compiler and runtime characteristics of ONNX ops.
4242
make-report.py -[vh] [-c <compile_log>] [-r <run_log>] [-l <num>]
43-
[-s <stats>] [--sort <val>] [--supported] [-u <val>] [-p <op regexp>]
44-
[-w <num>]
43+
[-p <plot_file_name][-s <stats>] [--sort <val>] [--supported] [-u <val>]
44+
[-f <op regexp>] [-w <num>]
4545
4646
Compile-time statistics are collected from a `onnx-mlir` compiler output
4747
with the `--opt-report` option equal to `Simd` or other supported sub-options.
@@ -57,11 +57,23 @@ def print_usage(msg=""):
5757
`onnx-mlir --debug-only=lowering-to-krnl` and look at the compiler output.
5858
Use `-l 3` to correlate the node name printed here with compiler output.
5959
60-
Parameters:
60+
Parameters on inputs:
6161
-c/--compile <file_log>: File name containing the compile-time statistics
6262
or runtime signature statistics.
6363
-r/--runtime <file_log>: File name containing the runtime time statistics.
6464
65+
Parameters to focus analysis:
66+
-f/--focus <regexp>: Focus only on ops that match the regexp pattern.
67+
-m/--min <num>: Focus on operations with at least <num>% of exec time.
68+
--supported: Focus only on ops that are supported. Namely, the report
69+
will skip ops for which compile-time statistics list
70+
the 'unsupported' keyword in its printout.
71+
For SIMD/parallel statistics, this include all ops that
72+
have currently no support for it.
73+
-w/--warmup <num>: If multiple runtime statistics are given, ignore the first
74+
<num> stats. Default is zero.
75+
76+
Parameters on what to print:
6577
-s/--stats <name>: Print specific statistics:
6678
simd: Print simd optimization stats.
6779
Default if a compile time file is given.
@@ -74,19 +86,15 @@ def print_usage(msg=""):
7486
1: Also count reasons for success/failure.
7587
2: Also list metrics.
7688
3: Also list node name.
77-
-f/--focus <regexp>: Focus only on ops that match the regexp pattern.
78-
-m/--min <num>: Focus on operations with at least <num>% of exec time.
79-
-supported: Focus only on ops that are supported. Namely, the report
80-
will skip ops for which compile-time statistics list
81-
the 'unsupported' keyword in its printout.
82-
For SIMD/parallel statistics, this include all ops that
83-
have currently no support for it.
89+
-p/--plot <name>: Print a "<name>.jpg" plot of the output.
90+
91+
Parameters on how to print:
8492
-u/--unit <str>: Time in second ('s', default), millisecond ('ms') or
8593
microsecond ('us).
8694
--sort <str>: Sort output by op 'name', occurrence 'num' or `time`.
95+
96+
Help:
8797
-v/--verbose: Run in verbose mode (see error and warnings).
88-
-w/--warmup <num>: If multiple runtime statistics are given, ignore the first
89-
<num> stats. Default is zero.
9098
-h/--help: Print usage.
9199
"""
92100
)
@@ -117,6 +125,11 @@ def print_usage(msg=""):
117125
time_unit = 1 # seconds.
118126
min_percent_reporting = 0.0 # percentage.
119127

128+
# For plot info
129+
plot_names = np.array([])
130+
plot_values = np.array([])
131+
plot_x_axis = "time"
132+
120133

121134
# Basic pattern for reports: "==" <stat name> "==," <op name> "," <node name> ","
122135
def common_report_str(stat_name):
@@ -438,7 +451,7 @@ def get_percent(n, d):
438451
def get_sorting_key(count, name, time):
439452
global sorting_preference
440453
if sorting_preference == "num":
441-
key = -count
454+
return -count
442455
if sorting_preference == "name":
443456
return name
444457
return -time
@@ -450,18 +463,25 @@ def make_report(stat_message):
450463
global has_timing, time_unit, error_missing_time
451464
global report_level, supported_only, verbose, min_percent_reporting
452465
global sorting_preference
466+
global plot_names, plot_values, plot_x_axis
453467

454468
# Gather statistics in a dictionary so that we may sort the entries.
455469
sorted_output = {}
470+
sorted_plot_output = {}
456471
for op in op_count_dict:
457472
count = op_count_dict[op]
458473
count_time_str = str(count)
459474
time = 0
475+
plot_name_val = []
460476
if op in op_time_dict:
461477
time = np.sum(op_time_dict[op])
462478
count_time_str += ", {:.7f}".format(time * time_unit / count)
463479
count_time_str += ", {:.7f}".format(time * time_unit)
464480
count_time_str += ", {:.1f}%".format(get_percent(time, tot_time))
481+
# Keep plot_name -> plot_val in sorted_plot_output.
482+
plot_name = "{} ({})".format(op, count)
483+
plot_val = time * time_unit
484+
plot_name_val = [plot_name, plot_val]
465485
if get_percent(time, tot_time) < min_percent_reporting:
466486
continue
467487
output = " " + op + ", " + count_time_str
@@ -495,22 +515,34 @@ def make_report(stat_message):
495515
for key in sorted(sorted_det_output):
496516
output += sorted_det_output[key]
497517

498-
# add output to sorted_output
518+
# Add output to sorted_output.
499519
output_key = get_sorting_key(count, op, time)
500520
if output_key in sorted_output:
501521
sorted_output[output_key] += "\n" + output
502522
else:
503523
sorted_output[output_key] = output
524+
# Add plot name/val tuple to sorted_plot_output.
525+
if len(plot_name_val) == 2:
526+
if output_key in sorted_plot_output:
527+
curr_list = sorted_plot_output[output_key]
528+
curr_list.append(plot_name_val)
529+
sorted_plot_output[output_key] = curr_list
530+
else:
531+
sorted_plot_output[output_key] = [plot_name_val]
504532

505533
# Print legend and stats.
506534
num_desc = "num"
507535
if has_timing:
508536
if time_unit == 1:
509-
unit_str = "(s)"
537+
unit_str = "s"
510538
elif time_unit == 1000:
511-
unit_str = "(ms)"
539+
unit_str = "ms"
512540
elif time_unit == 1000 * 1000:
513-
unit_str = "(us)"
541+
unit_str = "us"
542+
plot_x_axis = "execution time ({}) out of total time of {:.2f}{}".format(
543+
unit_str, tot_time * time_unit, unit_str
544+
)
545+
unit_str = "(" + unit_str + ")"
514546
num_desc += ", average time " + unit_str
515547
num_desc += ", cumulative time " + unit_str
516548
num_desc += ", percent of total "
@@ -535,6 +567,10 @@ def make_report(stat_message):
535567
print("Statistics start" + stat_details)
536568
for key in sorted(sorted_output):
537569
print(sorted_output[key])
570+
if key in sorted_plot_output:
571+
for t in sorted_plot_output[key]:
572+
plot_names = np.append(plot_names, t[0])
573+
plot_values = np.append(plot_values, t[1])
538574
print("Statistics end" + stat_details)
539575

540576
# Report spurious node name if any.
@@ -543,6 +579,42 @@ def make_report(stat_message):
543579
print("> Run with `-v` for detailed list of errors.")
544580

545581

582+
################################################################################
583+
# Print plot.
584+
585+
586+
def output_plot(runtime_file_name, plot_file_name):
587+
global plot_names, plot_values, plot_x_axis
588+
if len(plot_names) == 0:
589+
print("\n> No info to plot, skip")
590+
return
591+
try:
592+
import matplotlib.pyplot as plt
593+
594+
# Create the horizontal bar graph.
595+
fig, ax = plt.subplots()
596+
bars = ax.barh(np.flip(plot_names), np.flip(plot_values), color="blue")
597+
598+
# Adding the data values next to the bars.
599+
for bar in bars:
600+
width = bar.get_width()
601+
text = " {:.1f}".format(width)
602+
ax.text(width, bar.get_y() + bar.get_height() / 2, text, va="center")
603+
604+
# Setting the axes limits to make room for annotations.
605+
ax.set_xlim(0, max(plot_values) * 1.2)
606+
607+
# Adjust layout to make room for the table:
608+
# plt.subplots_adjust(left=0.1)
609+
plt.xlabel(plot_x_axis)
610+
plt.title("execution time summary for " + runtime_file_name)
611+
output_file_name = "{}.jpg".format(plot_file_name)
612+
plt.savefig(output_file_name, bbox_inches="tight")
613+
print('\n> output plot printed in "{}"'.format(output_file_name))
614+
except ImportError:
615+
print("\n> Could not import mathplotlib, please add if you want to plot.")
616+
617+
546618
################################################################################
547619
# Main.
548620

@@ -554,19 +626,21 @@ def main(argv):
554626

555627
compile_file_name = ""
556628
runtime_file_name = ""
629+
plot_file_name = ""
557630
make_stats = ""
558631
make_legend = ""
559632
warmup_num = 0
560633
try:
561634
opts, args = getopt.getopt(
562635
argv,
563-
"c:f:hl:m:r:s:u:vw:",
636+
"c:f:hl:m:p:r:s:u:vw:",
564637
[
565638
"compile=",
566639
"focus=",
567640
"help",
568641
"level=",
569642
"min=",
643+
"plot=",
570644
"runtime=",
571645
"stats=",
572646
"sort=",
@@ -593,6 +667,8 @@ def main(argv):
593667
print_usage("detail levels are 0, 1, 2, or 3")
594668
elif opt in ("-m", "--min"):
595669
min_percent_reporting = float(arg)
670+
elif opt in ("-p", "--plot"):
671+
plot_file_name = arg
596672
elif opt in ("-r", "--runtime"):
597673
runtime_file_name = arg
598674
elif opt in ("-s", "--stats"):
@@ -675,6 +751,9 @@ def main(argv):
675751
else:
676752
print_usage("Command requires an input file name (compile/runtime or both).\n")
677753

754+
if plot_file_name:
755+
output_plot(runtime_file_name, plot_file_name)
756+
678757

679758
if __name__ == "__main__":
680759
main(sys.argv[1:])

0 commit comments

Comments
 (0)