Skip to content

Commit 0049763

Browse files
authored
Add test to check the output memory type for onnx models (#6033)
* Add test to check the output memory type for onnx models * Remove unused import * Address comment
1 parent fd96f23 commit 0049763

File tree

3 files changed

+207
-3
lines changed

3 files changed

+207
-3
lines changed

qa/L0_warmup/test.sh

+82-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#!/bin/bash
2-
# Copyright 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# Copyright 2019-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
33
#
44
# Redistribution and use in source and binary forms, with or without
55
# modification, are permitted provided that the following conditions
@@ -42,6 +42,9 @@ export CUDA_VISIBLE_DEVICES=0
4242

4343
CLIENT=../clients/image_client
4444
CLIENT_LOG="./client.log"
45+
CLIENT_PY=./python_unittest.py
46+
EXPECTED_NUM_TESTS="1"
47+
TEST_RESULT_FILE='test_results.txt'
4548

4649
IMAGE="../images/vulture.jpeg"
4750

@@ -56,6 +59,7 @@ SERVER_LOG="./inference_server.log"
5659
source ../common/util.sh
5760

5861
RET=0
62+
rm -fr *.txt
5963

6064
for BACKEND in ${BACKENDS}; do
6165
rm -f $SERVER_LOG $CLIENT_LOG
@@ -408,8 +412,83 @@ set -e
408412
kill $SERVER_PID
409413
wait $SERVER_PID
410414

411-
if [ $RET -eq 0 ]; then
412-
echo -e "\n***\n*** Test Passed\n***"
415+
# Test the onnx model to verify that the memory type of the output tensor
416+
# remains unchanged with the warmup setting
417+
pip3 uninstall -y torch
418+
pip3 install torch==1.13.0+cu117 -f https://download.pytorch.org/whl/torch_stable.html
419+
420+
rm -fr models && mkdir models
421+
cp -r /data/inferenceserver/${REPO_VERSION}/qa_model_repository/onnx_nobatch_float32_float32_float32 models/.
422+
(cd models/onnx_nobatch_float32_float32_float32 && \
423+
echo "" >> config.pbtxt && \
424+
echo 'instance_group [{' >> config.pbtxt && \
425+
echo ' kind : KIND_GPU' >> config.pbtxt && \
426+
echo '}]' >> config.pbtxt && \
427+
echo 'model_warmup [{' >> config.pbtxt && \
428+
echo ' name : "sample"' >> config.pbtxt && \
429+
echo ' batch_size: 1' >> config.pbtxt && \
430+
echo ' inputs {' >> config.pbtxt && \
431+
echo ' key: "INPUT0"' >> config.pbtxt && \
432+
echo ' value: {' >> config.pbtxt && \
433+
echo ' data_type: TYPE_FP32' >> config.pbtxt && \
434+
echo " dims: 16" >> config.pbtxt && \
435+
echo " zero_data: false" >> config.pbtxt && \
436+
echo ' }' >> config.pbtxt && \
437+
echo ' }' >> config.pbtxt && \
438+
echo ' inputs {' >> config.pbtxt && \
439+
echo ' key: "INPUT1"' >> config.pbtxt && \
440+
echo ' value: {' >> config.pbtxt && \
441+
echo ' data_type: TYPE_FP32' >> config.pbtxt && \
442+
echo " dims: 16" >> config.pbtxt && \
443+
echo " zero_data: false" >> config.pbtxt && \
444+
echo ' }' >> config.pbtxt && \
445+
echo ' }' >> config.pbtxt && \
446+
echo '}]' >> config.pbtxt )
447+
448+
mkdir -p models/bls_onnx_warmup/1/
449+
cp ../python_models/bls_onnx_warmup/model.py models/bls_onnx_warmup/1/
450+
cp ../python_models/bls_onnx_warmup/config.pbtxt models/bls_onnx_warmup/.
451+
452+
cp ../L0_backend_python/python_unittest.py .
453+
sed -i 's#sys.path.append("../../common")#sys.path.append("../common")#g' python_unittest.py
454+
455+
run_server
456+
if [ "$SERVER_PID" == "0" ]; then
457+
echo -e "\n***\n*** Failed to start $SERVER\n***"
458+
cat $SERVER_LOG
459+
exit 1
460+
fi
461+
462+
set +e
463+
464+
export MODEL_NAME='bls_onnx_warmup'
465+
python3 $CLIENT_PY >> $CLIENT_LOG 2>&1
466+
if [ $? -ne 0 ]; then
467+
echo -e "\n***\n*** 'bls_onnx_warmup' test FAILED. \n***"
468+
cat $CLIENT_LOG
469+
RET=1
470+
else
471+
check_test_results $TEST_RESULT_FILE $EXPECTED_NUM_TESTS
472+
if [ $? -ne 0 ]; then
473+
cat $CLIENT_LOG
474+
echo -e "\n***\n*** Test Result Verification Failed\n***"
475+
RET=1
476+
fi
477+
fi
478+
479+
set -e
480+
481+
482+
kill $SERVER_PID
483+
wait $SERVER_PID
484+
485+
486+
if [ $RET -eq 1 ]; then
487+
cat $CLIENT_LOG
488+
cat $SERVER_LOG
489+
echo -e "\n***\n*** Test Failed \n***"
490+
else
491+
echo -e "\n***\n*** Test Passed \n***"
413492
fi
414493

415494
exit $RET
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# Redistribution and use in source and binary forms, with or without
4+
# modification, are permitted provided that the following conditions
5+
# are met:
6+
# * Redistributions of source code must retain the above copyright
7+
# notice, this list of conditions and the following disclaimer.
8+
# * Redistributions in binary form must reproduce the above copyright
9+
# notice, this list of conditions and the following disclaimer in the
10+
# documentation and/or other materials provided with the distribution.
11+
# * Neither the name of NVIDIA CORPORATION nor the names of its
12+
# contributors may be used to endorse or promote products derived
13+
# from this software without specific prior written permission.
14+
#
15+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
16+
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18+
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
19+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22+
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
23+
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26+
27+
name: "bls_onnx_warmup"
28+
backend: "python"
29+
30+
output [
31+
{
32+
name: "OUTPUT0"
33+
data_type: TYPE_FP32
34+
dims: [ 16 ]
35+
}
36+
]
37+
38+
instance_group [{ kind: KIND_CPU }]
+87
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# Redistribution and use in source and binary forms, with or without
4+
# modification, are permitted provided that the following conditions
5+
# are met:
6+
# * Redistributions of source code must retain the above copyright
7+
# notice, this list of conditions and the following disclaimer.
8+
# * Redistributions in binary form must reproduce the above copyright
9+
# notice, this list of conditions and the following disclaimer in the
10+
# documentation and/or other materials provided with the distribution.
11+
# * Neither the name of NVIDIA CORPORATION nor the names of its
12+
# contributors may be used to endorse or promote products derived
13+
# from this software without specific prior written permission.
14+
#
15+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
16+
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18+
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
19+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22+
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
23+
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26+
27+
import numpy as np
28+
import unittest
29+
import triton_python_backend_utils as pb_utils
30+
from torch.utils.dlpack import from_dlpack
31+
32+
33+
class PBBLSONNXWarmupTest(unittest.TestCase):
34+
35+
def test_onnx_output_mem_type(self):
36+
input0_np = np.random.randn(*[16])
37+
input0_np = input0_np.astype(np.float32)
38+
input1_np = np.random.randn(*[16])
39+
input1_np = input1_np.astype(np.float32)
40+
input0 = pb_utils.Tensor('INPUT0', input0_np)
41+
input1 = pb_utils.Tensor('INPUT1', input1_np)
42+
infer_request = pb_utils.InferenceRequest(
43+
model_name='onnx_nobatch_float32_float32_float32',
44+
inputs=[input0, input1],
45+
requested_output_names=['OUTPUT0', 'OUTPUT1'])
46+
47+
infer_response = infer_request.exec()
48+
49+
self.assertFalse(infer_response.has_error())
50+
51+
output0 = pb_utils.get_output_tensor_by_name(infer_response, 'OUTPUT0')
52+
output1 = pb_utils.get_output_tensor_by_name(infer_response, 'OUTPUT1')
53+
54+
self.assertIsNotNone(output0)
55+
self.assertIsNotNone(output1)
56+
57+
# The memory type of output tensor should be GPU
58+
self.assertFalse(output0.is_cpu())
59+
self.assertFalse(output1.is_cpu())
60+
61+
expected_output_0 = input0.as_numpy() - input1.as_numpy()
62+
expected_output_1 = input0.as_numpy() + input1.as_numpy()
63+
64+
output0 = from_dlpack(
65+
output0.to_dlpack()).to('cpu').cpu().detach().numpy()
66+
output1 = from_dlpack(
67+
output1.to_dlpack()).to('cpu').cpu().detach().numpy()
68+
69+
self.assertTrue(np.all(output0 == expected_output_0))
70+
self.assertTrue(np.all(output1 == expected_output_1))
71+
72+
73+
class TritonPythonModel:
74+
75+
def execute(self, requests):
76+
responses = []
77+
for _ in requests:
78+
# Run the unittest and store the results in InferenceResponse.
79+
test = unittest.main('model', exit=False)
80+
responses.append(
81+
pb_utils.InferenceResponse([
82+
pb_utils.Tensor(
83+
'OUTPUT0',
84+
np.array([test.result.wasSuccessful()],
85+
dtype=np.float16))
86+
]))
87+
return responses

0 commit comments

Comments
 (0)