Skip to content

Commit 7da4164

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Add default skew threshold to be an optional input at _SkewDetectionConfig and also mark the target_field and data_source of skew config to optional.
PiperOrigin-RevId: 496543878
1 parent c23a8bd commit 7da4164

File tree

2 files changed

+106
-35
lines changed

2 files changed

+106
-35
lines changed

google/cloud/aiplatform/model_monitoring/objective.py

+38-25
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# limitations under the License.
1616
#
1717

18-
from typing import Optional, Dict
18+
from typing import Optional, Dict, Union
1919

2020
from google.cloud.aiplatform_v1.types import (
2121
io as gca_io_v1,
@@ -39,27 +39,30 @@
3939
class _SkewDetectionConfig:
4040
def __init__(
4141
self,
42-
data_source: str,
43-
skew_thresholds: Dict[str, float],
44-
target_field: str,
45-
attribute_skew_thresholds: Dict[str, float],
42+
data_source: Optional[str] = None,
43+
skew_thresholds: Union[Dict[str, float], float, None] = None,
44+
target_field: Optional[str] = None,
45+
attribute_skew_thresholds: Optional[Dict[str, float]] = None,
4646
data_format: Optional[str] = None,
4747
):
4848
"""Base class for training-serving skew detection.
4949
Args:
5050
data_source (str):
51-
Required. Path to training dataset.
51+
Optional. Path to training dataset.
5252
53-
skew_thresholds (Dict[str, float]):
53+
skew_thresholds: Union[Dict[str, float], float, None]:
5454
Optional. Key is the feature name and value is the
5555
threshold. If a feature needs to be monitored
5656
for skew, a value threshold must be configured
5757
for that feature. The threshold here is against
5858
feature distribution distance between the
59-
training and prediction feature.
59+
training and prediction feature. If a float is passed,
60+
then all features will be monitored using the same
61+
threshold. If None is passed, all feature will be monitored
62+
using alert threshold 0.3 (Backend default).
6063
6164
target_field (str):
62-
Required. The target field name the model is to
65+
Optional. The target field name the model is to
6366
predict. This field will be excluded when doing
6467
Predict and (or) Explain for the training data.
6568
@@ -93,12 +96,18 @@ def as_proto(self):
9396
"""Returns _SkewDetectionConfig as a proto message."""
9497
skew_thresholds_mapping = {}
9598
attribution_score_skew_thresholds_mapping = {}
99+
default_skew_threshold = None
96100
if self.skew_thresholds is not None:
97-
for key in self.skew_thresholds.keys():
98-
skew_threshold = gca_model_monitoring.ThresholdConfig(
99-
value=self.skew_thresholds[key]
101+
if isinstance(self.skew_thresholds, float):
102+
default_skew_threshold = gca_model_monitoring.ThresholdConfig(
103+
value=self.skew_thresholds
100104
)
101-
skew_thresholds_mapping[key] = skew_threshold
105+
else:
106+
for key in self.skew_thresholds.keys():
107+
skew_threshold = gca_model_monitoring.ThresholdConfig(
108+
value=self.skew_thresholds[key]
109+
)
110+
skew_thresholds_mapping[key] = skew_threshold
102111
if self.attribute_skew_thresholds is not None:
103112
for key in self.attribute_skew_thresholds.keys():
104113
attribution_score_skew_threshold = gca_model_monitoring.ThresholdConfig(
@@ -110,6 +119,7 @@ def as_proto(self):
110119
return gca_model_monitoring.ModelMonitoringObjectiveConfig.TrainingPredictionSkewDetectionConfig(
111120
skew_thresholds=skew_thresholds_mapping,
112121
attribution_score_skew_thresholds=attribution_score_skew_thresholds_mapping,
122+
default_skew_threshold=default_skew_threshold,
113123
)
114124

115125

@@ -266,30 +276,33 @@ class SkewDetectionConfig(_SkewDetectionConfig):
266276

267277
def __init__(
268278
self,
269-
data_source: str,
270-
target_field: str,
271-
skew_thresholds: Optional[Dict[str, float]] = None,
279+
data_source: Optional[str] = None,
280+
target_field: Optional[str] = None,
281+
skew_thresholds: Union[Dict[str, float], float, None] = None,
272282
attribute_skew_thresholds: Optional[Dict[str, float]] = None,
273283
data_format: Optional[str] = None,
274284
):
275285
"""Initializer for SkewDetectionConfig.
276286
277287
Args:
278288
data_source (str):
279-
Required. Path to training dataset.
289+
Optional. Path to training dataset.
280290
281291
target_field (str):
282-
Required. The target field name the model is to
292+
Optional. The target field name the model is to
283293
predict. This field will be excluded when doing
284294
Predict and (or) Explain for the training data.
285295
286-
skew_thresholds (Dict[str, float]):
296+
skew_thresholds: Union[Dict[str, float], float, None]:
287297
Optional. Key is the feature name and value is the
288298
threshold. If a feature needs to be monitored
289299
for skew, a value threshold must be configured
290300
for that feature. The threshold here is against
291301
feature distribution distance between the
292-
training and prediction feature.
302+
training and prediction feature. If a float is passed,
303+
then all features will be monitored using the same
304+
threshold. If None is passed, all feature will be monitored
305+
using alert threshold 0.3 (Backend default).
293306
294307
attribute_skew_thresholds (Dict[str, float]):
295308
Optional. Key is the feature name and value is the
@@ -315,11 +328,11 @@ def __init__(
315328
ValueError for unsupported data formats.
316329
"""
317330
super().__init__(
318-
data_source,
319-
skew_thresholds,
320-
target_field,
321-
attribute_skew_thresholds,
322-
data_format,
331+
data_source=data_source,
332+
skew_thresholds=skew_thresholds,
333+
target_field=target_field,
334+
attribute_skew_thresholds=attribute_skew_thresholds,
335+
data_format=data_format,
323336
)
324337

325338

tests/unit/aiplatform/test_model_monitoring.py

+68-10
Original file line numberDiff line numberDiff line change
@@ -24,26 +24,79 @@
2424
model_monitoring as gca_model_monitoring,
2525
)
2626

27-
_TEST_THRESHOLD = 0.1
2827
_TEST_TARGET_FIELD = "target"
2928
_TEST_BQ_DATASOURCE = "bq://test/data"
3029
_TEST_GCS_DATASOURCE = "gs://test/data"
3130
_TEST_OTHER_DATASOURCE = ""
32-
_TEST_KEY = "key"
31+
_TEST_DRIFT_TRESHOLD = {"key": 0.2}
3332
_TEST_EMAIL1 = "test1"
3433
_TEST_EMAIL2 = "test2"
3534
_TEST_VALID_DATA_FORMATS = ["tf-record", "csv", "jsonl"]
3635
_TEST_SAMPLING_RATE = 0.8
3736
_TEST_MONITORING_INTERVAL = 1
37+
_TEST_SKEW_THRESHOLDS = [None, 0.2, {"key": 0.1}]
38+
_TEST_ATTRIBUTE_SKEW_THRESHOLDS = [None, {"key": 0.1}]
3839

3940

4041
class TestModelMonitoringConfigs:
42+
"""Tests for model monitoring configs."""
43+
4144
@pytest.mark.parametrize(
4245
"data_source",
4346
[_TEST_BQ_DATASOURCE, _TEST_GCS_DATASOURCE, _TEST_OTHER_DATASOURCE],
4447
)
4548
@pytest.mark.parametrize("data_format", _TEST_VALID_DATA_FORMATS)
46-
def test_valid_configs(self, data_source, data_format):
49+
@pytest.mark.parametrize("skew_thresholds", _TEST_SKEW_THRESHOLDS)
50+
def test_skew_config_proto_value(self, data_source, data_format, skew_thresholds):
51+
"""Tests if skew config can be constrctued properly to gapic proto."""
52+
attribute_skew_thresholds = {"key": 0.1}
53+
skew_config = model_monitoring.SkewDetectionConfig(
54+
data_source=data_source,
55+
skew_thresholds=skew_thresholds,
56+
target_field=_TEST_TARGET_FIELD,
57+
attribute_skew_thresholds=attribute_skew_thresholds,
58+
data_format=data_format,
59+
)
60+
# data_format and data source are not used at
61+
# TrainingPredictionSkewDetectionConfig.
62+
if isinstance(skew_thresholds, dict):
63+
expected_gapic_proto = gca_model_monitoring.ModelMonitoringObjectiveConfig.TrainingPredictionSkewDetectionConfig(
64+
skew_thresholds={
65+
key: gca_model_monitoring.ThresholdConfig(value=val)
66+
for key, val in skew_thresholds.items()
67+
},
68+
attribution_score_skew_thresholds={
69+
key: gca_model_monitoring.ThresholdConfig(value=val)
70+
for key, val in attribute_skew_thresholds.items()
71+
},
72+
)
73+
else:
74+
expected_gapic_proto = gca_model_monitoring.ModelMonitoringObjectiveConfig.TrainingPredictionSkewDetectionConfig(
75+
default_skew_threshold=gca_model_monitoring.ThresholdConfig(
76+
value=skew_thresholds
77+
)
78+
if skew_thresholds is not None
79+
else None,
80+
attribution_score_skew_thresholds={
81+
key: gca_model_monitoring.ThresholdConfig(value=val)
82+
for key, val in attribute_skew_thresholds.items()
83+
},
84+
)
85+
assert skew_config.as_proto() == expected_gapic_proto
86+
87+
@pytest.mark.parametrize(
88+
"data_source",
89+
[_TEST_BQ_DATASOURCE, _TEST_GCS_DATASOURCE, _TEST_OTHER_DATASOURCE],
90+
)
91+
@pytest.mark.parametrize("data_format", _TEST_VALID_DATA_FORMATS)
92+
@pytest.mark.parametrize("skew_thresholds", _TEST_SKEW_THRESHOLDS)
93+
@pytest.mark.parametrize(
94+
"attribute_skew_thresholds", _TEST_ATTRIBUTE_SKEW_THRESHOLDS
95+
)
96+
def test_valid_configs(
97+
self, data_source, data_format, skew_thresholds, attribute_skew_thresholds
98+
):
99+
"""Test config creation validity."""
47100
random_sample_config = model_monitoring.RandomSampleConfig(
48101
sample_rate=_TEST_SAMPLING_RATE
49102
)
@@ -57,17 +110,16 @@ def test_valid_configs(self, data_source, data_format):
57110
)
58111

59112
prediction_drift_config = model_monitoring.DriftDetectionConfig(
60-
drift_thresholds={_TEST_KEY: _TEST_THRESHOLD}
113+
drift_thresholds=_TEST_DRIFT_TRESHOLD
61114
)
62115

63116
skew_config = model_monitoring.SkewDetectionConfig(
64117
data_source=data_source,
65-
skew_thresholds={_TEST_KEY: _TEST_THRESHOLD},
118+
skew_thresholds=skew_thresholds,
66119
target_field=_TEST_TARGET_FIELD,
67-
attribute_skew_thresholds={_TEST_KEY: _TEST_THRESHOLD},
120+
attribute_skew_thresholds=attribute_skew_thresholds,
68121
data_format=data_format,
69122
)
70-
71123
expected_training_dataset = (
72124
gca_model_monitoring.ModelMonitoringObjectiveConfig.TrainingDataset(
73125
bigquery_source=gca_io.BigQuerySource(input_uri=_TEST_BQ_DATASOURCE),
@@ -110,15 +162,21 @@ def test_valid_configs(self, data_source, data_format):
110162

111163
@pytest.mark.parametrize("data_source", [_TEST_GCS_DATASOURCE])
112164
@pytest.mark.parametrize("data_format", ["other"])
113-
def test_invalid_data_format(self, data_source, data_format):
165+
@pytest.mark.parametrize("skew_thresholds", _TEST_SKEW_THRESHOLDS)
166+
@pytest.mark.parametrize(
167+
"attribute_skew_thresholds", _TEST_ATTRIBUTE_SKEW_THRESHOLDS
168+
)
169+
def test_invalid_data_format(
170+
self, data_source, data_format, skew_thresholds, attribute_skew_thresholds
171+
):
114172
if data_format == "other":
115173
with pytest.raises(ValueError) as e:
116174
model_monitoring.ObjectiveConfig(
117175
skew_detection_config=model_monitoring.SkewDetectionConfig(
118176
data_source=data_source,
119-
skew_thresholds={_TEST_KEY: _TEST_THRESHOLD},
177+
skew_thresholds=skew_thresholds,
120178
target_field=_TEST_TARGET_FIELD,
121-
attribute_skew_thresholds={_TEST_KEY: _TEST_THRESHOLD},
179+
attribute_skew_thresholds=attribute_skew_thresholds,
122180
data_format=data_format,
123181
)
124182
).as_proto()

0 commit comments

Comments
 (0)