From 61ba7b538f44569199e5573e69810f08905ed275 Mon Sep 17 00:00:00 2001 From: tanmayv25 Date: Thu, 7 Sep 2023 16:44:48 -0700 Subject: [PATCH 01/11] Add tests for gRPC client-side cancellation --- .../client_cancellation_test.py | 237 ++++++++++++++++++ .../models/custom_identity_int32/config.pbtxt | 54 ++++ qa/L0_client_cancellation/test.sh | 103 ++++++++ 3 files changed, 394 insertions(+) create mode 100755 qa/L0_client_cancellation/client_cancellation_test.py create mode 100644 qa/L0_client_cancellation/models/custom_identity_int32/config.pbtxt create mode 100755 qa/L0_client_cancellation/test.sh diff --git a/qa/L0_client_cancellation/client_cancellation_test.py b/qa/L0_client_cancellation/client_cancellation_test.py new file mode 100755 index 0000000000..0119658e3e --- /dev/null +++ b/qa/L0_client_cancellation/client_cancellation_test.py @@ -0,0 +1,237 @@ +#!/usr/bin/env python3 + +# Copyright 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import sys + +sys.path.append("../common") + +import asyncio +import queue +import socket +import unittest +from functools import partial +import time + +import numpy as np +import test_util as tu +import tritonclient.grpc as grpcclient +import tritonclient.grpc.aio as aiogrpcclient +from tritonclient.utils import InferenceServerException + + +class UserData: + def __init__(self): + self._completed_requests = queue.Queue() + + +def callback(user_data, result, error): + if error: + user_data._completed_requests.put(error) + else: + user_data._completed_requests.put(result) + +class ClientCancellationTest(tu.TestResultCollector): + def setUp(self): + self.model_name_ = "custom_identity_int32" + self.input0_data_ = np.array([[10]], dtype=np.int32) + self._start_time_ms = 0 + self._end_time_ms = 0 + + def _record_start_time_ms(self): + self._start_time_ms = int(round(time.time() * 1000)) + + def _record_end_time_ms(self): + self._end_time_ms = int(round(time.time() * 1000)) + + def _test_runtime_duration(self, upper_limit): + self.assertTrue( + (self._end_time_ms - self._start_time_ms) < upper_limit, + "test runtime expected less than " + + str(upper_limit) + + "ms response time, got " + + str(self._end_time_ms - self._start_time_ms) + + " ms", + ) + + def _prepare_request(self): + self.inputs_ = [] + self.inputs_.append(grpcclient.InferInput("INPUT0", [1, 1], "INT32")) + self.outputs_ = [] + self.outputs_.append(grpcclient.InferRequestedOutput("OUTPUT0")) + + self.inputs_[0].set_data_from_numpy(self.input0_data_) + + + def test_grpc_async_infer(self): + # Sends a request using async_infer to a + # model that takes 10s to execute. Issues + # a cancellation request after 2s. The client + # should return with appropriate exception within + # 5s. + triton_client = grpcclient.InferenceServerClient( + url="localhost:8001", verbose=True + ) + self._prepare_request() + + user_data = UserData() + + self._record_start_time_ms() + + # Expect inference to pass successfully for a large timeout + # value + future = triton_client.async_infer( + model_name=self.model_name_, + inputs=self.inputs_, + callback=partial(callback, user_data), + outputs=self.outputs_, + ) + time.sleep(2) + future.cancel() + + # Wait until the results is captured via callback + data_item = user_data._completed_requests.get() + self.assertTrue(type(data_item) == grpcclient.CancelledError) + + self._record_end_time_ms() + self._test_runtime_duration(5000) + + def test_grpc_stream_infer(self): + # Sends a request using async_stream_infer to a + # model that takes 10s to execute. Issues stream + # closure with cancel_requests=True. The client + # should return with appropriate exception within + # 5s. + triton_client = grpcclient.InferenceServerClient( + url="localhost:8001", verbose=True + ) + + self._prepare_request() + user_data = UserData() + + # The model is configured to take three seconds to send the + # response. Expect an exception for small timeout values. + triton_client.start_stream( + callback=partial(callback, user_data) + ) + self._record_start_time_ms() + for i in range(1): + triton_client.async_stream_infer( + model_name=self.model_name_, inputs=self.inputs_, outputs=self.outputs_ + ) + + time.sleep(2) + triton_client.stop_stream(cancel_requests=True) + + data_item = user_data._completed_requests.get() + self.assertTrue(type(data_item) == grpcclient.CancelledError) + + self._record_end_time_ms() + self._test_runtime_duration(5000) + + + def test_aio_grpc_async_infer(self): + # Sends a request using infer of grpc.aio to a + # model that takes 10s to execute. Issues + # a cancellation request after 2s. The client + # should return with appropriate exception within + # 5s. + async def cancel_request(call): + await asyncio.sleep(2) + self.assertTrue(call.cancel()) + + async def handle_response(call): + with self.assertRaises(asyncio.exceptions.CancelledError) as cm: + response = await call + + async def test_aio_infer(self): + triton_client = aiogrpcclient.InferenceServerClient( + url="localhost:8001", verbose=True + ) + self._prepare_request() + self._record_start_time_ms() + # Expect inference to pass successfully for a large timeout + # value + call = await triton_client.infer( + model_name=self.model_name_, + inputs=self.inputs_, + outputs=self.outputs_, + get_call_obj=True, + ) + task1 = asyncio.create_task(handle_response(call)) + task2 = asyncio.create_task(cancel_request(call)) + await task1 + await task2 + + self._record_end_time_ms() + self._test_runtime_duration(5000) + + + asyncio.run(test_aio_infer(self)) + + def test_aio_grpc_stream_infer(self): + # Sends a request using stream_infer of grpc.aio + # library model that takes 10s to execute. Issues + # stream closure with cancel_requests=True. The client + # should return with appropriate exception within + # 5s. + async def test_aio_streaming_infer(self): + async with aiogrpcclient.InferenceServerClient( + url="localhost:8001", verbose=True) as triton_client: + async def async_request_iterator(): + for i in range(1): + await asyncio.sleep(1) + yield {"model_name": self.model_name_, + "inputs": self.inputs_, + "outputs": self.outputs_} + + self._prepare_request() + self._record_start_time_ms() + response_iterator = triton_client.stream_infer(inputs_iterator=async_request_iterator(), get_call_obj=True) + streaming_call = await response_iterator.__anext__() + + async def cancel_streaming(streaming_call): + await asyncio.sleep(2) + streaming_call.cancel() + + async def handle_response(response_iterator): + with self.assertRaises(asyncio.exceptions.CancelledError) as cm: + async for response in response_iterator: + self.assertTrue(False, "Received an unexpected response!") + + task1 = asyncio.create_task(handle_response(response_iterator)) + task2 = asyncio.create_task(cancel_streaming(streaming_call)) + await task1 + await task2 + + self._record_end_time_ms() + self._test_runtime_duration(5000) + + asyncio.run(test_aio_streaming_infer(self)) + +if __name__ == "__main__": + unittest.main() diff --git a/qa/L0_client_cancellation/models/custom_identity_int32/config.pbtxt b/qa/L0_client_cancellation/models/custom_identity_int32/config.pbtxt new file mode 100644 index 0000000000..f04acb18ea --- /dev/null +++ b/qa/L0_client_cancellation/models/custom_identity_int32/config.pbtxt @@ -0,0 +1,54 @@ +# Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +name: "custom_identity_int32" +backend: "identity" +max_batch_size: 1024 +version_policy: { latest { num_versions: 1 }} +instance_group [ { kind: KIND_CPU } ] + +input [ + { + name: "INPUT0" + data_type: TYPE_INT32 + dims: [ -1 ] + + } +] +output [ + { + name: "OUTPUT0" + data_type: TYPE_INT32 + dims: [ -1 ] + } +] + +parameters [ + { + key: "execute_delay_ms" + value: { string_value: "10000" } + } +] diff --git a/qa/L0_client_cancellation/test.sh b/qa/L0_client_cancellation/test.sh new file mode 100755 index 0000000000..da6b64706a --- /dev/null +++ b/qa/L0_client_cancellation/test.sh @@ -0,0 +1,103 @@ +#!/bin/bash +# Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +REPO_VERSION=${NVIDIA_TRITON_SERVER_VERSION} +if [ "$#" -ge 1 ]; then + REPO_VERSION=$1 +fi +if [ -z "$REPO_VERSION" ]; then + echo -e "Repository version must be specified" + echo -e "\n***\n*** Test Failed\n***" + exit 1 +fi +if [ ! -z "$TEST_REPO_ARCH" ]; then + REPO_VERSION=${REPO_VERSION}_${TEST_REPO_ARCH} +fi + +export CUDA_VISIBLE_DEVICES=0 + +RET=0 + +CLIENT_CANCELLATION_TEST=client_cancellation_test.py +TEST_RESULT_FILE='test_results.txt' + +rm -f *.log +rm -f *.log.* + +CLIENT_LOG=`pwd`/client.log +DATADIR=`pwd`/models +SERVER=/opt/tritonserver/bin/tritonserver +SERVER_ARGS="--model-repository=$DATADIR" +source ../common/util.sh + +mkdir -p $DATADIR/custom_identity_int32/1 + + + +set +e + +for i in test_grpc_async_infer \ + test_grpc_stream_infer \ + test_aio_grpc_async_infer \ + test_aio_grpc_stream_infer \ + ; do + + run_server + if [ "$SERVER_PID" == "0" ]; then + echo -e "\n***\n*** Failed to start $SERVER\n***" + cat $SERVER_LOG + exit 1 + fi + + set +e + python $CLIENT_CANCELLATION_TEST ClientCancellationTest.$i >>$CLIENT_LOG 2>&1 + if [ $? -ne 0 ]; then + echo -e "\n***\n*** Test $i Failed\n***" >>$CLIENT_LOG + echo -e "\n***\n*** Test $i Failed\n***" + RET=1 + else + check_test_results $TEST_RESULT_FILE 1 + if [ $? -ne 0 ]; then + cat $CLIENT_LOG + echo -e "\n***\n*** Test Result Verification Failed\n***" + RET=1 + fi + fi + + set -e + kill $SERVER_PID + wait $SERVER_PID +done + +if [ $RET -eq 0 ]; then + echo -e "\n***\n*** Test Passed\n***" +else + cat ${CLIENT_LOG} + echo -e "\n***\n*** Test FAILED\n***" +fi + +exit $RET From 8ff681f3934b608acab981b650da230d67864c0d Mon Sep 17 00:00:00 2001 From: tanmayv25 Date: Thu, 7 Sep 2023 17:18:09 -0700 Subject: [PATCH 02/11] Fix CodeQL issues --- .../client_cancellation_test.py | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/qa/L0_client_cancellation/client_cancellation_test.py b/qa/L0_client_cancellation/client_cancellation_test.py index 0119658e3e..d14ee230dd 100755 --- a/qa/L0_client_cancellation/client_cancellation_test.py +++ b/qa/L0_client_cancellation/client_cancellation_test.py @@ -32,7 +32,6 @@ import asyncio import queue -import socket import unittest from functools import partial import time @@ -115,7 +114,7 @@ def test_grpc_async_infer(self): # Wait until the results is captured via callback data_item = user_data._completed_requests.get() - self.assertTrue(type(data_item) == grpcclient.CancelledError) + self.assertEqual(type(data_item), grpcclient.CancelledError) self._record_end_time_ms() self._test_runtime_duration(5000) @@ -148,7 +147,7 @@ def test_grpc_stream_infer(self): triton_client.stop_stream(cancel_requests=True) data_item = user_data._completed_requests.get() - self.assertTrue(type(data_item) == grpcclient.CancelledError) + self.assertEqual(type(data_item), grpcclient.CancelledError) self._record_end_time_ms() self._test_runtime_duration(5000) @@ -166,7 +165,7 @@ async def cancel_request(call): async def handle_response(call): with self.assertRaises(asyncio.exceptions.CancelledError) as cm: - response = await call + await call async def test_aio_infer(self): triton_client = aiogrpcclient.InferenceServerClient( @@ -182,10 +181,8 @@ async def test_aio_infer(self): outputs=self.outputs_, get_call_obj=True, ) - task1 = asyncio.create_task(handle_response(call)) - task2 = asyncio.create_task(cancel_request(call)) - await task1 - await task2 + asyncio.create_task(handle_response(call)) + asyncio.create_task(cancel_request(call)) self._record_end_time_ms() self._test_runtime_duration(5000) @@ -223,10 +220,8 @@ async def handle_response(response_iterator): async for response in response_iterator: self.assertTrue(False, "Received an unexpected response!") - task1 = asyncio.create_task(handle_response(response_iterator)) - task2 = asyncio.create_task(cancel_streaming(streaming_call)) - await task1 - await task2 + asyncio.create_task(handle_response(response_iterator)) + asyncio.create_task(cancel_streaming(streaming_call)) self._record_end_time_ms() self._test_runtime_duration(5000) From 1e519789424f155c63c42e7b49780abb53694dcd Mon Sep 17 00:00:00 2001 From: tanmayv25 Date: Thu, 7 Sep 2023 17:56:24 -0700 Subject: [PATCH 03/11] Formatting --- .../client_cancellation_test.py | 45 ++++++++++--------- 1 file changed, 24 insertions(+), 21 deletions(-) diff --git a/qa/L0_client_cancellation/client_cancellation_test.py b/qa/L0_client_cancellation/client_cancellation_test.py index d14ee230dd..c009b7916a 100755 --- a/qa/L0_client_cancellation/client_cancellation_test.py +++ b/qa/L0_client_cancellation/client_cancellation_test.py @@ -32,9 +32,9 @@ import asyncio import queue +import time import unittest from functools import partial -import time import numpy as np import test_util as tu @@ -54,6 +54,7 @@ def callback(user_data, result, error): else: user_data._completed_requests.put(result) + class ClientCancellationTest(tu.TestResultCollector): def setUp(self): self.model_name_ = "custom_identity_int32" @@ -69,13 +70,13 @@ def _record_end_time_ms(self): def _test_runtime_duration(self, upper_limit): self.assertTrue( - (self._end_time_ms - self._start_time_ms) < upper_limit, - "test runtime expected less than " - + str(upper_limit) - + "ms response time, got " - + str(self._end_time_ms - self._start_time_ms) - + " ms", - ) + (self._end_time_ms - self._start_time_ms) < upper_limit, + "test runtime expected less than " + + str(upper_limit) + + "ms response time, got " + + str(self._end_time_ms - self._start_time_ms) + + " ms", + ) def _prepare_request(self): self.inputs_ = [] @@ -85,7 +86,6 @@ def _prepare_request(self): self.inputs_[0].set_data_from_numpy(self.input0_data_) - def test_grpc_async_infer(self): # Sends a request using async_infer to a # model that takes 10s to execute. Issues @@ -115,13 +115,13 @@ def test_grpc_async_infer(self): # Wait until the results is captured via callback data_item = user_data._completed_requests.get() self.assertEqual(type(data_item), grpcclient.CancelledError) - + self._record_end_time_ms() self._test_runtime_duration(5000) def test_grpc_stream_infer(self): # Sends a request using async_stream_infer to a - # model that takes 10s to execute. Issues stream + # model that takes 10s to execute. Issues stream # closure with cancel_requests=True. The client # should return with appropriate exception within # 5s. @@ -134,9 +134,7 @@ def test_grpc_stream_infer(self): # The model is configured to take three seconds to send the # response. Expect an exception for small timeout values. - triton_client.start_stream( - callback=partial(callback, user_data) - ) + triton_client.start_stream(callback=partial(callback, user_data)) self._record_start_time_ms() for i in range(1): triton_client.async_stream_infer( @@ -148,11 +146,10 @@ def test_grpc_stream_infer(self): data_item = user_data._completed_requests.get() self.assertEqual(type(data_item), grpcclient.CancelledError) - + self._record_end_time_ms() self._test_runtime_duration(5000) - def test_aio_grpc_async_infer(self): # Sends a request using infer of grpc.aio to a # model that takes 10s to execute. Issues @@ -187,7 +184,6 @@ async def test_aio_infer(self): self._record_end_time_ms() self._test_runtime_duration(5000) - asyncio.run(test_aio_infer(self)) def test_aio_grpc_stream_infer(self): @@ -198,17 +194,23 @@ def test_aio_grpc_stream_infer(self): # 5s. async def test_aio_streaming_infer(self): async with aiogrpcclient.InferenceServerClient( - url="localhost:8001", verbose=True) as triton_client: + url="localhost:8001", verbose=True + ) as triton_client: + async def async_request_iterator(): for i in range(1): await asyncio.sleep(1) - yield {"model_name": self.model_name_, + yield { + "model_name": self.model_name_, "inputs": self.inputs_, - "outputs": self.outputs_} + "outputs": self.outputs_, + } self._prepare_request() self._record_start_time_ms() - response_iterator = triton_client.stream_infer(inputs_iterator=async_request_iterator(), get_call_obj=True) + response_iterator = triton_client.stream_infer( + inputs_iterator=async_request_iterator(), get_call_obj=True + ) streaming_call = await response_iterator.__anext__() async def cancel_streaming(streaming_call): @@ -228,5 +230,6 @@ async def handle_response(response_iterator): asyncio.run(test_aio_streaming_infer(self)) + if __name__ == "__main__": unittest.main() From 17b347cff0df53b97332615c475ceb2be78ba230 Mon Sep 17 00:00:00 2001 From: Tanmay Verma Date: Thu, 7 Sep 2023 17:36:30 -0700 Subject: [PATCH 04/11] Update qa/L0_client_cancellation/client_cancellation_test.py Co-authored-by: Ryan McCormick --- qa/L0_client_cancellation/client_cancellation_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/qa/L0_client_cancellation/client_cancellation_test.py b/qa/L0_client_cancellation/client_cancellation_test.py index c009b7916a..d8385eef83 100755 --- a/qa/L0_client_cancellation/client_cancellation_test.py +++ b/qa/L0_client_cancellation/client_cancellation_test.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 -# Copyright 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions From ada27e8efc3f37d612bae2f7357beea8d108c536 Mon Sep 17 00:00:00 2001 From: tanmayv25 Date: Tue, 12 Sep 2023 16:00:53 -0700 Subject: [PATCH 05/11] Move to L0_request_cancellation --- qa/L0_client_cancellation/test.sh | 103 ------------------ .../client_cancellation_test.py | 0 .../models/custom_identity_int32/config.pbtxt | 0 qa/L0_request_cancellation/test.sh | 52 +++++++++ 4 files changed, 52 insertions(+), 103 deletions(-) delete mode 100755 qa/L0_client_cancellation/test.sh rename qa/{L0_client_cancellation => L0_request_cancellation}/client_cancellation_test.py (100%) rename qa/{L0_client_cancellation => L0_request_cancellation}/models/custom_identity_int32/config.pbtxt (100%) diff --git a/qa/L0_client_cancellation/test.sh b/qa/L0_client_cancellation/test.sh deleted file mode 100755 index da6b64706a..0000000000 --- a/qa/L0_client_cancellation/test.sh +++ /dev/null @@ -1,103 +0,0 @@ -#!/bin/bash -# Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions -# are met: -# * Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in the -# documentation and/or other materials provided with the distribution. -# * Neither the name of NVIDIA CORPORATION nor the names of its -# contributors may be used to endorse or promote products derived -# from this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY -# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR -# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR -# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, -# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, -# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR -# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY -# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -REPO_VERSION=${NVIDIA_TRITON_SERVER_VERSION} -if [ "$#" -ge 1 ]; then - REPO_VERSION=$1 -fi -if [ -z "$REPO_VERSION" ]; then - echo -e "Repository version must be specified" - echo -e "\n***\n*** Test Failed\n***" - exit 1 -fi -if [ ! -z "$TEST_REPO_ARCH" ]; then - REPO_VERSION=${REPO_VERSION}_${TEST_REPO_ARCH} -fi - -export CUDA_VISIBLE_DEVICES=0 - -RET=0 - -CLIENT_CANCELLATION_TEST=client_cancellation_test.py -TEST_RESULT_FILE='test_results.txt' - -rm -f *.log -rm -f *.log.* - -CLIENT_LOG=`pwd`/client.log -DATADIR=`pwd`/models -SERVER=/opt/tritonserver/bin/tritonserver -SERVER_ARGS="--model-repository=$DATADIR" -source ../common/util.sh - -mkdir -p $DATADIR/custom_identity_int32/1 - - - -set +e - -for i in test_grpc_async_infer \ - test_grpc_stream_infer \ - test_aio_grpc_async_infer \ - test_aio_grpc_stream_infer \ - ; do - - run_server - if [ "$SERVER_PID" == "0" ]; then - echo -e "\n***\n*** Failed to start $SERVER\n***" - cat $SERVER_LOG - exit 1 - fi - - set +e - python $CLIENT_CANCELLATION_TEST ClientCancellationTest.$i >>$CLIENT_LOG 2>&1 - if [ $? -ne 0 ]; then - echo -e "\n***\n*** Test $i Failed\n***" >>$CLIENT_LOG - echo -e "\n***\n*** Test $i Failed\n***" - RET=1 - else - check_test_results $TEST_RESULT_FILE 1 - if [ $? -ne 0 ]; then - cat $CLIENT_LOG - echo -e "\n***\n*** Test Result Verification Failed\n***" - RET=1 - fi - fi - - set -e - kill $SERVER_PID - wait $SERVER_PID -done - -if [ $RET -eq 0 ]; then - echo -e "\n***\n*** Test Passed\n***" -else - cat ${CLIENT_LOG} - echo -e "\n***\n*** Test FAILED\n***" -fi - -exit $RET diff --git a/qa/L0_client_cancellation/client_cancellation_test.py b/qa/L0_request_cancellation/client_cancellation_test.py similarity index 100% rename from qa/L0_client_cancellation/client_cancellation_test.py rename to qa/L0_request_cancellation/client_cancellation_test.py diff --git a/qa/L0_client_cancellation/models/custom_identity_int32/config.pbtxt b/qa/L0_request_cancellation/models/custom_identity_int32/config.pbtxt similarity index 100% rename from qa/L0_client_cancellation/models/custom_identity_int32/config.pbtxt rename to qa/L0_request_cancellation/models/custom_identity_int32/config.pbtxt diff --git a/qa/L0_request_cancellation/test.sh b/qa/L0_request_cancellation/test.sh index 7a359ebaec..1c70eaa015 100755 --- a/qa/L0_request_cancellation/test.sh +++ b/qa/L0_request_cancellation/test.sh @@ -42,6 +42,23 @@ DATADIR=${DATADIR:="/data/inferenceserver/${REPO_VERSION}"} RET=0 mkdir -p models/model/1 +mkdir -p $DATADIR/custom_identity_int32/1 + +export CUDA_VISIBLE_DEVICES=0 + +RET=0 + +CLIENT_CANCELLATION_TEST=client_cancellation_test.py +TEST_RESULT_FILE='test_results.txt' + +rm -f *.log +rm -f *.log.* + +CLIENT_LOG=`pwd`/client.log +DATADIR=`pwd`/models +SERVER=/opt/tritonserver/bin/tritonserver +SERVER_ARGS="--model-repository=$DATADIR --log-verbose=1" +source ../common/util.sh SERVER_LOG=server.log LD_LIBRARY_PATH=/opt/tritonserver/lib:$LD_LIBRARY_PATH ./request_cancellation_test > $SERVER_LOG @@ -50,6 +67,41 @@ if [ $? -ne 0 ]; then RET=1 fi +# gRPC client-side cancellation tests... +for i in test_grpc_async_infer \ + test_grpc_stream_infer \ + test_aio_grpc_async_infer \ + test_aio_grpc_stream_infer \ + ; do + + SERVER_LOG=${i}.server.log + run_server + if [ "$SERVER_PID" == "0" ]; then + echo -e "\n***\n*** Failed to start $SERVER\n***" + cat $SERVER_LOG + exit 1 + fi + + set +e + python $CLIENT_CANCELLATION_TEST ClientCancellationTest.$i >>$CLIENT_LOG 2>&1 + if [ $? -ne 0 ]; then + echo -e "\n***\n*** Test $i Failed\n***" >>$CLIENT_LOG + echo -e "\n***\n*** Test $i Failed\n***" + RET=1 + else + check_test_results $TEST_RESULT_FILE 1 + if [ $? -ne 0 ]; then + cat $CLIENT_LOG + echo -e "\n***\n*** Test Result Verification Failed\n***" + RET=1 + fi + fi + + set -e + kill $SERVER_PID + wait $SERVER_PID +done + if [ $RET -eq 0 ]; then echo -e "\n***\n*** Test Passed\n***" else From aad265d99e0febde885357316bf0549646c322f0 Mon Sep 17 00:00:00 2001 From: tanmayv25 Date: Tue, 12 Sep 2023 17:42:57 -0700 Subject: [PATCH 06/11] Address review comments --- .../client_cancellation_test.py | 79 +++++++++++-------- 1 file changed, 45 insertions(+), 34 deletions(-) diff --git a/qa/L0_request_cancellation/client_cancellation_test.py b/qa/L0_request_cancellation/client_cancellation_test.py index d8385eef83..d9e7c10c85 100755 --- a/qa/L0_request_cancellation/client_cancellation_test.py +++ b/qa/L0_request_cancellation/client_cancellation_test.py @@ -101,20 +101,20 @@ def test_grpc_async_infer(self): self._record_start_time_ms() - # Expect inference to pass successfully for a large timeout - # value - future = triton_client.async_infer( - model_name=self.model_name_, - inputs=self.inputs_, - callback=partial(callback, user_data), - outputs=self.outputs_, - ) - time.sleep(2) - future.cancel() + with self.assertRaises(InferenceServerException) as cm: + future = triton_client.async_infer( + model_name=self.model_name_, + inputs=self.inputs_, + callback=partial(callback, user_data), + outputs=self.outputs_, + ) + time.sleep(2) + future.cancel() - # Wait until the results is captured via callback - data_item = user_data._completed_requests.get() - self.assertEqual(type(data_item), grpcclient.CancelledError) + data_item = user_data._completed_requests.get() + if type(data_item) == InferenceServerException: + raise data_item + self.assertIn("Locally cancelled by application!", str(cm.exception)) self._record_end_time_ms() self._test_runtime_duration(5000) @@ -132,20 +132,22 @@ def test_grpc_stream_infer(self): self._prepare_request() user_data = UserData() - # The model is configured to take three seconds to send the - # response. Expect an exception for small timeout values. triton_client.start_stream(callback=partial(callback, user_data)) self._record_start_time_ms() - for i in range(1): - triton_client.async_stream_infer( - model_name=self.model_name_, inputs=self.inputs_, outputs=self.outputs_ - ) - time.sleep(2) - triton_client.stop_stream(cancel_requests=True) - - data_item = user_data._completed_requests.get() - self.assertEqual(type(data_item), grpcclient.CancelledError) + with self.assertRaises(InferenceServerException) as cm: + for i in range(1): + triton_client.async_stream_infer( + model_name=self.model_name_, + inputs=self.inputs_, + outputs=self.outputs_, + ) + time.sleep(2) + triton_client.stop_stream(cancel_requests=True) + data_item = user_data._completed_requests.get() + if type(data_item) == InferenceServerException: + raise data_item + self.assertIn("Locally cancelled by application!", str(cm.exception)) self._record_end_time_ms() self._test_runtime_duration(5000) @@ -160,9 +162,9 @@ async def cancel_request(call): await asyncio.sleep(2) self.assertTrue(call.cancel()) - async def handle_response(call): + async def handle_response(generator): with self.assertRaises(asyncio.exceptions.CancelledError) as cm: - await call + _ = await anext(generator) async def test_aio_infer(self): triton_client = aiogrpcclient.InferenceServerClient( @@ -170,16 +172,21 @@ async def test_aio_infer(self): ) self._prepare_request() self._record_start_time_ms() - # Expect inference to pass successfully for a large timeout - # value - call = await triton_client.infer( + + generator = triton_client.infer( model_name=self.model_name_, inputs=self.inputs_, outputs=self.outputs_, get_call_obj=True, ) - asyncio.create_task(handle_response(call)) - asyncio.create_task(cancel_request(call)) + grpc_call = await anext(generator) + + tasks = [] + tasks.append(asyncio.create_task(handle_response(generator))) + tasks.append(asyncio.create_task(cancel_request(grpc_call))) + + for task in tasks: + await task self._record_end_time_ms() self._test_runtime_duration(5000) @@ -211,7 +218,7 @@ async def async_request_iterator(): response_iterator = triton_client.stream_infer( inputs_iterator=async_request_iterator(), get_call_obj=True ) - streaming_call = await response_iterator.__anext__() + streaming_call = await anext(response_iterator) async def cancel_streaming(streaming_call): await asyncio.sleep(2) @@ -222,8 +229,12 @@ async def handle_response(response_iterator): async for response in response_iterator: self.assertTrue(False, "Received an unexpected response!") - asyncio.create_task(handle_response(response_iterator)) - asyncio.create_task(cancel_streaming(streaming_call)) + tasks = [] + tasks.append(asyncio.create_task(handle_response(response_iterator))) + tasks.append(asyncio.create_task(cancel_streaming(streaming_call))) + + for task in tasks: + await task self._record_end_time_ms() self._test_runtime_duration(5000) From 35b643f4d2324e01e131a1233c85c9f7e58bb46c Mon Sep 17 00:00:00 2001 From: tanmayv25 Date: Mon, 18 Sep 2023 16:12:58 -0700 Subject: [PATCH 07/11] Removing request cancellation support from asyncio version --- .../client_cancellation_test.py | 178 +++++++++--------- 1 file changed, 90 insertions(+), 88 deletions(-) diff --git a/qa/L0_request_cancellation/client_cancellation_test.py b/qa/L0_request_cancellation/client_cancellation_test.py index d9e7c10c85..5808529eca 100755 --- a/qa/L0_request_cancellation/client_cancellation_test.py +++ b/qa/L0_request_cancellation/client_cancellation_test.py @@ -152,94 +152,96 @@ def test_grpc_stream_infer(self): self._record_end_time_ms() self._test_runtime_duration(5000) - def test_aio_grpc_async_infer(self): - # Sends a request using infer of grpc.aio to a - # model that takes 10s to execute. Issues - # a cancellation request after 2s. The client - # should return with appropriate exception within - # 5s. - async def cancel_request(call): - await asyncio.sleep(2) - self.assertTrue(call.cancel()) - - async def handle_response(generator): - with self.assertRaises(asyncio.exceptions.CancelledError) as cm: - _ = await anext(generator) - - async def test_aio_infer(self): - triton_client = aiogrpcclient.InferenceServerClient( - url="localhost:8001", verbose=True - ) - self._prepare_request() - self._record_start_time_ms() - - generator = triton_client.infer( - model_name=self.model_name_, - inputs=self.inputs_, - outputs=self.outputs_, - get_call_obj=True, - ) - grpc_call = await anext(generator) - - tasks = [] - tasks.append(asyncio.create_task(handle_response(generator))) - tasks.append(asyncio.create_task(cancel_request(grpc_call))) - - for task in tasks: - await task - - self._record_end_time_ms() - self._test_runtime_duration(5000) - - asyncio.run(test_aio_infer(self)) - - def test_aio_grpc_stream_infer(self): - # Sends a request using stream_infer of grpc.aio - # library model that takes 10s to execute. Issues - # stream closure with cancel_requests=True. The client - # should return with appropriate exception within - # 5s. - async def test_aio_streaming_infer(self): - async with aiogrpcclient.InferenceServerClient( - url="localhost:8001", verbose=True - ) as triton_client: - - async def async_request_iterator(): - for i in range(1): - await asyncio.sleep(1) - yield { - "model_name": self.model_name_, - "inputs": self.inputs_, - "outputs": self.outputs_, - } - - self._prepare_request() - self._record_start_time_ms() - response_iterator = triton_client.stream_infer( - inputs_iterator=async_request_iterator(), get_call_obj=True - ) - streaming_call = await anext(response_iterator) - - async def cancel_streaming(streaming_call): - await asyncio.sleep(2) - streaming_call.cancel() - - async def handle_response(response_iterator): - with self.assertRaises(asyncio.exceptions.CancelledError) as cm: - async for response in response_iterator: - self.assertTrue(False, "Received an unexpected response!") - - tasks = [] - tasks.append(asyncio.create_task(handle_response(response_iterator))) - tasks.append(asyncio.create_task(cancel_streaming(streaming_call))) - - for task in tasks: - await task - - self._record_end_time_ms() - self._test_runtime_duration(5000) - - asyncio.run(test_aio_streaming_infer(self)) +# Disabling AsyncIO cancellation testing. Enable once +# DLIS-5476 is implemented. +# def test_aio_grpc_async_infer(self): +# # Sends a request using infer of grpc.aio to a +# # model that takes 10s to execute. Issues +# # a cancellation request after 2s. The client +# # should return with appropriate exception within +# # 5s. +# async def cancel_request(call): +# await asyncio.sleep(2) +# self.assertTrue(call.cancel()) +# +# async def handle_response(generator): +# with self.assertRaises(asyncio.exceptions.CancelledError) as cm: +# _ = await anext(generator) +# +# async def test_aio_infer(self): +# triton_client = aiogrpcclient.InferenceServerClient( +# url="localhost:8001", verbose=True +# ) +# self._prepare_request() +# self._record_start_time_ms() +# +# generator = triton_client.infer( +# model_name=self.model_name_, +# inputs=self.inputs_, +# outputs=self.outputs_, +# get_call_obj=True, +# ) +# grpc_call = await anext(generator) +# +# tasks = [] +# tasks.append(asyncio.create_task(handle_response(generator))) +# tasks.append(asyncio.create_task(cancel_request(grpc_call))) +# +# for task in tasks: +# await task +# +# self._record_end_time_ms() +# self._test_runtime_duration(5000) +# +# asyncio.run(test_aio_infer(self)) +# +# def test_aio_grpc_stream_infer(self): +# # Sends a request using stream_infer of grpc.aio +# # library model that takes 10s to execute. Issues +# # stream closure with cancel_requests=True. The client +# # should return with appropriate exception within +# # 5s. +# async def test_aio_streaming_infer(self): +# async with aiogrpcclient.InferenceServerClient( +# url="localhost:8001", verbose=True +# ) as triton_client: +# +# async def async_request_iterator(): +# for i in range(1): +# await asyncio.sleep(1) +# yield { +# "model_name": self.model_name_, +# "inputs": self.inputs_, +# "outputs": self.outputs_, +# } +# +# self._prepare_request() +# self._record_start_time_ms() +# response_iterator = triton_client.stream_infer( +# inputs_iterator=async_request_iterator(), get_call_obj=True +# ) +# streaming_call = await anext(response_iterator) +# +# async def cancel_streaming(streaming_call): +# await asyncio.sleep(2) +# streaming_call.cancel() +# +# async def handle_response(response_iterator): +# with self.assertRaises(asyncio.exceptions.CancelledError) as cm: +# async for response in response_iterator: +# self.assertTrue(False, "Received an unexpected response!") +# +# tasks = [] +# tasks.append(asyncio.create_task(handle_response(response_iterator))) +# tasks.append(asyncio.create_task(cancel_streaming(streaming_call))) +# +# for task in tasks: +# await task +# +# self._record_end_time_ms() +# self._test_runtime_duration(5000) +# +# asyncio.run(test_aio_streaming_infer(self)) if __name__ == "__main__": From b8ffb1e243bf8db0165f70ffd3586e96a32bf231 Mon Sep 17 00:00:00 2001 From: tanmayv25 Date: Mon, 18 Sep 2023 16:14:26 -0700 Subject: [PATCH 08/11] Format --- qa/L0_request_cancellation/client_cancellation_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/qa/L0_request_cancellation/client_cancellation_test.py b/qa/L0_request_cancellation/client_cancellation_test.py index 5808529eca..c2bc0a0bf9 100755 --- a/qa/L0_request_cancellation/client_cancellation_test.py +++ b/qa/L0_request_cancellation/client_cancellation_test.py @@ -152,6 +152,7 @@ def test_grpc_stream_infer(self): self._record_end_time_ms() self._test_runtime_duration(5000) + # Disabling AsyncIO cancellation testing. Enable once # DLIS-5476 is implemented. # def test_aio_grpc_async_infer(self): From 0f75a3adf4ba73238ef095e9e36140e47b7493d8 Mon Sep 17 00:00:00 2001 From: tanmayv25 Date: Mon, 18 Sep 2023 16:16:06 -0700 Subject: [PATCH 09/11] Update copyright --- .../models/custom_identity_int32/config.pbtxt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/qa/L0_request_cancellation/models/custom_identity_int32/config.pbtxt b/qa/L0_request_cancellation/models/custom_identity_int32/config.pbtxt index f04acb18ea..4d9eda743b 100644 --- a/qa/L0_request_cancellation/models/custom_identity_int32/config.pbtxt +++ b/qa/L0_request_cancellation/models/custom_identity_int32/config.pbtxt @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions From 832b67f763d05981f022514e09a3f557912c2819 Mon Sep 17 00:00:00 2001 From: tanmayv25 Date: Mon, 18 Sep 2023 16:24:19 -0700 Subject: [PATCH 10/11] Remove tests --- qa/L0_request_cancellation/test.sh | 2 -- 1 file changed, 2 deletions(-) diff --git a/qa/L0_request_cancellation/test.sh b/qa/L0_request_cancellation/test.sh index 1c70eaa015..d50c288f8f 100755 --- a/qa/L0_request_cancellation/test.sh +++ b/qa/L0_request_cancellation/test.sh @@ -70,8 +70,6 @@ fi # gRPC client-side cancellation tests... for i in test_grpc_async_infer \ test_grpc_stream_infer \ - test_aio_grpc_async_infer \ - test_aio_grpc_stream_infer \ ; do SERVER_LOG=${i}.server.log From 5fc88fc1d21f32cf5952f61b59a559f0101b9449 Mon Sep 17 00:00:00 2001 From: Tanmay Verma Date: Mon, 18 Sep 2023 18:23:06 -0700 Subject: [PATCH 11/11] Handle cancellation notification in gRPC server (#6298) * Handle cancellation notification in gRPC server * Fix the request ptr initialization * Update src/grpc/infer_handler.h Co-authored-by: Ryan McCormick * Address review comment * Fix logs * Fix request complete callback by removing reference to state * Improve documentation --------- Co-authored-by: Ryan McCormick --- src/grpc/infer_handler.cc | 58 ++++++++++++- src/grpc/infer_handler.h | 135 ++++++++++++++++++++++++++++++- src/grpc/stream_infer_handler.cc | 94 ++++++++++++++++++--- 3 files changed, 270 insertions(+), 17 deletions(-) diff --git a/src/grpc/infer_handler.cc b/src/grpc/infer_handler.cc index 37a921fa75..e40cdf6165 100644 --- a/src/grpc/infer_handler.cc +++ b/src/grpc/infer_handler.cc @@ -62,6 +62,12 @@ operator<<(std::ostream& out, const Steps& step) case WRITTEN: out << "WRITTEN"; break; + case CANCELLATION_ISSUED: + out << "CANCELLATION_ISSUED"; + break; + case CANCELLED: + out << "CANCELLED"; + break; } return out; @@ -707,6 +713,21 @@ ModelInferHandler::StartNewRequest() bool ModelInferHandler::Process(InferHandler::State* state, bool rpc_ok) { + // There are multiple handlers registered in the gRPC service. + // Hence, there we can have a case where a handler thread is + // making progress in the state machine for a request and the + // other thread is issuing cancellation on the same request. + // Need to protect the state transitions for these cases. + std::lock_guard lock(state->step_mtx_); + + // Handle notification for cancellation which can be raised + // asynchronously if detected on the network. + if (state->IsGrpcContextCancelled()) { + bool resume = state->context_->HandleCancellation(state, rpc_ok, Name()); + return resume; + } + + LOG_VERBOSE(1) << "Process for " << Name() << ", rpc_ok=" << rpc_ok << ", " << state->unique_id_ << " step " << state->step_; @@ -933,8 +954,12 @@ ModelInferHandler::Execute(InferHandler::State* state) // If not error then state->step_ == ISSUED and inference request // has initiated... completion callback will transition to - // COMPLETE. If error go immediately to COMPLETE. - if (err != nullptr) { + // COMPLETE or CANCELLED. Recording the state and the irequest + // to handle gRPC stream cancellation. + if (err == nullptr) { + state->context_->InsertInflightState(state, irequest); + } else { + // If error go immediately to COMPLETE. LOG_VERBOSE(1) << "[request id: " << request_id << "] " << "Infer failed: " << TRITONSERVER_ErrorMessage(err); @@ -965,6 +990,12 @@ ModelInferHandler::InferResponseComplete( { State* state = reinterpret_cast(userp); + // There are multiple handlers registered in the gRPC service + // Hence, we would need to properly synchronize this thread + // and the handler thread handling async cancellation + // notification. + std::lock_guard lock(state->step_mtx_); + // Increment the callback index state->cb_count_++; @@ -982,6 +1013,29 @@ ModelInferHandler::InferResponseComplete( "INFER_RESPONSE_COMPLETE", TraceManager::CaptureTimestamp())); #endif // TRITON_ENABLE_TRACING + // If gRPC Stream is cancelled then no need of forming and returning + // a response. + if (state->IsGrpcContextCancelled()) { + // Clean-up the received response object. + LOG_TRITONSERVER_ERROR( + TRITONSERVER_InferenceResponseDelete(iresponse), + "deleting GRPC inference response"); + + state->step_ = Steps::CANCELLED; + state->context_->EraseInflightState(state); + + LOG_VERBOSE(1) << "ModelInferHandler::InferResponseComplete, " + << state->unique_id_ + << ", skipping response generation as grpc transaction was " + "cancelled... "; + + // Send state back to the queue so that state can be released + // in the next cycle. + state->context_->PutTaskBackToQueue(state); + + return; + } + TRITONSERVER_Error* err = nullptr; // This callback is expected to be called exactly once for each request. // Will use the single response object in the response list to hold the diff --git a/src/grpc/infer_handler.h b/src/grpc/infer_handler.h index b2ce3f13e2..f0adb29fc3 100644 --- a/src/grpc/infer_handler.h +++ b/src/grpc/infer_handler.h @@ -61,7 +61,9 @@ typedef enum { ISSUED, READ, WRITEREADY, - WRITTEN + WRITTEN, + CANCELLATION_ISSUED, + CANCELLED } Steps; // Debugging helper @@ -652,12 +654,127 @@ class InferHandlerState { ctx_->set_compression_level(compression_level); } + void GrpcContextAsyncNotifyWhenDone(InferHandlerStateType* state) + { + ctx_->AsyncNotifyWhenDone(state); + } + + bool IsCancelled() { return ctx_->IsCancelled(); } + // Increments the ongoing request counter void IncrementRequestCounter() { ongoing_requests_++; } // Decrements the ongoing request counter void DecrementRequestCounter() { ongoing_requests_--; } + + // Inserts the state to a set tracking active requests + // within the server core. Should only be called when + // the request was successfully enqueued on Triton. + void InsertInflightState( + InferHandlerStateType* state, TRITONSERVER_InferenceRequest* irequest) + { + std::lock_guard lock(mu_); + // The irequest_ptr_ will get populated when it is + // marked as active which means the request has been + // successfully enqueued to Triton core using + // TRITONSERVER_ServerInferAsync. + state->irequest_ptr_ = irequest; + inflight_states_.insert(state); + } + + // Erases the state to a set tracking active requests + // within the server core. + void EraseInflightState(InferHandlerStateType* state) + { + std::lock_guard lock(mu_); + inflight_states_.erase(state); + } + + // Issues the cancellation for all inflight requests + // being tracked by this context. + void IssueRequestCancellation() + { + { + std::lock_guard lock(mu_); + + // Issues the request cancellation to the core. + for (auto state : inflight_states_) { + std::lock_guard lock(state->step_mtx_); + if (state->step_ != Steps::CANCELLED) { + LOG_VERBOSE(1) << "Issuing cancellation for " << state->unique_id_; + if (state->irequest_ptr_ == nullptr) { + // The context might be holding some states that have + // not been issued to Triton core. Need to skip calling + // issuing cancellation for such requests. + continue; + } + // Note that request may or may not be valid at this point. + // Assuming if RequestComplete callback is run asynchronously + // before this point. + TRITONSERVER_Error* err = nullptr; + err = TRITONSERVER_InferenceRequestCancel(state->irequest_ptr_); + // TODO: Add request id to the message + if (err != nullptr) { + LOG_INFO << "Failed to cancel the request: " + << TRITONSERVER_ErrorMessage(err); + } + state->step_ = Steps::CANCELLATION_ISSUED; + } + } + } + } + + + // Handles the gRPC context cancellation. This function can be called + // multiple times and is supposed to be re-entrant. + // Returns whether or not to continue cycling through the gRPC + // completion queue or not. + bool HandleCancellation( + InferHandlerStateType* state, bool rpc_ok, const std::string& name) + { + if (!IsCancelled()) { + LOG_ERROR + << "[INTERNAL] HandleCancellation called even when the context was " + "not cancelled for " + << name << ", rpc_ok=" << rpc_ok << ", context " + << state->context_->unique_id_ << ", " << state->unique_id_ + << " step " << state->step_; + return true; + } + if ((state->step_ != Steps::CANCELLATION_ISSUED) && + (state->step_ != Steps::CANCELLED)) { + LOG_VERBOSE(1) << "Cancellation notification received for " << name + << ", rpc_ok=" << rpc_ok << ", context " + << state->context_->unique_id_ << ", " + << state->unique_id_ << " step " << state->step_; + + // If the context has not been cancelled then + // issue cancellation request to all the inflight + // states belonging to the context. + if (state->context_->step_ != Steps::CANCELLED) { + IssueRequestCancellation(); + // Mark the context as cancelled + state->context_->step_ = Steps::CANCELLED; + + // The state returns true because the CancelExecution + // call above would have raised alarm objects on all + // pending inflight states objects. This state will + // be taken up along with all the other states in the + // next iteration from the completion queue which + // would release the state. + return true; + } + } + + LOG_VERBOSE(1) << "Completing cancellation for " << name + << ", rpc_ok=" << rpc_ok << ", context " + << state->context_->unique_id_ << ", " << state->unique_id_ + << " step " << state->step_; + + return false; + } + // Enqueue 'state' so that its response is delivered in the // correct order. void EnqueueForResponse(InferHandlerStateType* state) @@ -781,6 +898,11 @@ class InferHandlerState { std::queue states_; std::atomic ongoing_requests_; + // Tracks the inflight requests sent to Triton core via this + // context. We will use this structure to issue cancellations + // on these requests. + std::set inflight_states_; + // The step of the entire context. Steps step_; @@ -809,12 +931,15 @@ class InferHandlerState { ~InferHandlerState() { ClearTraceTimestamps(); } + bool IsGrpcContextCancelled() { return context_->IsCancelled(); } + void Reset( const std::shared_ptr& context, Steps start_step = Steps::START) { unique_id_ = NEXT_UNIQUE_ID; context_ = context; step_ = start_step; + irequest_ptr_ = nullptr; cb_count_ = 0; is_decoupled_ = false; complete_ = false; @@ -859,7 +984,9 @@ class InferHandlerState { std::shared_ptr context_; Steps step_; - std::mutex step_mtx_; + std::recursive_mutex step_mtx_; + + TRITONSERVER_InferenceRequest* irequest_ptr_; #ifdef TRITON_ENABLE_TRACING std::shared_ptr trace_; @@ -939,6 +1066,10 @@ class InferHandler : public HandlerBase { state = new State(tritonserver, context, start_step); } + // Need to be called to receive an asynchronous notification + // when the transaction is cancelled. + context->GrpcContextAsyncNotifyWhenDone(state); + return state; } diff --git a/src/grpc/stream_infer_handler.cc b/src/grpc/stream_infer_handler.cc index 8877694284..ed6ee2296d 100644 --- a/src/grpc/stream_infer_handler.cc +++ b/src/grpc/stream_infer_handler.cc @@ -132,6 +132,22 @@ ModelStreamInferHandler::StartNewRequest() bool ModelStreamInferHandler::Process(InferHandler::State* state, bool rpc_ok) { + // Because gRPC doesn't allow concurrent writes on the + // the stream we only have a single handler thread that + // reads from the completion queue. Hence, cancellation + // notification will be received on the same handler + // thread. + // This means that we only need to take care of + // synchronizing this thread and the ResponseComplete + // threads. + { + std::lock_guard lock(state->step_mtx_); + if (state->IsGrpcContextCancelled()) { + bool resume = state->context_->HandleCancellation(state, rpc_ok, Name()); + return resume; + } + } + LOG_VERBOSE(1) << "Process for " << Name() << ", rpc_ok=" << rpc_ok << ", context " << state->context_->unique_id_ << ", " << state->unique_id_ << " step " << state->step_; @@ -292,9 +308,13 @@ ModelStreamInferHandler::Process(InferHandler::State* state, bool rpc_ok) // If there was not an error in issuing the 'state' request then // state->step_ == ISSUED and inference request has // initiated... the completion callback will transition to - // WRITEREADY or WRITTEN. If there was an error then enqueue the - // error response and show it to be ready for writing. - if (err != nullptr) { + // WRITEREADY or WRITTEN or CANCELLED. Recording the state and the + // irequest to handle gRPC stream cancellation. + if (err == nullptr) { + state->context_->InsertInflightState(state, irequest); + } else { + // If there was an error then enqueue the error response and show + // it to be ready for writing. inference::ModelStreamInferResponse* response; if (state->is_decoupled_) { state->response_queue_->AllocateResponse(); @@ -439,7 +459,7 @@ ModelStreamInferHandler::Process(InferHandler::State* state, bool rpc_ok) state->context_->DecrementRequestCounter(); finished = Finish(state); } else { - std::lock_guard lock(state->step_mtx_); + std::lock_guard lock(state->step_mtx_); // If there is an available response to be written // to the stream, then transition directly to WRITEREADY @@ -543,6 +563,31 @@ ModelStreamInferHandler::StreamInferResponseComplete( } } + if (state->IsGrpcContextCancelled()) { + std::lock_guard lock(state->step_mtx_); + // Clean-up the received response object. + LOG_TRITONSERVER_ERROR( + TRITONSERVER_InferenceResponseDelete(iresponse), + "deleting GRPC inference response"); + + LOG_VERBOSE(1) << "ModelStreamInferHandler::StreamInferResponseComplete, " + << state->unique_id_ + << ", skipping response generation as grpc transaction was " + "cancelled... "; + + // If this was the final callback for the state + // then cycle through the completion queue so + // that state object can be released. + if (state->complete_) { + state->step_ = Steps::CANCELLED; + state->context_->EraseInflightState(state); + + state->context_->PutTaskBackToQueue(state); + } + + return; + } + auto& response_queue = state->response_queue_; std::string log_request_id = state->request_.id(); if (log_request_id.empty()) { @@ -619,18 +664,41 @@ ModelStreamInferHandler::StreamInferResponseComplete( } // Update states to signal that response/error is ready to write to stream - if (state->is_decoupled_) { - std::lock_guard lock(state->step_mtx_); - if (response) { - state->response_queue_->MarkNextResponseComplete(); + { + // Need to hold lock because the handler thread processing context + // cancellation might have cancelled or marked the state for cancellation. + std::lock_guard lock(state->step_mtx_); + + if (state->IsGrpcContextCancelled()) { + LOG_VERBOSE(1) + << "ModelStreamInferHandler::StreamInferResponseComplete, " + << state->unique_id_ + << ", skipping writing response because of transaction was cancelled"; + + // If this was the final callback for the state + // then cycle through the completion queue so + // that state object can be released. + if (state->complete_) { + state->step_ = Steps::CANCELLED; + state->context_->EraseInflightState(state); + state->context_->PutTaskBackToQueue(state); + } + + return; } - if (state->step_ == Steps::ISSUED) { + + if (state->is_decoupled_) { + if (response) { + state->response_queue_->MarkNextResponseComplete(); + } + if (state->step_ == Steps::ISSUED) { + state->step_ = Steps::WRITEREADY; + state->context_->PutTaskBackToQueue(state); + } + } else { state->step_ = Steps::WRITEREADY; - state->context_->PutTaskBackToQueue(state); + state->context_->WriteResponseIfReady(state); } - } else { - state->step_ = Steps::WRITEREADY; - state->context_->WriteResponseIfReady(state); } }