-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Input data/shape validation #7171
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
@GuanLuo @tanmayv25 @Tabrizian I believe there are some relaxed checks on the hot path for performance purposes. Do you have any comments on problem areas or risks? |
Hi @HennerM, thanks for raising this. Do you mind adding the following:
|
Thanks for looking at this. Here is a full example that you can run with import asyncio
from pathlib import Path
from subprocess import Popen
from tempfile import TemporaryDirectory
from typing import Optional
import numpy as np
import pytest
import torch
from tritonclient.grpc.aio import InferenceServerClient, InferInput
from tritonclient.utils import np_to_triton_dtype
GRPC_PORT = 9653
FIXED_LAST_DIM = 8
@pytest.fixture
def repo_dir():
with TemporaryDirectory() as model_repo:
(Path(model_repo) / "pt_identity" / "1").mkdir(parents=True, exist_ok=True)
torch.jit.save(torch.jit.script(torch.nn.Identity()), model_repo + "/pt_identity/1/model.pt")
pbtxt = f"""
name: "pt_identity"
backend: "pytorch"
max_batch_size: 8
input [
{{
name: "INPUT0"
data_type: TYPE_FP32
dims: [ {FIXED_LAST_DIM} ]
}}
]
output [
{{
name: "OUTPUT0"
data_type: TYPE_FP32
dims: [ {FIXED_LAST_DIM} ]
}}
]
# ensure we batch requests together
dynamic_batching {{
max_queue_delay_microseconds: {int(5e6)}
}}
"""
with open(model_repo + "/pt_identity/config.pbtxt", "w") as f:
f.write(pbtxt)
yield model_repo
async def poll_readiness(client: InferenceServerClient):
while True:
if server_proc is not None and (ret_code := server_proc.poll()) is not None:
_, stderr = server_proc.communicate()
print(stderr)
raise Exception(f"Tritonserver died with return code {ret_code}")
try:
if (await client.is_server_ready()):
break
except: # noqa: E722
pass
await asyncio.sleep(0.5)
@pytest.mark.asyncio
async def test_shape_overlapped(repo_dir: str):
with Popen(["tritonserver", "--model-repository", repo_dir, "--grpc-port", str(GRPC_PORT)]) as server:
await poll_readiness(InferenceServerClient("localhost:" + str(GRPC_PORT)), server)
alice = InferenceServerClient("localhost:" + str(GRPC_PORT))
bob = InferenceServerClient("localhost:" + str(GRPC_PORT))
input_data_1 = np.arange(FIXED_LAST_DIM + 2)[None].astype(np.float32)
print(f"{input_data_1=}")
inputs_1 = [
InferInput("INPUT0", input_data_1.shape, np_to_triton_dtype(input_data_1.dtype)),
]
inputs_1[0].set_data_from_numpy(input_data_1)
# Compromised input shape
inputs_1[0].set_shape((1, FIXED_LAST_DIM))
input_data_2 = 100 + np.arange(FIXED_LAST_DIM)[None].astype(np.float32)
print(f"{input_data_2=}")
inputs_2 = [InferInput("INPUT0", shape=input_data_2.shape, datatype=np_to_triton_dtype(input_data_2.dtype))]
inputs_2[0].set_data_from_numpy(input_data_2)
t1 = asyncio.create_task(alice.infer("pt_identity", inputs_1))
t2 = asyncio.create_task(bob.infer("pt_identity", inputs_2))
_, bob_result = await asyncio.gather(t1, t2)
server.terminate()
assert np.allclose(bob_result.as_numpy("OUTPUT0"), input_data_2), "Bob's result should be the same as input" |
Thanks for filing this issue. I think we have checks for other input mismatches with the model configuration (https://github.com/triton-inference-server/core/blob/35555724612df29923007f1da45d2a58f928206c/src/infer_request.cc#L1066-L1175) looks like we need an additional check to make sure that the total byte size of the elements matches with the specified dimensions. Filed (DLIS-6634) |
I am happy to contribute on this, just wanted to check if this isn't checked already and if the team considers it a good idea. |
@jbkyang-nvi has started taking a look at this bug. Looks like there are a few locations where we can update the checks to make sure the request has the correct size. |
@HennerM Is this supposed to pass with current Triton? Currently this test passes. |
@yinggeh nice, thanks for fixing this! |
Is your feature request related to a problem? Please describe.
Triton server doesn't do any validation of the data sent by clients. Specifically validation that the given shape matches the size of the input. For example if one client sends a vector with 15 elements and specifies the shape as [1,10], Triton blindly accepts this and passes it on to the backend, a similar issue arises if the client only sends a vector with 5 elements.
This could potentially lead to data leaking from one request into another one, an example that can trigger this behaviour is given here:
With a example model config:
This has been observed with the PyTorch backend, I am not sure if there are provisions in other backends in place.
Describe the solution you'd like
In cases where the sent request buffer size doesn't match we should fail fast, rejecting the request as early as possible. Probably before the said request is enqueued into a batcher. A trivial check would be if the total_num_elements (i.e. the product of the whole shape vector) multiplied by the datatype size in bytes adds up to the actual size of the input buffer.
Describe alternatives you've considered
I have a draft for adding a validation like this to the libtorch backend: https://github.com/triton-inference-server/pytorch_backend/compare/main...speechmatics:pytorch_backend:check-shapes?expand=1 The problem with this is that it is very late in the Triton pipeline, we validate only once a request has been batched. I might have overseen something but at this point I am not sure if a single request can be rejected, at leaset I couldn't find an example for that.
Additional context
Ideally Triton core should also check each backends output, with the same check. This could be another feature request though.
The text was updated successfully, but these errors were encountered: