diff --git a/qa/L0_cuda_shared_memory/cuda_shared_memory_test.py b/qa/L0_cuda_shared_memory/cuda_shared_memory_test.py index ce4f72aec7..07f9c05a88 100755 --- a/qa/L0_cuda_shared_memory/cuda_shared_memory_test.py +++ b/qa/L0_cuda_shared_memory/cuda_shared_memory_test.py @@ -283,7 +283,7 @@ def test_too_big_shm(self): ) if len(error_msg) > 0: self.assertIn( - "unexpected total byte size 128 for input 'INPUT1', expecting 64", + "input byte size mismatch for input 'INPUT1' for model 'simple'. Expected 64, got 128", error_msg[-1], ) shm_handles.append(shm_ip2_handle) diff --git a/qa/L0_input_validation/input_validation_test.py b/qa/L0_input_validation/input_validation_test.py index afd791b527..e683723711 100755 --- a/qa/L0_input_validation/input_validation_test.py +++ b/qa/L0_input_validation/input_validation_test.py @@ -31,9 +31,10 @@ import unittest +import infer_util as iu import numpy as np import tritonclient.grpc as tritongrpcclient -from tritonclient.utils import InferenceServerException +from tritonclient.utils import InferenceServerException, np_to_triton_dtype class InputValTest(unittest.TestCase): @@ -113,5 +114,104 @@ def test_input_validation_all_optional(self): self.assertIn(str(response.outputs[0].name), "OUTPUT0") +class InputShapeTest(unittest.TestCase): + def test_input_shape_validation(self): + input_size = 8 + model_name = "pt_identity" + triton_client = tritongrpcclient.InferenceServerClient("localhost:8001") + + # Pass + input_data = np.arange(input_size)[None].astype(np.float32) + inputs = [ + tritongrpcclient.InferInput( + "INPUT0", input_data.shape, np_to_triton_dtype(input_data.dtype) + ) + ] + inputs[0].set_data_from_numpy(input_data) + triton_client.infer(model_name=model_name, inputs=inputs) + + # Larger input byte size than expected + input_data = np.arange(input_size + 2)[None].astype(np.float32) + inputs = [ + tritongrpcclient.InferInput( + "INPUT0", input_data.shape, np_to_triton_dtype(input_data.dtype) + ) + ] + inputs[0].set_data_from_numpy(input_data) + # Compromised input shape + inputs[0].set_shape((1, input_size)) + with self.assertRaises(InferenceServerException) as e: + triton_client.infer( + model_name=model_name, + inputs=inputs, + ) + err_str = str(e.exception) + self.assertIn( + "input byte size mismatch for input 'INPUT0' for model 'pt_identity'. Expected 32, got 40", + err_str, + ) + + def test_input_string_shape_validation(self): + input_size = 16 + model_name = "graphdef_object_int32_int32" + np_dtype_string = np.dtype(object) + triton_client = tritongrpcclient.InferenceServerClient("localhost:8001") + + def get_input_array(input_size, np_dtype): + rinput_dtype = iu._range_repr_dtype(np_dtype) + input_array = np.random.randint( + low=0, high=127, size=(1, input_size), dtype=rinput_dtype + ) + + # Convert to string type + inn = np.array( + [str(x) for x in input_array.reshape(input_array.size)], dtype=object + ) + input_array = inn.reshape(input_array.shape) + + inputs = [] + inputs.append( + tritongrpcclient.InferInput( + "INPUT0", input_array.shape, np_to_triton_dtype(np_dtype) + ) + ) + inputs.append( + tritongrpcclient.InferInput( + "INPUT1", input_array.shape, np_to_triton_dtype(np_dtype) + ) + ) + + inputs[0].set_data_from_numpy(input_array) + inputs[1].set_data_from_numpy(input_array) + return inputs + + # Input size is less than expected + inputs = get_input_array(input_size - 2, np_dtype_string) + # Compromised input shape + inputs[0].set_shape((1, input_size)) + inputs[1].set_shape((1, input_size)) + with self.assertRaises(InferenceServerException) as e: + triton_client.infer(model_name=model_name, inputs=inputs) + err_str = str(e.exception) + self.assertIn( + f"expected {input_size} strings for inference input 'INPUT1', got {input_size-2}", + err_str, + ) + + # Input size is greater than expected + inputs = get_input_array(input_size + 2, np_dtype_string) + # Compromised input shape + inputs[0].set_shape((1, input_size)) + inputs[1].set_shape((1, input_size)) + with self.assertRaises(InferenceServerException) as e: + triton_client.infer(model_name=model_name, inputs=inputs) + err_str = str(e.exception) + self.assertIn( + # Core will throw exception as soon as reading the "input_size+1"th byte. + f"unexpected number of string elements {input_size+1} for inference input 'INPUT1', expecting {input_size}", + err_str, + ) + + if __name__ == "__main__": unittest.main() diff --git a/qa/L0_input_validation/test.sh b/qa/L0_input_validation/test.sh index 1c66c2bbaa..ef4a1a6d65 100755 --- a/qa/L0_input_validation/test.sh +++ b/qa/L0_input_validation/test.sh @@ -42,17 +42,20 @@ source ../common/util.sh RET=0 +DATADIR=/data/inferenceserver/${REPO_VERSION} +SERVER=/opt/tritonserver/bin/tritonserver CLIENT_LOG="./input_validation_client.log" TEST_PY=./input_validation_test.py +SHAPE_TEST_PY=./input_shape_validation_test.py TEST_RESULT_FILE='./test_results.txt' +SERVER_LOG="./inference_server.log" export CUDA_VISIBLE_DEVICES=0 rm -fr *.log -SERVER=/opt/tritonserver/bin/tritonserver +# input_validation_test SERVER_ARGS="--model-repository=`pwd`/models" -SERVER_LOG="./inference_server.log" run_server if [ "$SERVER_PID" == "0" ]; then echo -e "\n***\n*** Failed to start $SERVER\n***" @@ -64,7 +67,50 @@ set +e python3 -m pytest --junitxml="input_validation.report.xml" $TEST_PY >> $CLIENT_LOG 2>&1 if [ $? -ne 0 ]; then - echo -e "\n***\n*** python_unittest.py FAILED. \n***" + echo -e "\n***\n*** input_validation_test.py FAILED. \n***" + RET=1 +fi +set -e + +kill $SERVER_PID +wait $SERVER_PID + +# input_shape_validation_test +pip install torch +pip install pytest-asyncio + +mkdir -p models/pt_identity/1 +PYTHON_CODE=$(cat <> $CLIENT_LOG 2>&1 + +if [ $? -ne 0 ]; then + echo -e "\n***\n*** input_validation_test.py FAILED. \n***" RET=1 fi set -e diff --git a/qa/L0_shared_memory/shared_memory_test.py b/qa/L0_shared_memory/shared_memory_test.py index e162f6b296..c38ecb4814 100755 --- a/qa/L0_shared_memory/shared_memory_test.py +++ b/qa/L0_shared_memory/shared_memory_test.py @@ -118,8 +118,8 @@ def test_reregister_after_register(self): "dummy_data", "/dummy_data", 8 ) except Exception as ex: - self.assertTrue( - "shared memory region 'dummy_data' already in manager" in str(ex) + self.assertIn( + "shared memory region 'dummy_data' already in manager", str(ex) ) shm_status = self.triton_client.get_system_shared_memory_status() if self.protocol == "http": @@ -271,9 +271,9 @@ def test_too_big_shm(self): use_system_shared_memory=True, ) if len(error_msg) > 0: - self.assertTrue( - "unexpected total byte size 128 for input 'INPUT1', expecting 64" - in error_msg[-1] + self.assertIn( + "input byte size mismatch for input 'INPUT1' for model 'simple'. Expected 64, got 128", + error_msg[-1], ) shm_handles.append(shm_ip2_handle) self._cleanup_server(shm_handles)