Skip to content

Commit 173fe6a

Browse files
krishung5mc-nv
authored andcommitted
Add testing for Pytorch instance group kind MODEL (#5810)
* Add testing for Pytorch instance group kind MODEL * Remove unused item * Update testing to verify the infer result * Add copyright * Remove unused import * Update pip install * Update the model to use the same add sub logic * Add torch multi-gpu and multi-device models to L0_io * Fix up model version
1 parent 429f914 commit 173fe6a

File tree

5 files changed

+425
-17
lines changed

5 files changed

+425
-17
lines changed

qa/L0_io/test.sh

+41-17
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#!/bin/bash
2-
# Copyright 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# Copyright 2019-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
33
#
44
# Redistribution and use in source and binary forms, with or without
55
# modification, are permitted provided that the following conditions
@@ -47,13 +47,11 @@ MODELSDIR=`pwd`/models
4747
DATADIR=/data/inferenceserver/${REPO_VERSION}/qa_model_repository
4848
ENSEMBLEDIR=/data/inferenceserver/${REPO_VERSION}/qa_ensemble_model_repository/qa_model_repository
4949

50-
export CUDA_VISIBLE_DEVICES=0,1
51-
5250
# Must explicitly set LD_LIBRARY_PATH so that IO_TEST_UTIL can find
5351
# libtritonserver.so.
5452
LD_LIBRARY_PATH=/opt/tritonserver/lib:$LD_LIBRARY_PATH
5553

56-
rm -f $CLIENT_LOG.*
54+
rm -f $CLIENT_LOG*
5755

5856
# PyTorch is required for the Python backend dlpack add sub models
5957
pip3 install torch==1.13.0+cu117 -f https://download.pytorch.org/whl/torch_stable.html
@@ -148,23 +146,47 @@ cp -r $MODELSDIR/fan_graphdef_float32_float32_float32 $MODELSDIR/fan_${full} &&
148146
cp -r $ENSEMBLEDIR/nop_TYPE_FP32_-1 $MODELSDIR/. && \
149147
mkdir -p $MODELSDIR/nop_TYPE_FP32_-1/1
150148

149+
# prepare libtorch multi-device and multi-gpu models
150+
cp -r ../L0_libtorch_instance_group_kind_model/models/libtorch_multi_device $MODELSDIR/.
151+
cp ../L0_libtorch_instance_group_kind_model/gen_models.py ./gen_libtorch_model.py
152+
mkdir -p $MODELSDIR/libtorch_multi_device/1
153+
mkdir -p $MODELSDIR/libtorch_multi_gpu/1
154+
cp $MODELSDIR/libtorch_multi_device/config.pbtxt $MODELSDIR/libtorch_multi_gpu/.
155+
(cd $MODELSDIR/libtorch_multi_gpu && \
156+
sed -i "s/name: \"libtorch_multi_device\"/name: \"libtorch_multi_gpu\"/" config.pbtxt)
157+
158+
set +e
159+
python3 gen_libtorch_model.py >> $CLIENT_LOG 2>&1
160+
if [ $? -ne 0 ]; then
161+
echo -e "\n***\n*** Error when generating libtorch models. \n***"
162+
cat $CLIENT_LOG
163+
RET=1
164+
fi
165+
set -e
166+
167+
TRIALS="graphdef savedmodel onnx libtorch plan python python_dlpack libtorch_multi_gpu libtorch_multi_device"
151168
for input_device in -1 0 1; do
152169
for output_device in -1 0 1; do
153-
for trial in graphdef savedmodel onnx libtorch plan python python_dlpack; do
170+
for trial in ${TRIALS}; do
154171
# TensorRT Plan should only be deployed on GPU device
155172
model_devices="-1 0 1" && [[ "$trial" == "plan" ]] && model_devices="0 1"
173+
full=${trial}_float32_float32_float32 && [[ "$trial" == "libtorch_multi"* ]] && full=${trial}
174+
156175
for model_device in $model_devices; do
157-
full=${trial}_float32_float32_float32
158176
full_log=$CLIENT_LOG.$full.$input_device.$output_device.$model_device
159177

160178
host_policy=cpu
161179
if [ "$model_device" == "-1" ]; then
162-
(cd $MODELSDIR/${full} && \
163-
sed -i "s/instance_group.*/instance_group [{ kind: KIND_CPU }]/" config.pbtxt)
180+
if [[ "$trial" != "libtorch_multi"* ]]; then
181+
(cd $MODELSDIR/${full} && \
182+
sed -i "s/instance_group.*/instance_group [{ kind: KIND_CPU }]/" config.pbtxt)
183+
fi
164184
else
165185
host_policy=gpu_${model_device}
166-
(cd $MODELSDIR/${full} && \
167-
sed -i "s/instance_group.*/instance_group [{ kind: KIND_GPU, gpus: [${model_device}] }]/" config.pbtxt)
186+
if [[ "$trial" != "libtorch_multi"* ]]; then
187+
(cd $MODELSDIR/${full} && \
188+
sed -i "s/instance_group.*/instance_group [{ kind: KIND_GPU, gpus: [${model_device}] }]/" config.pbtxt)
189+
fi
168190
fi
169191

170192
set +e
@@ -196,14 +218,16 @@ for input_device in -1 0 1; do
196218
set -e
197219

198220
# ensemble
199-
set +e
200-
$IO_TEST_UTIL -i $input_device -o $output_device -r $MODELSDIR -m fan_$full >>$full_log.ensemble 2>&1
201-
if [ $? -ne 0 ]; then
202-
cat $full_log.ensemble
203-
echo -e "\n***\n*** Test Failed\n***"
204-
RET=1
221+
if [[ "$trial" != "libtorch_multi"* ]]; then
222+
set +e
223+
$IO_TEST_UTIL -i $input_device -o $output_device -r $MODELSDIR -m fan_$full >>$full_log.ensemble 2>&1
224+
if [ $? -ne 0 ]; then
225+
cat $full_log.ensemble
226+
echo -e "\n***\n*** Test Failed\n***"
227+
RET=1
228+
fi
229+
set -e
205230
fi
206-
set -e
207231
done
208232
done
209233

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
#!/usr/bin/env python
2+
# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
#
4+
# Redistribution and use in source and binary forms, with or without
5+
# modification, are permitted provided that the following conditions
6+
# are met:
7+
# * Redistributions of source code must retain the above copyright
8+
# notice, this list of conditions and the following disclaimer.
9+
# * Redistributions in binary form must reproduce the above copyright
10+
# notice, this list of conditions and the following disclaimer in the
11+
# documentation and/or other materials provided with the distribution.
12+
# * Neither the name of NVIDIA CORPORATION nor the names of its
13+
# contributors may be used to endorse or promote products derived
14+
# from this software without specific prior written permission.
15+
#
16+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
17+
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
19+
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
20+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
21+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23+
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
24+
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
25+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
26+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27+
28+
import os
29+
import sys
30+
31+
sys.path.append("../common")
32+
33+
import unittest
34+
import numpy as np
35+
import test_util as tu
36+
37+
import tritonclient.http as httpclient
38+
39+
# By default, find tritonserver on "localhost", but can be overridden
40+
# with TRITONSERVER_IPADDR envvar
41+
_tritonserver_ipaddr = os.environ.get('TRITONSERVER_IPADDR', 'localhost')
42+
43+
44+
class InferTest(tu.TestResultCollector):
45+
46+
def test_infer(self):
47+
try:
48+
triton_client = httpclient.InferenceServerClient(
49+
url=f"{_tritonserver_ipaddr}:8000")
50+
except Exception as e:
51+
print("channel creation failed: " + str(e))
52+
sys.exit(1)
53+
54+
model_name = os.environ['MODEL_NAME']
55+
56+
inputs = []
57+
outputs = []
58+
inputs.append(httpclient.InferInput('INPUT0', [1, 16], "FP32"))
59+
inputs.append(httpclient.InferInput('INPUT1', [1, 16], "FP32"))
60+
61+
# Create the data for the two input tensors.
62+
input0_data = np.arange(start=0, stop=16, dtype=np.float32)
63+
input0_data = np.expand_dims(input0_data, axis=0)
64+
input1_data = np.arange(start=32, stop=48, dtype=np.float32)
65+
input1_data = np.expand_dims(input1_data, axis=0)
66+
67+
# Initialize the data
68+
inputs[0].set_data_from_numpy(input0_data, binary_data=True)
69+
inputs[1].set_data_from_numpy(input1_data, binary_data=True)
70+
71+
outputs.append(
72+
httpclient.InferRequestedOutput('OUTPUT__0', binary_data=True))
73+
outputs.append(
74+
httpclient.InferRequestedOutput('OUTPUT__1', binary_data=True))
75+
76+
results = triton_client.infer(model_name, inputs, outputs=outputs)
77+
78+
output0_data = results.as_numpy('OUTPUT__0')
79+
output1_data = results.as_numpy('OUTPUT__1')
80+
81+
expected_output_0 = input0_data + input1_data
82+
expected_output_1 = input0_data - input1_data
83+
84+
self.assertEqual(output0_data.shape, (1, 16))
85+
self.assertEqual(output1_data.shape, (1, 16))
86+
87+
self.assertTrue(np.all(expected_output_0 == output0_data))
88+
self.assertTrue(np.all(expected_output_1 == output1_data))
89+
90+
91+
if __name__ == '__main__':
92+
unittest.main()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
#!/usr/bin/python
2+
# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
#
4+
# Redistribution and use in source and binary forms, with or without
5+
# modification, are permitted provided that the following conditions
6+
# are met:
7+
# * Redistributions of source code must retain the above copyright
8+
# notice, this list of conditions and the following disclaimer.
9+
# * Redistributions in binary form must reproduce the above copyright
10+
# notice, this list of conditions and the following disclaimer in the
11+
# documentation and/or other materials provided with the distribution.
12+
# * Neither the name of NVIDIA CORPORATION nor the names of its
13+
# contributors may be used to endorse or promote products derived
14+
# from this software without specific prior written permission.
15+
#
16+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
17+
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
19+
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
20+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
21+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23+
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
24+
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
25+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
26+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27+
28+
import torch
29+
import torch.nn as nn
30+
31+
32+
class SumModule(nn.Module):
33+
34+
def __init__(self, device):
35+
super(SumModule, self).__init__()
36+
self.device = device
37+
38+
def forward(self, INPUT0, INPUT1):
39+
INPUT0 = INPUT0.to(self.device)
40+
INPUT1 = INPUT1.to(self.device)
41+
print('SumModule - INPUT0 device: {}, INPUT1 device: {}\n'.format(
42+
INPUT0.device, INPUT1.device))
43+
return INPUT0 + INPUT1
44+
45+
46+
class DiffModule(nn.Module):
47+
48+
def __init__(self, device):
49+
super(DiffModule, self).__init__()
50+
self.device = device
51+
52+
def forward(self, INPUT0, INPUT1):
53+
INPUT0 = INPUT0.to(self.device)
54+
INPUT1 = INPUT1.to(self.device)
55+
print('DiffModule - INPUT0 device: {}, INPUT1 device: {}\n'.format(
56+
INPUT0.device, INPUT1.device))
57+
return INPUT0 - INPUT1
58+
59+
60+
class TestModel(nn.Module):
61+
62+
def __init__(self, device0, device1):
63+
super(TestModel, self).__init__()
64+
self.device0 = device0
65+
self.device1 = device1
66+
67+
self.layer1 = SumModule(self.device0)
68+
self.layer2 = DiffModule(self.device1)
69+
70+
def forward(self, INPUT0, INPUT1):
71+
op0 = self.layer1(INPUT0, INPUT1)
72+
op1 = self.layer2(INPUT0, INPUT1)
73+
return op0, op1
74+
75+
76+
devices = [("cuda:2", "cuda:0"), ("cpu", "cuda:3")]
77+
model_names = ["libtorch_multi_gpu", "libtorch_multi_device"]
78+
79+
for device_pair, model_name in zip(devices, model_names):
80+
model = TestModel(device_pair[0], device_pair[1])
81+
model_path = "models/" + model_name + "/1/model.pt"
82+
scripted_model = torch.jit.script(model)
83+
scripted_model.save(model_path)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# Redistribution and use in source and binary forms, with or without
4+
# modification, are permitted provided that the following conditions
5+
# are met:
6+
# * Redistributions of source code must retain the above copyright
7+
# notice, this list of conditions and the following disclaimer.
8+
# * Redistributions in binary form must reproduce the above copyright
9+
# notice, this list of conditions and the following disclaimer in the
10+
# documentation and/or other materials provided with the distribution.
11+
# * Neither the name of NVIDIA CORPORATION nor the names of its
12+
# contributors may be used to endorse or promote products derived
13+
# from this software without specific prior written permission.
14+
#
15+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
16+
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18+
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
19+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22+
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
23+
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26+
27+
name: "libtorch_multi_device"
28+
platform: "pytorch_libtorch"
29+
max_batch_size: 8
30+
31+
input [
32+
{
33+
name: "INPUT0"
34+
data_type: TYPE_FP32
35+
dims: [ 16 ]
36+
},
37+
{
38+
name: "INPUT1"
39+
data_type: TYPE_FP32
40+
dims: [ 16 ]
41+
}
42+
]
43+
output [
44+
{
45+
name: "OUTPUT__0"
46+
data_type: TYPE_FP32
47+
dims: [ 4 ]
48+
},
49+
{
50+
name: "OUTPUT__1"
51+
data_type: TYPE_FP32
52+
dims: [ 4 ]
53+
}
54+
]
55+
56+
instance_group [
57+
{
58+
kind: KIND_MODEL
59+
}
60+
]

0 commit comments

Comments
 (0)