Skip to content

Commit 372ab8d

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: add support for Predict Request Response Logging in Endpoint SDK
PiperOrigin-RevId: 497049904
1 parent a915668 commit 372ab8d

File tree

2 files changed

+68
-1
lines changed

2 files changed

+68
-1
lines changed

google/cloud/aiplatform/models.py

+28-1
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,9 @@ def create(
279279
sync=True,
280280
create_request_timeout: Optional[float] = None,
281281
endpoint_id: Optional[str] = None,
282+
enable_request_response_logging=False,
283+
request_response_logging_sampling_rate: Optional[float] = None,
284+
request_response_logging_bq_destination_table: Optional[str] = None,
282285
) -> "Endpoint":
283286
"""Creates a new endpoint.
284287
@@ -339,12 +342,18 @@ def create(
339342
is populated based on a query string argument, such as
340343
``?endpoint_id=12345``. This is the fallback for fields
341344
that are not included in either the URI or the body.
345+
enable_request_response_logging (bool):
346+
Optional. Whether to enable request & response logging for this endpoint.
347+
request_response_logging_sampling_rate (float):
348+
Optional. The request response logging sampling rate. If not set, default is 0.0.
349+
request_response_logging_bq_destination_table (str):
350+
Optional. The request response logging bigquery destination. If not set, will create a table with name:
351+
``bq://{project_id}.logging_{endpoint_display_name}_{endpoint_id}.request_response_logging``.
342352
343353
Returns:
344354
endpoint (aiplatform.Endpoint):
345355
Created endpoint.
346356
"""
347-
348357
api_client = cls._instantiate_client(location=location, credentials=credentials)
349358

350359
if not display_name:
@@ -357,6 +366,17 @@ def create(
357366
project = project or initializer.global_config.project
358367
location = location or initializer.global_config.location
359368

369+
predict_request_response_logging_config = None
370+
if enable_request_response_logging:
371+
predict_request_response_logging_config = (
372+
gca_endpoint_compat.PredictRequestResponseLoggingConfig(
373+
enabled=True,
374+
sampling_rate=request_response_logging_sampling_rate,
375+
bigquery_destination=gca_io_compat.BigQueryDestination(
376+
output_uri=request_response_logging_bq_destination_table
377+
),
378+
)
379+
)
360380
return cls._create(
361381
api_client=api_client,
362382
display_name=display_name,
@@ -372,6 +392,7 @@ def create(
372392
sync=sync,
373393
create_request_timeout=create_request_timeout,
374394
endpoint_id=endpoint_id,
395+
predict_request_response_logging_config=predict_request_response_logging_config,
375396
)
376397

377398
@classmethod
@@ -391,6 +412,9 @@ def _create(
391412
sync=True,
392413
create_request_timeout: Optional[float] = None,
393414
endpoint_id: Optional[str] = None,
415+
predict_request_response_logging_config: Optional[
416+
gca_endpoint_compat.PredictRequestResponseLoggingConfig
417+
] = None,
394418
) -> "Endpoint":
395419
"""Creates a new endpoint by calling the API client.
396420
@@ -453,6 +477,8 @@ def _create(
453477
is populated based on a query string argument, such as
454478
``?endpoint_id=12345``. This is the fallback for fields
455479
that are not included in either the URI or the body.
480+
predict_request_response_logging_config (aiplatform.endpoint.PredictRequestResponseLoggingConfig):
481+
Optional. The request response logging configuration for online prediction.
456482
457483
Returns:
458484
endpoint (aiplatform.Endpoint):
@@ -469,6 +495,7 @@ def _create(
469495
labels=labels,
470496
encryption_spec=encryption_spec,
471497
network=network,
498+
predict_request_response_logging_config=predict_request_response_logging_config,
472499
)
473500

474501
operation_future = api_client.create_endpoint(

tests/unit/aiplatform/test_endpoints.py

+40
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
prediction_service as gca_prediction_service,
5050
endpoint_service as gca_endpoint_service,
5151
encryption_spec as gca_encryption_spec,
52+
io as gca_io,
5253
)
5354

5455

@@ -200,6 +201,19 @@
200201

201202
_TEST_LABELS = {"my_key": "my_value"}
202203

204+
_TEST_REQUEST_RESPONSE_LOGGING_SAMPLING_RATE = 0.1
205+
_TEST_REQUEST_RESPONSE_LOGGING_BQ_DEST = (
206+
output_uri
207+
) = f"bq://{_TEST_PROJECT}/test_dataset/test_table"
208+
_TEST_REQUEST_RESPONSE_LOGGING_CONFIG = (
209+
gca_endpoint.PredictRequestResponseLoggingConfig(
210+
enabled=True,
211+
sampling_rate=_TEST_REQUEST_RESPONSE_LOGGING_SAMPLING_RATE,
212+
bigquery_destination=gca_io.BigQueryDestination(
213+
output_uri=_TEST_REQUEST_RESPONSE_LOGGING_BQ_DEST
214+
),
215+
)
216+
)
203217

204218
"""
205219
----------------------------------------------------------------------------
@@ -853,6 +867,32 @@ def test_create_with_labels(self, create_endpoint_mock, sync):
853867
timeout=None,
854868
)
855869

870+
@pytest.mark.usefixtures("get_endpoint_mock")
871+
@pytest.mark.parametrize("sync", [True, False])
872+
def test_create_with_request_response_logging(self, create_endpoint_mock, sync):
873+
my_endpoint = models.Endpoint.create(
874+
display_name=_TEST_DISPLAY_NAME,
875+
enable_request_response_logging=True,
876+
request_response_logging_sampling_rate=_TEST_REQUEST_RESPONSE_LOGGING_SAMPLING_RATE,
877+
request_response_logging_bq_destination_table=_TEST_REQUEST_RESPONSE_LOGGING_BQ_DEST,
878+
sync=sync,
879+
create_request_timeout=None,
880+
)
881+
if not sync:
882+
my_endpoint.wait()
883+
884+
expected_endpoint = gca_endpoint.Endpoint(
885+
display_name=_TEST_DISPLAY_NAME,
886+
predict_request_response_logging_config=_TEST_REQUEST_RESPONSE_LOGGING_CONFIG,
887+
)
888+
create_endpoint_mock.assert_called_once_with(
889+
parent=_TEST_PARENT,
890+
endpoint=expected_endpoint,
891+
endpoint_id=None,
892+
metadata=(),
893+
timeout=None,
894+
)
895+
856896
@pytest.mark.usefixtures("get_endpoint_mock")
857897
def test_update_endpoint(self, update_endpoint_mock):
858898
endpoint = models.Endpoint(_TEST_ENDPOINT_NAME)

0 commit comments

Comments
 (0)