15
15
# limitations under the License.
16
16
#
17
17
18
- from typing import Optional , Dict
18
+ from typing import Optional , Dict , Union
19
19
20
20
from google .cloud .aiplatform_v1 .types import (
21
21
io as gca_io_v1 ,
39
39
class _SkewDetectionConfig :
40
40
def __init__ (
41
41
self ,
42
- data_source : str ,
43
- skew_thresholds : Dict [str , float ],
44
- target_field : str ,
45
- attribute_skew_thresholds : Dict [str , float ],
42
+ data_source : Optional [ str ] = None ,
43
+ skew_thresholds : Union [ Dict [str , float ], float , None ] = None ,
44
+ target_field : Optional [ str ] = None ,
45
+ attribute_skew_thresholds : Optional [ Dict [str , float ]] = None ,
46
46
data_format : Optional [str ] = None ,
47
47
):
48
48
"""Base class for training-serving skew detection.
49
49
Args:
50
50
data_source (str):
51
- Required . Path to training dataset.
51
+ Optional . Path to training dataset.
52
52
53
- skew_thresholds ( Dict[str, float]) :
53
+ skew_thresholds: Union[ Dict[str, float], float, None] :
54
54
Optional. Key is the feature name and value is the
55
55
threshold. If a feature needs to be monitored
56
56
for skew, a value threshold must be configured
57
57
for that feature. The threshold here is against
58
58
feature distribution distance between the
59
- training and prediction feature.
59
+ training and prediction feature. If a float is passed,
60
+ then all features will be monitored using the same
61
+ threshold. If None is passed, all feature will be monitored
62
+ using alert threshold 0.3 (Backend default).
60
63
61
64
target_field (str):
62
- Required . The target field name the model is to
65
+ Optional . The target field name the model is to
63
66
predict. This field will be excluded when doing
64
67
Predict and (or) Explain for the training data.
65
68
@@ -93,12 +96,18 @@ def as_proto(self):
93
96
"""Returns _SkewDetectionConfig as a proto message."""
94
97
skew_thresholds_mapping = {}
95
98
attribution_score_skew_thresholds_mapping = {}
99
+ default_skew_threshold = None
96
100
if self .skew_thresholds is not None :
97
- for key in self .skew_thresholds . keys ( ):
98
- skew_threshold = gca_model_monitoring .ThresholdConfig (
99
- value = self .skew_thresholds [ key ]
101
+ if isinstance ( self .skew_thresholds , float ):
102
+ default_skew_threshold = gca_model_monitoring .ThresholdConfig (
103
+ value = self .skew_thresholds
100
104
)
101
- skew_thresholds_mapping [key ] = skew_threshold
105
+ else :
106
+ for key in self .skew_thresholds .keys ():
107
+ skew_threshold = gca_model_monitoring .ThresholdConfig (
108
+ value = self .skew_thresholds [key ]
109
+ )
110
+ skew_thresholds_mapping [key ] = skew_threshold
102
111
if self .attribute_skew_thresholds is not None :
103
112
for key in self .attribute_skew_thresholds .keys ():
104
113
attribution_score_skew_threshold = gca_model_monitoring .ThresholdConfig (
@@ -110,6 +119,7 @@ def as_proto(self):
110
119
return gca_model_monitoring .ModelMonitoringObjectiveConfig .TrainingPredictionSkewDetectionConfig (
111
120
skew_thresholds = skew_thresholds_mapping ,
112
121
attribution_score_skew_thresholds = attribution_score_skew_thresholds_mapping ,
122
+ default_skew_threshold = default_skew_threshold ,
113
123
)
114
124
115
125
@@ -266,30 +276,33 @@ class SkewDetectionConfig(_SkewDetectionConfig):
266
276
267
277
def __init__ (
268
278
self ,
269
- data_source : str ,
270
- target_field : str ,
271
- skew_thresholds : Optional [Dict [str , float ]] = None ,
279
+ data_source : Optional [ str ] = None ,
280
+ target_field : Optional [ str ] = None ,
281
+ skew_thresholds : Union [Dict [str , float ], float , None ] = None ,
272
282
attribute_skew_thresholds : Optional [Dict [str , float ]] = None ,
273
283
data_format : Optional [str ] = None ,
274
284
):
275
285
"""Initializer for SkewDetectionConfig.
276
286
277
287
Args:
278
288
data_source (str):
279
- Required . Path to training dataset.
289
+ Optional . Path to training dataset.
280
290
281
291
target_field (str):
282
- Required . The target field name the model is to
292
+ Optional . The target field name the model is to
283
293
predict. This field will be excluded when doing
284
294
Predict and (or) Explain for the training data.
285
295
286
- skew_thresholds ( Dict[str, float]) :
296
+ skew_thresholds: Union[ Dict[str, float], float, None] :
287
297
Optional. Key is the feature name and value is the
288
298
threshold. If a feature needs to be monitored
289
299
for skew, a value threshold must be configured
290
300
for that feature. The threshold here is against
291
301
feature distribution distance between the
292
- training and prediction feature.
302
+ training and prediction feature. If a float is passed,
303
+ then all features will be monitored using the same
304
+ threshold. If None is passed, all feature will be monitored
305
+ using alert threshold 0.3 (Backend default).
293
306
294
307
attribute_skew_thresholds (Dict[str, float]):
295
308
Optional. Key is the feature name and value is the
@@ -315,11 +328,11 @@ def __init__(
315
328
ValueError for unsupported data formats.
316
329
"""
317
330
super ().__init__ (
318
- data_source ,
319
- skew_thresholds ,
320
- target_field ,
321
- attribute_skew_thresholds ,
322
- data_format ,
331
+ data_source = data_source ,
332
+ skew_thresholds = skew_thresholds ,
333
+ target_field = target_field ,
334
+ attribute_skew_thresholds = attribute_skew_thresholds ,
335
+ data_format = data_format ,
323
336
)
324
337
325
338
0 commit comments