|
18 | 18 | from typing import Optional, Dict
|
19 | 19 |
|
20 | 20 | from google.cloud.aiplatform_v1.types import (
|
21 |
| - io as gca_io, |
22 |
| - ThresholdConfig as gca_threshold_config, |
23 |
| - model_monitoring as gca_model_monitoring, |
| 21 | + io as gca_io_v1, |
| 22 | + model_monitoring as gca_model_monitoring_v1, |
24 | 23 | )
|
25 | 24 |
|
| 25 | +# TODO: b/242108750 |
| 26 | +from google.cloud.aiplatform_v1beta1.types import ( |
| 27 | + io as gca_io_v1beta1, |
| 28 | + model_monitoring as gca_model_monitoring_v1beta1, |
| 29 | +) |
| 30 | + |
| 31 | +gca_model_monitoring = gca_model_monitoring_v1 |
| 32 | +gca_io = gca_io_v1 |
| 33 | + |
26 | 34 | TF_RECORD = "tf-record"
|
27 | 35 | CSV = "csv"
|
28 | 36 | JSONL = "jsonl"
|
@@ -80,19 +88,20 @@ def __init__(
|
80 | 88 | self.attribute_skew_thresholds = attribute_skew_thresholds
|
81 | 89 | self.data_format = data_format
|
82 | 90 | self.target_field = target_field
|
83 |
| - self.training_dataset = None |
84 | 91 |
|
85 | 92 | def as_proto(self):
|
86 | 93 | """Returns _SkewDetectionConfig as a proto message."""
|
87 | 94 | skew_thresholds_mapping = {}
|
88 | 95 | attribution_score_skew_thresholds_mapping = {}
|
89 | 96 | if self.skew_thresholds is not None:
|
90 | 97 | for key in self.skew_thresholds.keys():
|
91 |
| - skew_threshold = gca_threshold_config(value=self.skew_thresholds[key]) |
| 98 | + skew_threshold = gca_model_monitoring.ThresholdConfig( |
| 99 | + value=self.skew_thresholds[key] |
| 100 | + ) |
92 | 101 | skew_thresholds_mapping[key] = skew_threshold
|
93 | 102 | if self.attribute_skew_thresholds is not None:
|
94 | 103 | for key in self.attribute_skew_thresholds.keys():
|
95 |
| - attribution_score_skew_threshold = gca_threshold_config( |
| 104 | + attribution_score_skew_threshold = gca_model_monitoring.ThresholdConfig( |
96 | 105 | value=self.attribute_skew_thresholds[key]
|
97 | 106 | )
|
98 | 107 | attribution_score_skew_thresholds_mapping[
|
@@ -134,12 +143,16 @@ def as_proto(self):
|
134 | 143 | attribution_score_drift_thresholds_mapping = {}
|
135 | 144 | if self.drift_thresholds is not None:
|
136 | 145 | for key in self.drift_thresholds.keys():
|
137 |
| - drift_threshold = gca_threshold_config(value=self.drift_thresholds[key]) |
| 146 | + drift_threshold = gca_model_monitoring.ThresholdConfig( |
| 147 | + value=self.drift_thresholds[key] |
| 148 | + ) |
138 | 149 | drift_thresholds_mapping[key] = drift_threshold
|
139 | 150 | if self.attribute_drift_thresholds is not None:
|
140 | 151 | for key in self.attribute_drift_thresholds.keys():
|
141 |
| - attribution_score_drift_threshold = gca_threshold_config( |
142 |
| - value=self.attribute_drift_thresholds[key] |
| 152 | + attribution_score_drift_threshold = ( |
| 153 | + gca_model_monitoring.ThresholdConfig( |
| 154 | + value=self.attribute_drift_thresholds[key] |
| 155 | + ) |
143 | 156 | )
|
144 | 157 | attribution_score_drift_thresholds_mapping[
|
145 | 158 | key
|
@@ -186,11 +199,49 @@ def __init__(
|
186 | 199 | self.drift_detection_config = drift_detection_config
|
187 | 200 | self.explanation_config = explanation_config
|
188 | 201 |
|
189 |
| - def as_proto(self): |
190 |
| - """Returns _ObjectiveConfig as a proto message.""" |
| 202 | + # TODO: b/242108750 |
| 203 | + def as_proto(self, config_for_bp: bool = False): |
| 204 | + """Returns _SkewDetectionConfig as a proto message. |
| 205 | +
|
| 206 | + Args: |
| 207 | + config_for_bp (bool): |
| 208 | + Optional. Set this parameter to True if the config object |
| 209 | + is used for model monitoring on a batch prediction job. |
| 210 | + """ |
| 211 | + if config_for_bp: |
| 212 | + gca_io = gca_io_v1beta1 |
| 213 | + gca_model_monitoring = gca_model_monitoring_v1beta1 |
| 214 | + else: |
| 215 | + gca_io = gca_io_v1 |
| 216 | + gca_model_monitoring = gca_model_monitoring_v1 |
191 | 217 | training_dataset = None
|
192 | 218 | if self.skew_detection_config is not None:
|
193 |
| - training_dataset = self.skew_detection_config.training_dataset |
| 219 | + training_dataset = ( |
| 220 | + gca_model_monitoring.ModelMonitoringObjectiveConfig.TrainingDataset( |
| 221 | + target_field=self.skew_detection_config.target_field |
| 222 | + ) |
| 223 | + ) |
| 224 | + if self.skew_detection_config.data_source.startswith("bq:/"): |
| 225 | + training_dataset.bigquery_source = gca_io.BigQuerySource( |
| 226 | + input_uri=self.skew_detection_config.data_source |
| 227 | + ) |
| 228 | + elif self.skew_detection_config.data_source.startswith("gs:/"): |
| 229 | + training_dataset.gcs_source = gca_io.GcsSource( |
| 230 | + uris=[self.skew_detection_config.data_source] |
| 231 | + ) |
| 232 | + if ( |
| 233 | + self.skew_detection_config.data_format is not None |
| 234 | + and self.skew_detection_config.data_format |
| 235 | + not in [TF_RECORD, CSV, JSONL] |
| 236 | + ): |
| 237 | + raise ValueError( |
| 238 | + "Unsupported value in skew detection config. `data_format` must be one of %s, %s, or %s" |
| 239 | + % (TF_RECORD, CSV, JSONL) |
| 240 | + ) |
| 241 | + training_dataset.data_format = self.skew_detection_config.data_format |
| 242 | + else: |
| 243 | + training_dataset.dataset = self.skew_detection_config.data_source |
| 244 | + |
194 | 245 | return gca_model_monitoring.ModelMonitoringObjectiveConfig(
|
195 | 246 | training_dataset=training_dataset,
|
196 | 247 | training_prediction_skew_detection_config=self.skew_detection_config.as_proto()
|
@@ -271,27 +322,6 @@ def __init__(
|
271 | 322 | data_format,
|
272 | 323 | )
|
273 | 324 |
|
274 |
| - training_dataset = ( |
275 |
| - gca_model_monitoring.ModelMonitoringObjectiveConfig.TrainingDataset( |
276 |
| - target_field=target_field |
277 |
| - ) |
278 |
| - ) |
279 |
| - if data_source.startswith("bq:/"): |
280 |
| - training_dataset.bigquery_source = gca_io.BigQuerySource( |
281 |
| - input_uri=data_source |
282 |
| - ) |
283 |
| - elif data_source.startswith("gs:/"): |
284 |
| - training_dataset.gcs_source = gca_io.GcsSource(uris=[data_source]) |
285 |
| - if data_format is not None and data_format not in [TF_RECORD, CSV, JSONL]: |
286 |
| - raise ValueError( |
287 |
| - "Unsupported value. `data_format` must be one of %s, %s, or %s" |
288 |
| - % (TF_RECORD, CSV, JSONL) |
289 |
| - ) |
290 |
| - training_dataset.data_format = data_format |
291 |
| - else: |
292 |
| - training_dataset.dataset = data_source |
293 |
| - self.training_dataset = training_dataset |
294 |
| - |
295 | 325 |
|
296 | 326 | class DriftDetectionConfig(_DriftDetectionConfig):
|
297 | 327 | """A class that configures prediction drift detection for models deployed to an endpoint.
|
|
0 commit comments