Skip to content

Commit a439c22

Browse files
committed
Add testing for Pytorch instance group kind MODEL
1 parent aea0ae6 commit a439c22

File tree

5 files changed

+360
-0
lines changed

5 files changed

+360
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
#!/usr/bin/env python
2+
# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
#
4+
# Redistribution and use in source and binary forms, with or without
5+
# modification, are permitted provided that the following conditions
6+
# are met:
7+
# * Redistributions of source code must retain the above copyright
8+
# notice, this list of conditions and the following disclaimer.
9+
# * Redistributions in binary form must reproduce the above copyright
10+
# notice, this list of conditions and the following disclaimer in the
11+
# documentation and/or other materials provided with the distribution.
12+
# * Neither the name of NVIDIA CORPORATION nor the names of its
13+
# contributors may be used to endorse or promote products derived
14+
# from this software without specific prior written permission.
15+
#
16+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
17+
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
19+
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
20+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
21+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23+
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
24+
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
25+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
26+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27+
28+
import os
29+
import sys
30+
31+
sys.path.append("../common")
32+
33+
import torch
34+
import unittest
35+
import numpy as np
36+
import test_util as tu
37+
38+
import tritonclient.http as httpclient
39+
from tritonclient.utils import InferenceServerException
40+
41+
# By default, find tritonserver on "localhost", but can be overridden
42+
# with TRITONSERVER_IPADDR envvar
43+
_tritonserver_ipaddr = os.environ.get('TRITONSERVER_IPADDR', 'localhost')
44+
45+
46+
class InferTest(tu.TestResultCollector):
47+
48+
def test_infer(self):
49+
try:
50+
triton_client = httpclient.InferenceServerClient(
51+
url=f"{_tritonserver_ipaddr}:8000")
52+
except Exception as e:
53+
print("channel creation failed: " + str(e))
54+
sys.exit(1)
55+
56+
model_name = os.environ['MODEL_NAME']
57+
58+
inputs = []
59+
outputs = []
60+
inputs.append(httpclient.InferInput('INPUT0', [1, 16], "FP32"))
61+
inputs.append(httpclient.InferInput('INPUT1', [1, 16], "FP32"))
62+
63+
# Create the data for the two input tensors.
64+
input0_data = np.arange(start=0, stop=16, dtype=np.float32)
65+
input0_data = np.expand_dims(input0_data, axis=0)
66+
input1_data = np.arange(start=32, stop=48, dtype=np.float32)
67+
input1_data = np.expand_dims(input1_data, axis=0)
68+
69+
# Initialize the data
70+
inputs[0].set_data_from_numpy(input0_data, binary_data=True)
71+
inputs[1].set_data_from_numpy(input1_data, binary_data=True)
72+
73+
outputs.append(
74+
httpclient.InferRequestedOutput('OUTPUT__0', binary_data=True))
75+
outputs.append(
76+
httpclient.InferRequestedOutput('OUTPUT__1', binary_data=True))
77+
78+
if model_name == "libtorch_instance_kind_err":
79+
with self.assertRaises(InferenceServerException) as ex:
80+
results = triton_client.infer(model_name,
81+
inputs,
82+
outputs=outputs)
83+
self.assertIn(
84+
"Expected all tensors to be on the same device, but found at least two devices",
85+
str(ex.exception))
86+
return
87+
88+
results = triton_client.infer(model_name, inputs, outputs=outputs)
89+
90+
output0_data = results.as_numpy('OUTPUT__0')
91+
output1_data = results.as_numpy('OUTPUT__1')
92+
93+
# Only validate the shape, as the output will differ every time the
94+
# model is compiled and used on different devices.
95+
self.assertEqual(output0_data.shape, (1, 4))
96+
self.assertEqual(output1_data.shape, (1, 4))
97+
98+
99+
if __name__ == '__main__':
100+
unittest.main()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
#!/usr/bin/python
2+
# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
#
4+
# Redistribution and use in source and binary forms, with or without
5+
# modification, are permitted provided that the following conditions
6+
# are met:
7+
# * Redistributions of source code must retain the above copyright
8+
# notice, this list of conditions and the following disclaimer.
9+
# * Redistributions in binary form must reproduce the above copyright
10+
# notice, this list of conditions and the following disclaimer in the
11+
# documentation and/or other materials provided with the distribution.
12+
# * Neither the name of NVIDIA CORPORATION nor the names of its
13+
# contributors may be used to endorse or promote products derived
14+
# from this software without specific prior written permission.
15+
#
16+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
17+
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
19+
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
20+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
21+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23+
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
24+
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
25+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
26+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27+
28+
import torch
29+
import torch.nn as nn
30+
from torch.nn.parallel import DataParallel
31+
32+
33+
class TestModel(nn.Module):
34+
35+
def __init__(self, device1, device2):
36+
super(TestModel, self).__init__()
37+
self.device1 = device1
38+
self.device2 = device2
39+
self.layers1 = nn.Sequential(nn.Linear(16, 4),).to(self.device1)
40+
self.layers2 = nn.Sequential(nn.Linear(16, 4)).to(self.device2)
41+
42+
def forward(self, INPUT0, INPUT1):
43+
INPUT0 = INPUT0.to(self.device1)
44+
INPUT1 = INPUT1.to(self.device2)
45+
print('INPUT0 device: {}, INPUT1 device: {}\n'.format(
46+
INPUT0.device, INPUT1.device))
47+
48+
op0 = self.layers1(INPUT0 + INPUT0)
49+
op1 = self.layers2(INPUT1 + INPUT1)
50+
return op0, op1
51+
52+
53+
devices = [("cuda:2", "cuda:0"), ("cpu", "cuda:3")]
54+
model_names = ["libtorch_multi_gpu", "libtorch_multi_devices"]
55+
56+
for device_pair, model_name in zip(devices, model_names):
57+
model = TestModel(device_pair[0], device_pair[1])
58+
model_path = "models/" + model_name + "/1/model.pt"
59+
scripted_model = torch.jit.script(model)
60+
scripted_model.save(model_path)
Binary file not shown.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
name: "libtorch_multi_devices"
2+
platform: "pytorch_libtorch"
3+
max_batch_size: 8
4+
5+
input [
6+
{
7+
name: "INPUT0"
8+
data_type: TYPE_FP32
9+
dims: [ 16 ]
10+
},
11+
{
12+
name: "INPUT1"
13+
data_type: TYPE_FP32
14+
dims: [ 16 ]
15+
}
16+
]
17+
output [
18+
{
19+
name: "OUTPUT__0"
20+
data_type: TYPE_FP32
21+
dims: [ 4 ]
22+
},
23+
{
24+
name: "OUTPUT__1"
25+
data_type: TYPE_FP32
26+
dims: [ 4 ]
27+
}
28+
]
29+
30+
instance_group [
31+
{
32+
kind: KIND_MODEL
33+
}
34+
]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
#!/bin/bash
2+
# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
#
4+
# Redistribution and use in source and binary forms, with or without
5+
# modification, are permitted provided that the following conditions
6+
# are met:
7+
# * Redistributions of source code must retain the above copyright
8+
# notice, this list of conditions and the following disclaimer.
9+
# * Redistributions in binary form must reproduce the above copyright
10+
# notice, this list of conditions and the following disclaimer in the
11+
# documentation and/or other materials provided with the distribution.
12+
# * Neither the name of NVIDIA CORPORATION nor the names of its
13+
# contributors may be used to endorse or promote products derived
14+
# from this software without specific prior written permission.
15+
#
16+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
17+
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
19+
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
20+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
21+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23+
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
24+
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
25+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
26+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27+
28+
REPO_VERSION=${NVIDIA_TRITON_SERVER_VERSION}
29+
if [ "$#" -ge 1 ]; then
30+
REPO_VERSION=$1
31+
fi
32+
if [ -z "$REPO_VERSION" ]; then
33+
echo -e "Repository version must be specified"
34+
echo -e "\n***\n*** Test Failed\n***"
35+
exit 1
36+
fi
37+
if [ ! -z "$TEST_REPO_ARCH" ]; then
38+
REPO_VERSION=${REPO_VERSION}_${TEST_REPO_ARCH}
39+
fi
40+
41+
pip3 uninstall -y torch
42+
pip3 install torch
43+
44+
DATADIR=/data/inferenceserver/${REPO_VERSION}/qa_model_repository
45+
SERVER=/opt/tritonserver/bin/tritonserver
46+
SERVER_ARGS="--model-repository=models --log-verbose=1"
47+
SERVER_LOG="./inference_server.log"
48+
49+
CLIENT_PY=./client.py
50+
CLIENT_LOG="./client.log"
51+
EXPECTED_NUM_TESTS="1"
52+
TEST_RESULT_FILE='test_results.txt'
53+
54+
source ../common/util.sh
55+
56+
RET=0
57+
58+
rm -f *.log *.txt
59+
60+
mkdir -p models/libtorch_multi_gpu/1
61+
cp models/libtorch_multi_devices/config.pbtxt models/libtorch_multi_gpu/.
62+
(cd models/libtorch_multi_gpu && \
63+
sed -i "s/name: \"libtorch_multi_devices\"/name: \"libtorch_multi_gpu\"/" config.pbtxt)
64+
65+
# Generate the models which are partioned across multiple devices
66+
set +e
67+
python3 gen_models.py >> $CLIENT_LOG 2>&1
68+
if [ $? -ne 0 ]; then
69+
echo -e "\n***\n*** Error when generating models. \n***"
70+
cat $CLIENT_LOG
71+
RET=1
72+
fi
73+
set -e
74+
75+
# Create the model that does not set instance_group_kind to 'KIND_MODEL'
76+
mkdir -p models/libtorch_instance_kind_err/1
77+
cp models/libtorch_multi_devices/config.pbtxt models/libtorch_instance_kind_err/.
78+
cp models/libtorch_multi_devices/1/model.pt models/libtorch_instance_kind_err/1/.
79+
(cd models/libtorch_instance_kind_err && \
80+
sed -i "s/name: \"libtorch_multi_devices\"/name: \"libtorch_instance_kind_err\"/" config.pbtxt && \
81+
sed -i "s/kind: KIND_MODEL/kind: KIND_GPU/" config.pbtxt)
82+
83+
run_server
84+
if [ "$SERVER_PID" == "0" ]; then
85+
echo -e "\n***\n*** Failed to start $SERVER\n***"
86+
cat $SERVER_LOG
87+
exit 1
88+
fi
89+
90+
set +e
91+
92+
MESSAGE="INPUT0 device: cpu, INPUT1 device: cuda:3"
93+
export MODEL_NAME='libtorch_multi_devices'
94+
python3 $CLIENT_PY >> $CLIENT_LOG 2>&1
95+
if [ $? -ne 0 ]; then
96+
echo -e "\n***\n*** Model $MODEL_NAME FAILED. \n***"
97+
cat $CLIENT_LOG
98+
RET=1
99+
else
100+
check_test_results $TEST_RESULT_FILE $EXPECTED_NUM_TESTS
101+
if [ $? -ne 0 ]; then
102+
cat $CLIENT_LOG
103+
echo -e "\n***\n*** Test Result Verification Failed\n***"
104+
RET=1
105+
fi
106+
fi
107+
108+
if grep "$MESSAGE" $SERVER_LOG; then
109+
echo -e "Found \"$MESSAGE\"" >> $CLIENT_LOG
110+
else
111+
echo -e "Not found \"$MESSAGE\"" >> $CLIENT_LOG
112+
RET=1
113+
fi
114+
115+
MESSAGE="INPUT0 device: cuda:2, INPUT1 device: cuda:0"
116+
export MODEL_NAME='libtorch_multi_gpu'
117+
python3 $CLIENT_PY >> $CLIENT_LOG 2>&1
118+
if [ $? -ne 0 ]; then
119+
echo -e "\n***\n*** Model $MODEL_NAME FAILED. \n***"
120+
cat $CLIENT_LOG
121+
RET=1
122+
else
123+
check_test_results $TEST_RESULT_FILE $EXPECTED_NUM_TESTS
124+
if [ $? -ne 0 ]; then
125+
cat $CLIENT_LOG
126+
echo -e "\n***\n*** Test Result Verification Failed\n***"
127+
RET=1
128+
fi
129+
fi
130+
131+
if grep "$MESSAGE" $SERVER_LOG; then
132+
echo -e "Found \"$MESSAGE\"" >> $CLIENT_LOG
133+
else
134+
echo -e "Not found \"$MESSAGE\"" >> $CLIENT_LOG
135+
RET=1
136+
fi
137+
138+
MESSAGE="INPUT0 device: cuda:2, INPUT1 device: cuda:0"
139+
export MODEL_NAME='libtorch_instance_kind_err'
140+
python3 $CLIENT_PY >> $CLIENT_LOG 2>&1
141+
if [ $? -ne 0 ]; then
142+
echo -e "\n***\n*** Model $MODEL_NAME FAILED. \n***"
143+
cat $CLIENT_LOG
144+
RET=1
145+
else
146+
check_test_results $TEST_RESULT_FILE $EXPECTED_NUM_TESTS
147+
if [ $? -ne 0 ]; then
148+
cat $CLIENT_LOG
149+
echo -e "\n***\n*** Test Result Verification Failed\n***"
150+
RET=1
151+
fi
152+
fi
153+
154+
set -e
155+
156+
kill $SERVER_PID
157+
wait $SERVER_PID
158+
159+
if [ $RET -eq 0 ]; then
160+
echo -e "\n***\n*** Test Passed\n***"
161+
else
162+
cat $CLIENT_LOG
163+
echo -e "\n***\n*** Test FAILED\n***"
164+
fi
165+
166+
exit $RET

0 commit comments

Comments
 (0)