Skip to content

Commit 3526b3e

Browse files
Feat: Add google.ClassificationMetrics, google.RegressionMetrics, and google.Forecasting Metrics (#1549)
* Add google.ClassificationMetrics, google.RegressionMetrics, and google.ForecastingMetrics Artifact types to metadata schema with unit tests. * fix implicit false * Fix typo * Running nox -s blacken and nox -s lint * fix typo in unit test Co-authored-by: sina chavoshi <[email protected]>
1 parent caebb47 commit 3526b3e

File tree

2 files changed

+393
-0
lines changed

2 files changed

+393
-0
lines changed

google/cloud/aiplatform/metadata/schema/google/artifact_schema.py

Lines changed: 258 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,3 +268,261 @@ def __init__(
268268
metadata=extended_metadata,
269269
state=state,
270270
)
271+
272+
273+
class ClassificationMetrics(base_artifact.BaseArtifactSchema):
274+
"""A Google artifact representing evaluation Classification Metrics."""
275+
276+
schema_title = "google.ClassificationMetrics"
277+
278+
def __init__(
279+
self,
280+
*,
281+
au_prc: Optional[float] = None,
282+
au_roc: Optional[float] = None,
283+
log_loss: Optional[float] = None,
284+
artifact_id: Optional[str] = None,
285+
uri: Optional[str] = None,
286+
display_name: Optional[str] = None,
287+
schema_version: Optional[str] = None,
288+
description: Optional[str] = None,
289+
metadata: Optional[Dict] = None,
290+
state: Optional[gca_artifact.Artifact.State] = gca_artifact.Artifact.State.LIVE,
291+
):
292+
"""Args:
293+
au_prc (float):
294+
Optional. The Area Under Precision-Recall Curve metric.
295+
Micro-averaged for the overall evaluation.
296+
au_roc (float):
297+
Optional. The Area Under Receiver Operating Characteristic curve metric.
298+
Micro-averaged for the overall evaluation.
299+
log_loss (float):
300+
Optional. The Log Loss metric.
301+
artifact_id (str):
302+
Optional. The <resource_id> portion of the Artifact name with
303+
the format. This is globally unique in a metadataStore:
304+
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/artifacts/<resource_id>.
305+
uri (str):
306+
Optional. The uniform resource identifier of the artifact file. May be empty if there is no actual
307+
artifact file.
308+
display_name (str):
309+
Optional. The user-defined name of the Artifact.
310+
schema_version (str):
311+
Optional. schema_version specifies the version used by the Artifact.
312+
If not set, defaults to use the latest version.
313+
description (str):
314+
Optional. Describes the purpose of the Artifact to be created.
315+
metadata (Dict):
316+
Optional. Contains the metadata information that will be stored in the Artifact.
317+
state (google.cloud.gapic.types.Artifact.State):
318+
Optional. The state of this Artifact. This is a
319+
property of the Artifact, and does not imply or
320+
capture any ongoing process. This property is
321+
managed by clients (such as Vertex AI
322+
Pipelines), and the system does not prescribe or
323+
check the validity of state transitions.
324+
"""
325+
extended_metadata = copy.deepcopy(metadata) if metadata else {}
326+
if au_prc:
327+
extended_metadata["auPrc"] = au_prc
328+
if au_roc:
329+
extended_metadata["auRoc"] = au_roc
330+
if log_loss:
331+
extended_metadata["logLoss"] = log_loss
332+
333+
super(ClassificationMetrics, self).__init__(
334+
uri=uri,
335+
artifact_id=artifact_id,
336+
display_name=display_name,
337+
schema_version=schema_version,
338+
description=description,
339+
metadata=extended_metadata,
340+
state=state,
341+
)
342+
343+
344+
class RegressionMetrics(base_artifact.BaseArtifactSchema):
345+
"""A Google artifact representing evaluation Regression Metrics."""
346+
347+
schema_title = "google.RegressionMetrics"
348+
349+
def __init__(
350+
self,
351+
*,
352+
root_mean_squared_error: Optional[float] = None,
353+
mean_absolute_error: Optional[float] = None,
354+
mean_absolute_percentage_error: Optional[float] = None,
355+
r_squared: Optional[float] = None,
356+
root_mean_squared_log_error: Optional[float] = None,
357+
artifact_id: Optional[str] = None,
358+
uri: Optional[str] = None,
359+
display_name: Optional[str] = None,
360+
schema_version: Optional[str] = None,
361+
description: Optional[str] = None,
362+
metadata: Optional[Dict] = None,
363+
state: Optional[gca_artifact.Artifact.State] = gca_artifact.Artifact.State.LIVE,
364+
):
365+
"""Args:
366+
root_mean_squared_error (float):
367+
Optional. Root Mean Squared Error (RMSE).
368+
mean_absolute_error (float):
369+
Optional. Mean Absolute Error (MAE).
370+
mean_absolute_percentage_error (float):
371+
Optional. Mean absolute percentage error.
372+
r_squared (float):
373+
Optional. Coefficient of determination as Pearson correlation coefficient.
374+
root_mean_squared_log_error (float):
375+
Optional. Root mean squared log error.
376+
artifact_id (str):
377+
Optional. The <resource_id> portion of the Artifact name with
378+
the format. This is globally unique in a metadataStore:
379+
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/artifacts/<resource_id>.
380+
uri (str):
381+
Optional. The uniform resource identifier of the artifact file. May be empty if there is no actual
382+
artifact file.
383+
display_name (str):
384+
Optional. The user-defined name of the Artifact.
385+
schema_version (str):
386+
Optional. schema_version specifies the version used by the Artifact.
387+
If not set, defaults to use the latest version.
388+
description (str):
389+
Optional. Describes the purpose of the Artifact to be created.
390+
metadata (Dict):
391+
Optional. Contains the metadata information that will be stored in the Artifact.
392+
state (google.cloud.gapic.types.Artifact.State):
393+
Optional. The state of this Artifact. This is a
394+
property of the Artifact, and does not imply or
395+
capture any ongoing process. This property is
396+
managed by clients (such as Vertex AI
397+
Pipelines), and the system does not prescribe or
398+
check the validity of state transitions.
399+
"""
400+
extended_metadata = copy.deepcopy(metadata) if metadata else {}
401+
if root_mean_squared_error:
402+
extended_metadata["rootMeanSquaredError"] = root_mean_squared_error
403+
if mean_absolute_error:
404+
extended_metadata["meanAbsoluteError"] = mean_absolute_error
405+
if mean_absolute_percentage_error:
406+
extended_metadata[
407+
"meanAbsolutePercentageError"
408+
] = mean_absolute_percentage_error
409+
if r_squared:
410+
extended_metadata["rSquared"] = r_squared
411+
if root_mean_squared_log_error:
412+
extended_metadata["rootMeanSquaredLogError"] = root_mean_squared_log_error
413+
414+
super(RegressionMetrics, self).__init__(
415+
uri=uri,
416+
artifact_id=artifact_id,
417+
display_name=display_name,
418+
schema_version=schema_version,
419+
description=description,
420+
metadata=extended_metadata,
421+
state=state,
422+
)
423+
424+
425+
class ForecastingMetrics(base_artifact.BaseArtifactSchema):
426+
"""A Google artifact representing evaluation Forecasting Metrics."""
427+
428+
schema_title = "google.ForecastingMetrics"
429+
430+
def __init__(
431+
self,
432+
*,
433+
root_mean_squared_error: Optional[float] = None,
434+
mean_absolute_error: Optional[float] = None,
435+
mean_absolute_percentage_error: Optional[float] = None,
436+
r_squared: Optional[float] = None,
437+
root_mean_squared_log_error: Optional[float] = None,
438+
weighted_absolute_percentage_error: Optional[float] = None,
439+
root_mean_squared_percentage_error: Optional[float] = None,
440+
symmetric_mean_absolute_percentage_error: Optional[float] = None,
441+
artifact_id: Optional[str] = None,
442+
uri: Optional[str] = None,
443+
display_name: Optional[str] = None,
444+
schema_version: Optional[str] = None,
445+
description: Optional[str] = None,
446+
metadata: Optional[Dict] = None,
447+
state: Optional[gca_artifact.Artifact.State] = gca_artifact.Artifact.State.LIVE,
448+
):
449+
"""Args:
450+
root_mean_squared_error (float):
451+
Optional. Root Mean Squared Error (RMSE).
452+
mean_absolute_error (float):
453+
Optional. Mean Absolute Error (MAE).
454+
mean_absolute_percentage_error (float):
455+
Optional. Mean absolute percentage error.
456+
r_squared (float):
457+
Optional. Coefficient of determination as Pearson correlation coefficient.
458+
root_mean_squared_log_error (float):
459+
Optional. Root mean squared log error.
460+
weighted_absolute_percentage_error (float):
461+
Optional. Weighted Absolute Percentage Error.
462+
Does not use weights, this is just what the metric is called.
463+
Undefined if actual values sum to zero.
464+
Will be very large if actual values sum to a very small number.
465+
root_mean_squared_percentage_error (float):
466+
Optional. Root Mean Square Percentage Error. Square root of MSPE.
467+
Undefined/imaginary when MSPE is negative.
468+
symmetric_mean_absolute_percentage_error (float):
469+
Optional. Symmetric Mean Absolute Percentage Error.
470+
artifact_id (str):
471+
Optional. The <resource_id> portion of the Artifact name with
472+
the format. This is globally unique in a metadataStore:
473+
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/artifacts/<resource_id>.
474+
uri (str):
475+
Optional. The uniform resource identifier of the artifact file. May be empty if there is no actual
476+
artifact file.
477+
display_name (str):
478+
Optional. The user-defined name of the Artifact.
479+
schema_version (str):
480+
Optional. schema_version specifies the version used by the Artifact.
481+
If not set, defaults to use the latest version.
482+
description (str):
483+
Optional. Describes the purpose of the Artifact to be created.
484+
metadata (Dict):
485+
Optional. Contains the metadata information that will be stored in the Artifact.
486+
state (google.cloud.gapic.types.Artifact.State):
487+
Optional. The state of this Artifact. This is a
488+
property of the Artifact, and does not imply or
489+
capture any ongoing process. This property is
490+
managed by clients (such as Vertex AI
491+
Pipelines), and the system does not prescribe or
492+
check the validity of state transitions.
493+
"""
494+
extended_metadata = copy.deepcopy(metadata) if metadata else {}
495+
if root_mean_squared_error:
496+
extended_metadata["rootMeanSquaredError"] = root_mean_squared_error
497+
if mean_absolute_error:
498+
extended_metadata["meanAbsoluteError"] = mean_absolute_error
499+
if mean_absolute_percentage_error:
500+
extended_metadata[
501+
"meanAbsolutePercentageError"
502+
] = mean_absolute_percentage_error
503+
if r_squared:
504+
extended_metadata["rSquared"] = r_squared
505+
if root_mean_squared_log_error:
506+
extended_metadata["rootMeanSquaredLogError"] = root_mean_squared_log_error
507+
if weighted_absolute_percentage_error:
508+
extended_metadata[
509+
"weightedAbsolutePercentageError"
510+
] = weighted_absolute_percentage_error
511+
if root_mean_squared_percentage_error:
512+
extended_metadata[
513+
"rootMeanSquaredPercentageError"
514+
] = root_mean_squared_percentage_error
515+
if symmetric_mean_absolute_percentage_error:
516+
extended_metadata[
517+
"symmetricMeanAbsolutePercentageError"
518+
] = symmetric_mean_absolute_percentage_error
519+
520+
super(ForecastingMetrics, self).__init__(
521+
uri=uri,
522+
artifact_id=artifact_id,
523+
display_name=display_name,
524+
schema_version=schema_version,
525+
description=description,
526+
metadata=extended_metadata,
527+
state=state,
528+
)

0 commit comments

Comments
 (0)