Skip to content

Commit ecaf698

Browse files
authored
polish memory opt benchmark (#198)
1 parent 674f08d commit ecaf698

File tree

3 files changed

+6
-4
lines changed

3 files changed

+6
-4
lines changed

benchmark/benchmark.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
turbo-transformers Benchmark Utils
1515
1616
Usage:
17-
benchmark <model_name> [--seq_len=<int>] [--framework=<f>] [--batch_size=<int>] [-n <int>] [--enable-random] [--min_seq_len=<int>] [--max_seq_len=<int>] [--use_gpu] [--num_threads=<int>] [--enable_mem_opt=<bool>]
17+
benchmark <model_name> [--seq_len=<int>] [--framework=<f>] [--batch_size=<int>] [-n <int>] [--enable-random] [--min_seq_len=<int>] [--max_seq_len=<int>] [--use_gpu] [--num_threads=<int>] [--enable_mem_opt]
1818
1919
Options:
2020
--framework=<f> The framework to test in (torch, torch_jit, turbo-transformers,
@@ -27,7 +27,7 @@
2727
--max_seq_len=<int> Maximal sequence length generated when enable random [default: 50]
2828
--use_gpu Enable GPU.
2929
--num_threads=<int> The number of CPU threads. [default: 4]
30-
--enable_mem_opt=<bool> Use memory optimization for BERT. [default: False]
30+
--enable_mem_opt Use model aware memory optimization for BERT.
3131
"""
3232

3333
import json
@@ -54,7 +54,8 @@ def main():
5454
'use_gpu': True if args['--use_gpu'] else False,
5555
'enable_mem_opt': True if args['--enable_mem_opt'] else False,
5656
}
57-
if (kwargs['model_name'] != 'bert'):
57+
if (kwargs['model_name'] != 'bert'
58+
or args['--framework'] != 'turbo-transformers'):
5859
kwargs['enable_mem_opt'] = False
5960
if args['--framework'] == 'turbo-transformers':
6061
benchmark_turbo_transformers(**kwargs)

benchmark/onnx_benchmark_helper.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ def _impl_(model_name: str,
153153

154154
if enable_latency_plot:
155155
import time
156+
import torch
156157
print(f"dump results to onnxrt_latency_{num_threads}.txt")
157158
result_list = []
158159
with open(f"onnxrt_latency_{num_threads}.txt", "w") as of:

benchmark/run_gpu_variable_benchmark.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ set -e
1717
# FRAMEWORKS=("turbo-transformers" "torch" "onnxruntime")
1818
FRAMEWORKS=("turbo-transformers" "torch")
1919
# Note Onnx doese not supports Albert
20-
# FRAMEWORKS=("onnxruntime")
20+
# FRAMEWORKS=("onnxruntime-gpu")
2121

2222
MAX_SEQ_LEN=(500)
2323

0 commit comments

Comments
 (0)