diff --git a/docs/protocol/extension_generate.md b/docs/protocol/extension_generate.md index da110972ea..92c7e8f52b 100644 --- a/docs/protocol/extension_generate.md +++ b/docs/protocol/extension_generate.md @@ -53,19 +53,29 @@ POST v2/models/${MODEL_NAME}[/versions/${MODEL_VERSION}]/generate POST v2/models/${MODEL_NAME}[/versions/${MODEL_VERSION}]/generate_stream ``` -### generate v.s. generate_stream - -Both URLs expect the same request JSON object, and generate the same response -JSON object. However, `generate` returns exactly 1 response JSON object, while -`generate_stream` may return multiple responses based on the inference -results. `generate_stream` returns the responses as +### generate vs. generate_stream + +Both URLs expect the same request JSON object, and generate the same JSON +response object. However, there are some differences in the format used to +return each: +* `/generate` returns exactly 1 response JSON object with a +`Content-Type` of `application/json` +* `/generate_stream` may return multiple responses based on the inference +results, with a `Content-Type` of `text/event-stream; charset=utf-8`. +These responses will be sent as [Server-Sent Events](https://html.spec.whatwg.org/multipage/server-sent-events.html#server-sent-events) -(SSE), where each response will be a "data" chunk in the HTTP response body. -Also, note that an error may be returned during inference, whereas the HTTP -response code has been set in the first response of the SSE, which can result in -receiving an [error object](#generate-response-json-error-object) while status -code shows success (200). Therefore the user must always check whether an error -object is received when generating responses through `generate_stream`. +(SSE), where each response will be a "data" chunk in the HTTP +response body. In the case of inference errors, responses will have +an [error JSON object](#generate-response-json-error-object). + * Note that the HTTP response code is set in the first response of the SSE, + so if the first response succeeds but an error occurs in a subsequent + response for the request, it can result in receiving an error object + while the status code shows success (200). Therefore, the user must + always check whether an error object is received when generating + responses through `/generate_stream`. + * If the request fails before inference begins, then a JSON error will + be returned with `Content-Type` of `application/json`, similar to errors + from other endpoints with the status code set to an error. ### Generate Request JSON Object @@ -175,4 +185,4 @@ A failed generate request must be indicated by an HTTP error status { "error" : "error message" } -``` \ No newline at end of file +``` diff --git a/qa/L0_http/generate_endpoint_test.py b/qa/L0_http/generate_endpoint_test.py index 0b0a2eb81d..29d2e20d96 100755 --- a/qa/L0_http/generate_endpoint_test.py +++ b/qa/L0_http/generate_endpoint_test.py @@ -71,6 +71,9 @@ def generate_expect_failure(self, model_name, inputs, msg): r = requests.post( url, data=inputs if isinstance(inputs, str) else json.dumps(inputs) ) + # Content-Type header should always be JSON for errors + self.assertEqual(r.headers["Content-Type"], "application/json") + try: r.raise_for_status() self.assertTrue(False, f"Expected failure, success for {inputs}") @@ -79,6 +82,9 @@ def generate_expect_failure(self, model_name, inputs, msg): def generate_stream_expect_failure(self, model_name, inputs, msg): r = self.generate_stream(model_name, inputs) + # Content-Type header should always be JSON for errors + self.assertEqual(r.headers["Content-Type"], "application/json") + try: r.raise_for_status() self.assertTrue(False, f"Expected failure, success for {inputs}") @@ -95,7 +101,9 @@ def generate_stream_expect_success( def check_sse_responses(self, res, expected_res): # Validate SSE format self.assertIn("Content-Type", res.headers) - self.assertIn("text/event-stream", res.headers["Content-Type"]) + self.assertEqual( + "text/event-stream; charset=utf-8", res.headers["Content-Type"] + ) # SSE format (data: []) is hard to parse, use helper library for simplicity client = sseclient.SSEClient(res) @@ -128,7 +136,7 @@ def test_generate(self): r.raise_for_status() self.assertIn("Content-Type", r.headers) - self.assertIn("application/json", r.headers["Content-Type"]) + self.assertEqual(r.headers["Content-Type"], "application/json") data = r.json() self.assertIn("TEXT", data) diff --git a/src/http_server.cc b/src/http_server.cc index 71ae5ca13c..32cf6956cd 100644 --- a/src/http_server.cc +++ b/src/http_server.cc @@ -116,6 +116,13 @@ EVBufferAddErrorJson(evbuffer* buffer, TRITONSERVER_Error* err) void AddContentTypeHeader(evhtp_request_t* req, const char* type) { + // Remove existing header if found + auto content_header = + evhtp_headers_find_header(req->headers_out, kContentTypeHeader); + if (content_header) { + evhtp_header_rm_and_free(req->headers_out, content_header); + } + evhtp_headers_add_header( req->headers_out, evhtp_header_new(kContentTypeHeader, type, 1, 1)); }