Skip to content

Commit 627dc8c

Browse files
committed
Add basic testing. Cherry pick from #6369
1 parent 91c8b84 commit 627dc8c

File tree

4 files changed

+289
-0
lines changed

4 files changed

+289
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# Copyright 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# Redistribution and use in source and binary forms, with or without
4+
# modification, are permitted provided that the following conditions
5+
# are met:
6+
# * Redistributions of source code must retain the above copyright
7+
# notice, this list of conditions and the following disclaimer.
8+
# * Redistributions in binary form must reproduce the above copyright
9+
# notice, this list of conditions and the following disclaimer in the
10+
# documentation and/or other materials provided with the distribution.
11+
# * Neither the name of NVIDIA CORPORATION nor the names of its
12+
# contributors may be used to endorse or promote products derived
13+
# from this software without specific prior written permission.
14+
#
15+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
16+
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18+
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
19+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22+
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
23+
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26+
27+
import json
28+
29+
import triton_python_backend_utils as pb_utils
30+
import numpy as np
31+
32+
33+
class TritonPythonModel:
34+
def initialize(self, args):
35+
self.model_config = json.loads(args["model_config"])
36+
self.decoupled = self.model_config.get("model_transaction_policy", {}).get(
37+
"decoupled"
38+
)
39+
print(f"{self.decoupled=}")
40+
41+
def execute(self, requests):
42+
if self.decoupled:
43+
return self.exec_decoupled(requests)
44+
else:
45+
return self.exec(requests)
46+
47+
def exec(self, requests):
48+
responses = []
49+
for request in requests:
50+
params = json.loads(request.parameters())
51+
rep_count = params["REPETITION"] if "REPETITION" in params else 1
52+
53+
input_np = pb_utils.get_input_tensor_by_name(request, "PROMPT").as_numpy()
54+
stream_np = pb_utils.get_input_tensor_by_name(request, "STREAM").as_numpy()
55+
stream = stream_np.flatten()[0]
56+
if stream:
57+
responses.append(
58+
pb_utils.InferenceResponse(
59+
error=pb_utils.TritonError(
60+
"STREAM only supported in decoupled mode"
61+
)
62+
)
63+
)
64+
else:
65+
out_tensor = pb_utils.Tensor("TEXT", np.repeat(input_np, rep_count, axis=1))
66+
responses.append(pb_utils.InferenceResponse([out_tensor]))
67+
return responses
68+
69+
def exec_decoupled(self, requests):
70+
for request in requests:
71+
params = json.loads(request.parameters())
72+
rep_count = params["REPETITION"] if "REPETITION" in params else 1
73+
74+
sender = request.get_response_sender()
75+
input_np = pb_utils.get_input_tensor_by_name(request, "PROMPT").as_numpy()
76+
stream_np = pb_utils.get_input_tensor_by_name(request, "STREAM").as_numpy()
77+
out_tensor = pb_utils.Tensor("TEXT", input_np)
78+
response = pb_utils.InferenceResponse([out_tensor])
79+
# If stream enabled, just send multiple copies of response
80+
# FIXME: Could split up response string into tokens, but this is simpler for now.
81+
stream = stream_np.flatten()[0]
82+
if stream:
83+
for _ in range(rep_count):
84+
sender.send(response)
85+
sender.send(None, flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
86+
# If stream disabled, just send one response
87+
else:
88+
sender.send(
89+
response, flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL
90+
)
91+
return None
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# Redistribution and use in source and binary forms, with or without
4+
# modification, are permitted provided that the following conditions
5+
# are met:
6+
# * Redistributions of source code must retain the above copyright
7+
# notice, this list of conditions and the following disclaimer.
8+
# * Redistributions in binary form must reproduce the above copyright
9+
# notice, this list of conditions and the following disclaimer in the
10+
# documentation and/or other materials provided with the distribution.
11+
# * Neither the name of NVIDIA CORPORATION nor the names of its
12+
# contributors may be used to endorse or promote products derived
13+
# from this software without specific prior written permission.
14+
#
15+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
16+
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18+
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
19+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22+
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
23+
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26+
backend: "python"
27+
28+
# Disabling batching in Triton, let vLLM handle the batching on its own.
29+
max_batch_size: 0
30+
31+
model_transaction_policy {
32+
decoupled: True
33+
}
34+
35+
input [
36+
{
37+
name: "PROMPT"
38+
data_type: TYPE_STRING
39+
dims: [ 1, 1 ]
40+
},
41+
{
42+
name: "STREAM"
43+
data_type: TYPE_BOOL
44+
dims: [ 1, 1 ]
45+
}
46+
]
47+
48+
output [
49+
{
50+
name: "TEXT"
51+
data_type: TYPE_STRING
52+
dims: [ 1, -1 ]
53+
}
54+
]
55+
56+
# The usage of device is deferred to the vLLM engine
57+
instance_group [
58+
{
59+
count: 1
60+
kind: KIND_MODEL
61+
}
62+
]

qa/L0_http/llm_test.py

+96
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
#!/usr/bin/python3
2+
# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
#
4+
# Redistribution and use in source and binary forms, with or without
5+
# modification, are permitted provided that the following conditions
6+
# are met:
7+
# * Redistributions of source code must retain the above copyright
8+
# notice, this list of conditions and the following disclaimer.
9+
# * Redistributions in binary form must reproduce the above copyright
10+
# notice, this list of conditions and the following disclaimer in the
11+
# documentation and/or other materials provided with the distribution.
12+
# * Neither the name of NVIDIA CORPORATION nor the names of its
13+
# contributors may be used to endorse or promote products derived
14+
# from this software without specific prior written permission.
15+
#
16+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
17+
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
19+
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
20+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
21+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23+
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
24+
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
25+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
26+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27+
28+
import sys
29+
30+
sys.path.append("../common")
31+
32+
import json
33+
import unittest
34+
35+
import requests
36+
import test_util as tu
37+
38+
39+
class HttpTest(tu.TestResultCollector):
40+
def _get_infer_url(self, model_name, route):
41+
return f"http://localhost:8000/v2/models/{model_name}/{route}"
42+
43+
def test_generate(self):
44+
model_name = "vllm_proxy"
45+
# Setup text-based input
46+
text = "hello world"
47+
inputs = {"PROMPT": [text], "STREAM": False}
48+
49+
url = self._get_infer_url(model_name, "generate")
50+
# stream=True used to indicate response can be iterated over
51+
r = requests.post(url, data=json.dumps(inputs))
52+
53+
r.raise_for_status()
54+
55+
self.assertIn("Content-Type", r.headers)
56+
self.assertIn("application/json", r.headers["Content-Type"])
57+
58+
data = r.json()
59+
self.assertTrue("TEXT" in data)
60+
self.assertEqual(text, data["TEXT"])
61+
62+
def test_generate_stream(self):
63+
model_name = "vllm_proxy"
64+
# Setup text-based input
65+
text = "hello world"
66+
rep_count = 3
67+
inputs = {"PROMPT": [text], "STREAM": True, "REPETITION": rep_count}
68+
69+
import sseclient
70+
71+
headers = {"Accept": "text/event-stream"}
72+
url = self._get_infer_url(model_name, "generate_stream")
73+
# stream=True used to indicate response can be iterated over
74+
r = requests.post(url, data=json.dumps(inputs), headers=headers, stream=True)
75+
76+
r.raise_for_status()
77+
78+
# Validate SSE format
79+
self.assertIn("Content-Type", r.headers)
80+
self.assertIn("text/event-stream", r.headers["Content-Type"])
81+
82+
# SSE format (data: []) is hard to parse, use helper library for simplicity
83+
client = sseclient.SSEClient(r)
84+
res_count = 0
85+
for i, event in enumerate(client.events()):
86+
# Parse event data, join events into a single response
87+
data = json.loads(event.data)
88+
print(f"Event {i}:", data)
89+
self.assertTrue("TEXT" in data)
90+
self.assertEqual(text, data["TEXT"])
91+
res_count += 1
92+
self.assertTrue(rep_count, res_count)
93+
94+
95+
if __name__ == "__main__":
96+
unittest.main()

qa/L0_http/test.sh

+40
Original file line numberDiff line numberDiff line change
@@ -629,6 +629,46 @@ set -e
629629
kill $SERVER_PID
630630
wait $SERVER_PID
631631

632+
### LLM REST API Endpoint Tests ###
633+
634+
# Helper library to parse SSE events
635+
# https://github.com/mpetazzoni/sseclient
636+
pip install sseclient-py
637+
638+
SERVER_ARGS="--model-repository=`pwd`/llm_models"
639+
SERVER_LOG="./inference_server_llm_test.log"
640+
CLIENT_LOG="./llm_test.log"
641+
run_server
642+
if [ "$SERVER_PID" == "0" ]; then
643+
echo -e "\n***\n*** Failed to start $SERVER\n***"
644+
cat $SERVER_LOG
645+
exit 1
646+
fi
647+
648+
## Python Unit Tests
649+
TEST_RESULT_FILE='test_results.txt'
650+
PYTHON_TEST=llm_test.py
651+
EXPECTED_NUM_TESTS=2
652+
set +e
653+
python3 $PYTHON_TEST >$CLIENT_LOG 2>&1
654+
if [ $? -ne 0 ]; then
655+
cat $CLIENT_LOG
656+
RET=1
657+
else
658+
check_test_results $TEST_RESULT_FILE $EXPECTED_NUM_TESTS
659+
if [ $? -ne 0 ]; then
660+
cat $CLIENT_LOG
661+
echo -e "\n***\n*** Test Result Verification Failed\n***"
662+
RET=1
663+
fi
664+
fi
665+
set -e
666+
667+
kill $SERVER_PID
668+
wait $SERVER_PID
669+
670+
###
671+
632672
if [ $RET -eq 0 ]; then
633673
echo -e "\n***\n*** Test Passed\n***"
634674
else

0 commit comments

Comments
 (0)