Skip to content

Commit 28bb9c9

Browse files
rmccorm4dyastremskyGuanLuo
authored
Add more comprehensive llm endpoint tests (#6377)
* Add security policy (#6376) * Start adding some more comprehensive tests * Fix test case * Add response error testing * Complete test placeholder * Address comment * Address comments * Fix code check --------- Co-authored-by: dyastremsky <[email protected]> Co-authored-by: GuanLuo <[email protected]>
1 parent 5b969a0 commit 28bb9c9

File tree

6 files changed

+394
-108
lines changed

6 files changed

+394
-108
lines changed

SECURITY.md

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Report a Security Vulnerability
2+
3+
To report a potential security vulnerability in any NVIDIA product, please use either:
4+
* This web form: [Security Vulnerability Submission Form](https://www.nvidia.com/object/submit-security-vulnerability.html), or
5+
* Send email to: [NVIDIA PSIRT](mailto:[email protected])
6+
7+
**OEM Partners should contact their NVIDIA Customer Program Manager**
8+
9+
If reporting a potential vulnerability via email, please encrypt it using NVIDIA’s public PGP key ([see PGP Key page](https://www.nvidia.com/en-us/security/pgp-key/)) and include the following information:
10+
1. Product/Driver name and version/branch that contains the vulnerability
11+
2. Type of vulnerability (code execution, denial of service, buffer overflow, etc.)
12+
3. Instructions to reproduce the vulnerability
13+
4. Proof-of-concept or exploit code
14+
5. Potential impact of the vulnerability, including how an attacker could exploit the vulnerability
15+
16+
See https://www.nvidia.com/en-us/security/ for past NVIDIA Security Bulletins and Notices.

qa/L0_http/generate_endpoint_test.py

+354
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,354 @@
1+
#!/usr/bin/python3
2+
# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
#
4+
# Redistribution and use in source and binary forms, with or without
5+
# modification, are permitted provided that the following conditions
6+
# are met:
7+
# * Redistributions of source code must retain the above copyright
8+
# notice, this list of conditions and the following disclaimer.
9+
# * Redistributions in binary form must reproduce the above copyright
10+
# notice, this list of conditions and the following disclaimer in the
11+
# documentation and/or other materials provided with the distribution.
12+
# * Neither the name of NVIDIA CORPORATION nor the names of its
13+
# contributors may be used to endorse or promote products derived
14+
# from this software without specific prior written permission.
15+
#
16+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
17+
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
19+
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
20+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
21+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23+
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
24+
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
25+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
26+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27+
28+
import sys
29+
30+
sys.path.append("../common")
31+
32+
import json
33+
import threading
34+
import time
35+
import unittest
36+
37+
import requests
38+
import sseclient
39+
import test_util as tu
40+
41+
42+
class GenerateEndpointTest(tu.TestResultCollector):
43+
def setUp(self):
44+
self._model_name = "mock_llm"
45+
46+
def _get_infer_url(self, model_name, route):
47+
return f"http://localhost:8000/v2/models/{model_name}/{route}"
48+
49+
def generate_stream(self, model_name, inputs, stream=False):
50+
headers = {"Accept": "text/event-stream"}
51+
url = self._get_infer_url(model_name, "generate_stream")
52+
# stream=True used to indicate response can be iterated over, which
53+
# should be the common setting for generate_stream.
54+
# For correctness test case, stream=False so that we can re-examine
55+
# the response content.
56+
return requests.post(
57+
url,
58+
data=inputs if isinstance(inputs, str) else json.dumps(inputs),
59+
headers=headers,
60+
stream=stream,
61+
)
62+
63+
def generate(self, model_name, inputs):
64+
url = self._get_infer_url(model_name, "generate")
65+
return requests.post(url, data=json.dumps(inputs))
66+
67+
def generate_expect_failure(self, model_name, inputs, msg):
68+
url = self._get_infer_url(model_name, "generate")
69+
r = requests.post(
70+
url, data=inputs if isinstance(inputs, str) else json.dumps(inputs)
71+
)
72+
try:
73+
r.raise_for_status()
74+
self.assertTrue(False, f"Expected failure, success for {inputs}")
75+
except requests.exceptions.HTTPError as e:
76+
self.assertIn(msg, r.json()["error"])
77+
78+
def generate_stream_expect_failure(self, model_name, inputs, msg):
79+
r = self.generate_stream(model_name, inputs)
80+
try:
81+
r.raise_for_status()
82+
self.assertTrue(False, f"Expected failure, success for {inputs}")
83+
except requests.exceptions.HTTPError as e:
84+
self.assertIn(msg, r.json()["error"])
85+
86+
def generate_stream_expect_success(
87+
self, model_name, inputs, expected_output, rep_count
88+
):
89+
r = self.generate_stream(model_name, inputs)
90+
r.raise_for_status()
91+
self.check_sse_responses(r, [{"TEXT": expected_output}] * rep_count)
92+
93+
def check_sse_responses(self, res, expected_res):
94+
# Validate SSE format
95+
self.assertIn("Content-Type", res.headers)
96+
self.assertIn("text/event-stream", res.headers["Content-Type"])
97+
98+
# SSE format (data: []) is hard to parse, use helper library for simplicity
99+
client = sseclient.SSEClient(res)
100+
res_count = 0
101+
for event in client.events():
102+
# Parse event data, join events into a single response
103+
data = json.loads(event.data)
104+
for key, value in expected_res[res_count].items():
105+
self.assertIn(key, data)
106+
self.assertEqual(value, data[key])
107+
res_count += 1
108+
self.assertTrue(len(expected_res), res_count)
109+
# Make sure there is no message in the wrong form
110+
for remaining in client._read():
111+
self.assertTrue(
112+
remaining.startswith(b"data:"),
113+
f"SSE response not formed properly, got: {remaining}",
114+
)
115+
self.assertTrue(
116+
remaining.endswith(b"\n\n"),
117+
f"SSE response not formed properly, got: {remaining}",
118+
)
119+
120+
def test_generate(self):
121+
# Setup text-based input
122+
text = "hello world"
123+
inputs = {"PROMPT": text, "STREAM": False}
124+
125+
r = self.generate(self._model_name, inputs)
126+
r.raise_for_status()
127+
128+
self.assertIn("Content-Type", r.headers)
129+
self.assertIn("application/json", r.headers["Content-Type"])
130+
131+
data = r.json()
132+
self.assertIn("TEXT", data)
133+
self.assertEqual(text, data["TEXT"])
134+
135+
def test_generate_stream(self):
136+
# Setup text-based input
137+
text = "hello world"
138+
rep_count = 3
139+
inputs = {"PROMPT": [text], "STREAM": True, "REPETITION": rep_count}
140+
self.generate_stream_expect_success(self._model_name, inputs, text, rep_count)
141+
142+
def test_streaming(self):
143+
# verify the responses are streamed as soon as it is generated
144+
text = "hello world"
145+
rep_count = 3
146+
inputs = {"PROMPT": [text], "STREAM": True, "REPETITION": rep_count, "DELAY": 2}
147+
past = time.time()
148+
res = self.generate_stream(self._model_name, inputs, stream=True)
149+
client = sseclient.SSEClient(res)
150+
# This test does not focus on event content
151+
for _ in client.events():
152+
now = time.time()
153+
self.assertTrue(1 < (now - past) < 3)
154+
past = now
155+
156+
def test_missing_inputs(self):
157+
missing_all_inputs = [
158+
# Missing all inputs
159+
{},
160+
{"abc": 123},
161+
]
162+
missing_one_input = [
163+
# Missing 1 input
164+
{"PROMPT": "hello"},
165+
{"STREAM": False},
166+
{"STREAM": False, "other": "param"},
167+
]
168+
for inputs in missing_all_inputs:
169+
self.generate_expect_failure(
170+
self._model_name, inputs, "expected 2 inputs but got 0"
171+
)
172+
self.generate_stream_expect_failure(
173+
self._model_name, inputs, "expected 2 inputs but got 0"
174+
)
175+
176+
for inputs in missing_one_input:
177+
self.generate_expect_failure(
178+
self._model_name, inputs, "expected 2 inputs but got 1"
179+
)
180+
self.generate_stream_expect_failure(
181+
self._model_name, inputs, "expected 2 inputs but got 1"
182+
)
183+
184+
def test_invalid_input_types(self):
185+
invalid_bool = "attempt to access JSON non-boolean as boolean"
186+
invalid_string = "attempt to access JSON non-string as string"
187+
invalid_type_inputs = [
188+
# Prompt bad type
189+
({"PROMPT": 123, "STREAM": False}, invalid_string),
190+
# Stream bad type
191+
({"PROMPT": "hello", "STREAM": "false"}, invalid_bool),
192+
# Both bad type, parsed in order
193+
({"PROMPT": True, "STREAM": 123}, invalid_string),
194+
({"STREAM": 123, "PROMPT": True}, invalid_bool),
195+
]
196+
197+
for inputs, error_msg in invalid_type_inputs:
198+
self.generate_expect_failure(self._model_name, inputs, error_msg)
199+
self.generate_stream_expect_failure(self._model_name, inputs, error_msg)
200+
201+
def test_duplicate_inputs(self):
202+
dupe_prompt = "input 'PROMPT' already exists in request"
203+
dupe_stream = "input 'STREAM' already exists in request"
204+
# Use JSON string directly as Python Dict doesn't support duplicate keys
205+
invalid_type_inputs = [
206+
# One duplicate
207+
(
208+
'{"PROMPT": "hello", "STREAM": false, "PROMPT": "duplicate"}',
209+
dupe_prompt,
210+
),
211+
('{"PROMPT": "hello", "STREAM": false, "STREAM": false}', dupe_stream),
212+
# Multiple duplicates, parsed in order
213+
(
214+
'{"PROMPT": "hello", "STREAM": false, "PROMPT": "duplicate", "STREAM": true}',
215+
dupe_prompt,
216+
),
217+
(
218+
'{"PROMPT": "hello", "STREAM": false, "STREAM": true, "PROMPT": "duplicate"}',
219+
dupe_stream,
220+
),
221+
]
222+
for inputs, error_msg in invalid_type_inputs:
223+
self.generate_expect_failure(self._model_name, inputs, error_msg)
224+
self.generate_stream_expect_failure(self._model_name, inputs, error_msg)
225+
226+
def test_generate_stream_response_error(self):
227+
# Setup text-based input
228+
text = "hello world"
229+
inputs = {"PROMPT": [text], "STREAM": True, "REPETITION": 0, "FAIL_LAST": True}
230+
r = self.generate_stream(self._model_name, inputs)
231+
232+
# With "REPETITION": 0, error will be first response and the HTTP code
233+
# will be set properly
234+
try:
235+
r.raise_for_status()
236+
except requests.exceptions.HTTPError as e:
237+
self.check_sse_responses(r, [{"error": "An Error Occurred"}])
238+
239+
# With "REPETITION" > 0, the first response is valid response and set
240+
# HTTP code to success, so user must validate each response
241+
inputs["REPETITION"] = 1
242+
r = self.generate_stream(self._model_name, inputs)
243+
r.raise_for_status()
244+
245+
self.check_sse_responses(r, [{"TEXT": text}, {"error": "An Error Occurred"}])
246+
247+
def test_race_condition(self):
248+
# HTTP response send logic and Triton response complete logic are
249+
# performed in different threads, both have shared access to the same
250+
# generate request object, and thus send sufficient load to the endpoint
251+
# in case of race condition.
252+
input1 = {"PROMPT": "hello", "STREAM": False, "param": "segfault"}
253+
input2 = {
254+
"PROMPT": "hello",
255+
"STREAM": True,
256+
"REPETITION": 3,
257+
"param": "segfault",
258+
}
259+
threads = []
260+
261+
def thread_func(model_name, input):
262+
self.generate_stream(model_name, input1).raise_for_status()
263+
264+
for _ in range(50):
265+
threads.append(
266+
threading.Thread(target=thread_func, args=((self._model_name, input1)))
267+
)
268+
threads.append(
269+
threading.Thread(target=thread_func, args=((self._model_name, input2)))
270+
)
271+
for thread in threads:
272+
thread.start()
273+
for thread in threads:
274+
thread.join()
275+
276+
def test_one_response(self):
277+
# "STREAM" controls response behavior,
278+
# True sends two responses, one with infer response and one with flag
279+
# only, where generate endpoint will be able to handle as in
280+
# non-decoupled mode.
281+
# False sends one response with infer response and flag, which is the
282+
# same behavior as how non-decoupled model return response.
283+
inputs = {"PROMPT": "hello world", "STREAM": True}
284+
r = self.generate_stream(self._model_name, inputs)
285+
r.raise_for_status()
286+
r = self.generate(self._model_name, inputs)
287+
r.raise_for_status()
288+
289+
inputs["STREAM"] = False
290+
r = self.generate_stream(self._model_name, inputs)
291+
r.raise_for_status()
292+
r = self.generate(self._model_name, inputs)
293+
r.raise_for_status()
294+
295+
def test_zero_response(self):
296+
inputs = {"PROMPT": "hello world", "STREAM": True, "REPETITION": 0}
297+
r = self.generate_stream(self._model_name, inputs)
298+
r.raise_for_status()
299+
# Expect generate fails the inference
300+
r = self.generate(self._model_name, inputs)
301+
try:
302+
r.raise_for_status()
303+
except requests.exceptions.HTTPError as e:
304+
self.assertIn(
305+
"generate expects model to produce exactly 1 response",
306+
r.json()["error"],
307+
)
308+
309+
def test_many_response(self):
310+
inputs = {"PROMPT": "hello world", "STREAM": True, "REPETITION": 2}
311+
r = self.generate_stream(self._model_name, inputs)
312+
r.raise_for_status()
313+
# Expect generate fails the inference
314+
r = self.generate(self._model_name, inputs)
315+
try:
316+
r.raise_for_status()
317+
except requests.exceptions.HTTPError as e:
318+
self.assertIn(
319+
"generate expects model to produce exactly 1 response",
320+
r.json()["error"],
321+
)
322+
323+
def test_complex_schema(self):
324+
# Currently only the fundamental conversion is supported, nested object
325+
# in the request will results in parsing error
326+
327+
# complex object to parameters (specifying non model input)
328+
inputs = {
329+
"PROMPT": "hello world",
330+
"STREAM": True,
331+
"PARAMS": {"PARAM_0": 0, "PARAM_1": True},
332+
}
333+
r = self.generate(self._model_name, inputs)
334+
try:
335+
r.raise_for_status()
336+
except requests.exceptions.HTTPError as e:
337+
self.assertIn("parameter 'PARAMS' has invalid type", r.json()["error"])
338+
339+
# complex object to model input
340+
inputs = {
341+
"PROMPT": {"USER": "hello world", "BOT": "world hello"},
342+
"STREAM": True,
343+
}
344+
r = self.generate(self._model_name, inputs)
345+
try:
346+
r.raise_for_status()
347+
except requests.exceptions.HTTPError as e:
348+
self.assertIn(
349+
"attempt to access JSON non-string as string", r.json()["error"]
350+
)
351+
352+
353+
if __name__ == "__main__":
354+
unittest.main()

0 commit comments

Comments
 (0)