27
27
from google .protobuf .timestamp_pb2 import Timestamp
28
28
29
29
MIN_SEQ_LEN = 4
30
- CLIENT_TIMEOUT_SEC = 3 * 60 * 60
31
30
NEW_TEXT_KEY = "\n Output:\n "
32
31
PROMETHEUS_PORT = 9090
33
32
@@ -148,6 +147,7 @@ async def send_stream_request(
148
147
tokenizer : PreTrainedTokenizerBase ,
149
148
sax_model : str ,
150
149
model : str ,
150
+ timeout : float ,
151
151
) -> Tuple [Tuple [int , int , float ], float , Dict [str , int ]]:
152
152
"""Sends stream request to server"""
153
153
request_start_time = time .time ()
@@ -179,7 +179,7 @@ async def send_stream_request(
179
179
ttft = 0.0
180
180
st = time .perf_counter ()
181
181
output = ""
182
- timeout = aiohttp .ClientTimeout (total = CLIENT_TIMEOUT_SEC )
182
+ timeout = aiohttp .ClientTimeout (total = timeout )
183
183
async with aiohttp .ClientSession (timeout = timeout ,trust_env = True ) as session :
184
184
try :
185
185
async with session .post (api_url , headers = headers , json = pload , ssl = False ) as response :
@@ -249,6 +249,7 @@ async def send_request(
249
249
tokenizer : PreTrainedTokenizerBase ,
250
250
sax_model : str ,
251
251
model : str ,
252
+ timeout : float ,
252
253
) -> Tuple [Tuple [int , int , float ], float , Dict [str , int ]]:
253
254
"""Sends request to server."""
254
255
request_start_time = time .time ()
@@ -322,7 +323,7 @@ async def send_request(
322
323
raise ValueError (f"Unknown backend: { backend } " )
323
324
324
325
# Set client timeout to be 3 hrs.
325
- timeout = aiohttp .ClientTimeout (total = CLIENT_TIMEOUT_SEC )
326
+ timeout = aiohttp .ClientTimeout (total = timeout )
326
327
async with aiohttp .ClientSession (timeout = timeout ,trust_env = True ,trace_configs = [trace_config ]) as session :
327
328
while True :
328
329
try :
@@ -426,6 +427,7 @@ async def benchmark(
426
427
tokenizer ,
427
428
args .sax_model ,
428
429
model ,
430
+ args .request_timeout ,
429
431
)
430
432
)
431
433
else :
@@ -442,6 +444,7 @@ async def benchmark(
442
444
tokenizer ,
443
445
args .sax_model ,
444
446
model ,
447
+ args .request_timeout ,
445
448
)
446
449
)
447
450
tasks .append (task )
@@ -834,6 +837,12 @@ async def main(args: argparse.Namespace):
834
837
action = "store_true" ,
835
838
help = "Whether to stream the request. Needed for TTFT metric" ,
836
839
)
840
+ parser .add_argument (
841
+ "--request-timeout" ,
842
+ type = float ,
843
+ default = (3.0 * 60.0 * 60.0 ),
844
+ help = "Individual request timeout" ,
845
+ )
837
846
parser .add_argument (
838
847
"--tokenizer" ,
839
848
type = str ,
0 commit comments