Skip to content

Commit bb228ce

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Add notification_channels field to model monitoring alert config.
PiperOrigin-RevId: 593812234
1 parent 9a8e1ca commit bb228ce

File tree

4 files changed

+114
-33
lines changed

4 files changed

+114
-33
lines changed

google/cloud/aiplatform/model_monitoring/__init__.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,23 @@
1515
# limitations under the License.
1616
#
1717

18-
from google.cloud.aiplatform.model_monitoring.alert import EmailAlertConfig
18+
from google.cloud.aiplatform.model_monitoring.alert import (
19+
AlertConfig,
20+
EmailAlertConfig,
21+
)
1922
from google.cloud.aiplatform.model_monitoring.objective import (
2023
SkewDetectionConfig,
2124
DriftDetectionConfig,
2225
ExplanationConfig,
2326
ObjectiveConfig,
2427
)
25-
from google.cloud.aiplatform.model_monitoring.sampling import RandomSampleConfig
28+
from google.cloud.aiplatform.model_monitoring.sampling import (
29+
RandomSampleConfig,
30+
)
2631
from google.cloud.aiplatform.model_monitoring.schedule import ScheduleConfig
2732

2833
__all__ = (
34+
"AlertConfig",
2935
"EmailAlertConfig",
3036
"SkewDetectionConfig",
3137
"DriftDetectionConfig",

google/cloud/aiplatform/model_monitoring/alert.py

+47-23
Original file line numberDiff line numberDiff line change
@@ -15,56 +15,80 @@
1515
# limitations under the License.
1616
#
1717

18-
from typing import Optional, List
18+
from typing import List, Optional
1919
from google.cloud.aiplatform_v1.types import (
2020
model_monitoring as gca_model_monitoring_v1,
2121
)
2222

23-
# TODO(b/242108750): remove temporary logic once model monitoring for batch prediction is GA
23+
# TODO(b/242108750): remove temporary logic once model monitoring for
24+
# batch prediction is GA.
2425
from google.cloud.aiplatform_v1beta1.types import (
2526
model_monitoring as gca_model_monitoring_v1beta1,
2627
)
2728

2829
gca_model_monitoring = gca_model_monitoring_v1
2930

3031

31-
class EmailAlertConfig:
32+
class AlertConfig:
3233
def __init__(
33-
self, user_emails: List[str] = [], enable_logging: Optional[bool] = False
34+
self,
35+
user_emails: List[str] = [],
36+
enable_logging: Optional[bool] = False,
37+
notification_channels: List[str] = [],
3438
):
35-
"""Initializer for EmailAlertConfig.
39+
"""Initializer for AlertConfig.
3640
3741
Args:
38-
user_emails (List[str]):
39-
The email addresses to send the alert to.
40-
enable_logging (bool):
41-
Optional. Defaults to False. Streams detected anomalies to Cloud Logging. The anomalies will be
42-
put into json payload encoded from proto
43-
[google.cloud.aiplatform.logging.ModelMonitoringAnomaliesLogEntry][].
44-
This can be further sync'd to Pub/Sub or any other services
45-
supported by Cloud Logging.
42+
user_emails (List[str]): The email addresses to send the alert to.
43+
enable_logging (bool): Optional. Defaults to False. Streams detected
44+
anomalies to Cloud Logging. The anomalies will be put into json
45+
payload encoded from proto
46+
[google.cloud.aiplatform.logging.ModelMonitoringAnomaliesLogEntry][].
47+
This can be further sync'd to Pub/Sub or any other services supported
48+
by Cloud Logging.
49+
notification_channels (List[str]): The Cloud notification channels to
50+
send the alert to.
4651
"""
47-
self.enable_logging = enable_logging
4852
self.user_emails = user_emails
53+
self.enable_logging = enable_logging
54+
self.notification_channels = notification_channels
4955
self._config_for_bp = False
5056

51-
# TODO(b/242108750): remove temporary logic once model monitoring for batch prediction is GA
5257
def as_proto(self) -> gca_model_monitoring.ModelMonitoringAlertConfig:
53-
"""Converts EmailAlertConfig to a proto message.
58+
"""Converts AlertConfig to a proto message.
5459
5560
Returns:
56-
The GAPIC representation of the email alert config.
61+
The GAPIC representation of the alert config.
5762
"""
63+
# TODO(b/242108750): remove temporary logic once model monitoring for
64+
# batch prediction is GA.
5865
if self._config_for_bp:
5966
gca_model_monitoring = gca_model_monitoring_v1beta1
6067
else:
6168
gca_model_monitoring = gca_model_monitoring_v1
62-
user_email_alert_config = (
63-
gca_model_monitoring.ModelMonitoringAlertConfig.EmailAlertConfig(
64-
user_emails=self.user_emails
65-
)
66-
)
69+
6770
return gca_model_monitoring.ModelMonitoringAlertConfig(
68-
email_alert_config=user_email_alert_config,
71+
email_alert_config=gca_model_monitoring.ModelMonitoringAlertConfig.EmailAlertConfig(
72+
user_emails=self.user_emails
73+
),
6974
enable_logging=self.enable_logging,
75+
notification_channels=self.notification_channels,
7076
)
77+
78+
79+
class EmailAlertConfig(AlertConfig):
80+
def __init__(
81+
self, user_emails: List[str] = [], enable_logging: Optional[bool] = False
82+
):
83+
"""Initializer for EmailAlertConfig.
84+
85+
Args:
86+
user_emails (List[str]): The email addresses to send the alert to.
87+
enable_logging (bool): Optional. Defaults to False. Streams detected
88+
anomalies to Cloud Logging. The anomalies will be put into json
89+
payload encoded from proto
90+
[google.cloud.aiplatform.logging.ModelMonitoringAnomaliesLogEntry][].
91+
This can be further sync'd to Pub/Sub or any other services supported
92+
by Cloud Logging.
93+
"""
94+
super().__init__(user_emails=user_emails, enable_logging=enable_logging)

tests/system/aiplatform/test_model_monitoring.py

+42-7
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131

3232
# constants used for testing
3333
USER_EMAIL = "[email protected]"
34+
NOTIFICATION_CHANNEL = "projects/123/notificationChannels/456"
3435
PERMANENT_CHURN_MODEL_ID = "5295507484113371136"
3536
CHURN_MODEL_PATH = "gs://mco-mm/churn"
3637
DEFAULT_INPUT = {
@@ -90,10 +91,16 @@
9091
# global test constants
9192
sampling_strategy = model_monitoring.RandomSampleConfig(sample_rate=LOG_SAMPLE_RATE)
9293

93-
alert_config = model_monitoring.EmailAlertConfig(
94+
email_alert_config = model_monitoring.EmailAlertConfig(
9495
user_emails=[USER_EMAIL], enable_logging=True
9596
)
9697

98+
alert_config = model_monitoring.AlertConfig(
99+
user_emails=[USER_EMAIL],
100+
enable_logging=True,
101+
notification_channels=[NOTIFICATION_CHANNEL],
102+
)
103+
97104
schedule_config = model_monitoring.ScheduleConfig(monitor_interval=MONITOR_INTERVAL)
98105

99106
skew_config = model_monitoring.SkewDetectionConfig(
@@ -149,7 +156,7 @@ def test_mdm_two_models_one_valid_config(self, shared_state):
149156
display_name=self._make_display_name(key=JOB_NAME),
150157
logging_sampling_strategy=sampling_strategy,
151158
schedule_config=schedule_config,
152-
alert_config=alert_config,
159+
alert_config=email_alert_config,
153160
objective_configs=objective_config,
154161
create_request_timeout=3600,
155162
project=e2e_base._PROJECT,
@@ -211,7 +218,7 @@ def test_mdm_pause_and_update_config(self, shared_state):
211218
display_name=self._make_display_name(key=JOB_NAME),
212219
logging_sampling_strategy=sampling_strategy,
213220
schedule_config=schedule_config,
214-
alert_config=alert_config,
221+
alert_config=email_alert_config,
215222
objective_configs=model_monitoring.ObjectiveConfig(
216223
drift_detection_config=drift_config
217224
),
@@ -284,7 +291,7 @@ def test_mdm_two_models_two_valid_configs(self, shared_state):
284291
display_name=self._make_display_name(key=JOB_NAME),
285292
logging_sampling_strategy=sampling_strategy,
286293
schedule_config=schedule_config,
287-
alert_config=alert_config,
294+
alert_config=email_alert_config,
288295
objective_configs=all_configs,
289296
create_request_timeout=3600,
290297
project=e2e_base._PROJECT,
@@ -338,7 +345,7 @@ def test_mdm_invalid_config_incorrect_model_id(self, shared_state):
338345
display_name=self._make_display_name(key=JOB_NAME),
339346
logging_sampling_strategy=sampling_strategy,
340347
schedule_config=schedule_config,
341-
alert_config=alert_config,
348+
alert_config=email_alert_config,
342349
objective_configs=objective_config,
343350
create_request_timeout=3600,
344351
project=e2e_base._PROJECT,
@@ -358,7 +365,7 @@ def test_mdm_invalid_config_xai(self, shared_state):
358365
display_name=self._make_display_name(key=JOB_NAME),
359366
logging_sampling_strategy=sampling_strategy,
360367
schedule_config=schedule_config,
361-
alert_config=alert_config,
368+
alert_config=email_alert_config,
362369
objective_configs=objective_config,
363370
create_request_timeout=3600,
364371
project=e2e_base._PROJECT,
@@ -388,7 +395,7 @@ def test_mdm_two_models_invalid_configs_xai(self, shared_state):
388395
display_name=self._make_display_name(key=JOB_NAME),
389396
logging_sampling_strategy=sampling_strategy,
390397
schedule_config=schedule_config,
391-
alert_config=alert_config,
398+
alert_config=email_alert_config,
392399
objective_configs=all_configs,
393400
create_request_timeout=3600,
394401
project=e2e_base._PROJECT,
@@ -399,3 +406,31 @@ def test_mdm_two_models_invalid_configs_xai(self, shared_state):
399406
"`explanation_config` should only be enabled if the model has `explanation_spec populated"
400407
in str(e.value)
401408
)
409+
410+
def test_mdm_notification_channel_alert_config(self, shared_state):
411+
self.endpoint = shared_state["resources"][0]
412+
aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION)
413+
# test model monitoring configurations
414+
job = aiplatform.ModelDeploymentMonitoringJob.create(
415+
display_name=self._make_display_name(key=JOB_NAME),
416+
logging_sampling_strategy=sampling_strategy,
417+
schedule_config=schedule_config,
418+
alert_config=alert_config,
419+
objective_configs=objective_config,
420+
create_request_timeout=3600,
421+
project=e2e_base._PROJECT,
422+
location=e2e_base._LOCATION,
423+
endpoint=self.endpoint,
424+
)
425+
426+
gapic_job = job._gca_resource
427+
assert (
428+
gapic_job.model_monitoring_alert_config.email_alert_config.user_emails
429+
== [USER_EMAIL]
430+
)
431+
assert gapic_job.model_monitoring_alert_config.enable_logging
432+
assert gapic_job.model_monitoring_alert_config.notification_channels == [
433+
NOTIFICATION_CHANNEL
434+
]
435+
436+
job.delete()

tests/unit/aiplatform/test_model_monitoring.py

+17-1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
_TEST_DRIFT_TRESHOLD = {"key": 0.2}
3232
_TEST_EMAIL1 = "test1"
3333
_TEST_EMAIL2 = "test2"
34+
_TEST_NOTIFICATION_CHANNEL = "projects/123/notificationChannels/456"
3435
_TEST_VALID_DATA_FORMATS = ["tf-record", "csv", "jsonl"]
3536
_TEST_SAMPLING_RATE = 0.8
3637
_TEST_MONITORING_INTERVAL = 1
@@ -105,10 +106,16 @@ def test_valid_configs(
105106
monitor_interval=_TEST_MONITORING_INTERVAL
106107
)
107108

108-
alert_config = model_monitoring.EmailAlertConfig(
109+
email_alert_config = model_monitoring.EmailAlertConfig(
109110
user_emails=[_TEST_EMAIL1, _TEST_EMAIL2]
110111
)
111112

113+
alert_config = model_monitoring.AlertConfig(
114+
user_emails=[_TEST_EMAIL1, _TEST_EMAIL2],
115+
enable_logging=True,
116+
notification_channels=[_TEST_NOTIFICATION_CHANNEL],
117+
)
118+
112119
prediction_drift_config = model_monitoring.DriftDetectionConfig(
113120
drift_thresholds=_TEST_DRIFT_TRESHOLD
114121
)
@@ -149,8 +156,17 @@ def test_valid_configs(
149156
== prediction_drift_config.as_proto()
150157
)
151158
assert objective_config.as_proto().explanation_config == xai_config.as_proto()
159+
assert (
160+
_TEST_EMAIL1 in email_alert_config.as_proto().email_alert_config.user_emails
161+
)
162+
assert (
163+
_TEST_EMAIL2 in email_alert_config.as_proto().email_alert_config.user_emails
164+
)
152165
assert _TEST_EMAIL1 in alert_config.as_proto().email_alert_config.user_emails
153166
assert _TEST_EMAIL2 in alert_config.as_proto().email_alert_config.user_emails
167+
assert (
168+
_TEST_NOTIFICATION_CHANNEL in alert_config.as_proto().notification_channels
169+
)
154170
assert (
155171
random_sample_config.as_proto().random_sample_config.sample_rate
156172
== _TEST_SAMPLING_RATE

0 commit comments

Comments
 (0)