Skip to content

Commit 2e56acc

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Improve get_experiment_df execution speed
PiperOrigin-RevId: 619596185
1 parent 57bb955 commit 2e56acc

File tree

3 files changed

+75
-63
lines changed

3 files changed

+75
-63
lines changed

google/cloud/aiplatform/metadata/experiment_resources.py

+35-21
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#
1717

1818
import abc
19+
import concurrent.futures
1920
from dataclasses import dataclass
2021
import logging
2122
from typing import Dict, List, NamedTuple, Optional, Tuple, Type, Union
@@ -448,28 +449,41 @@ def get_data_frame(self) -> "pd.DataFrame": # noqa: F821
448449
executions = execution.Execution.list(filter_str, **service_request_args)
449450

450451
rows = []
451-
for metadata_context in contexts:
452-
row_dict = (
453-
_SUPPORTED_LOGGABLE_RESOURCES[context.Context][
454-
metadata_context.schema_title
452+
if contexts or executions:
453+
with concurrent.futures.ThreadPoolExecutor(
454+
max_workers=max([len(contexts), len(executions)])
455+
) as executor:
456+
futures = [
457+
executor.submit(
458+
_SUPPORTED_LOGGABLE_RESOURCES[context.Context][
459+
metadata_context.schema_title
460+
]._query_experiment_row,
461+
metadata_context,
462+
)
463+
for metadata_context in contexts
455464
]
456-
._query_experiment_row(metadata_context)
457-
.to_dict()
458-
)
459-
row_dict.update({"experiment_name": self.name})
460-
rows.append(row_dict)
461-
462-
# backward compatibility
463-
for metadata_execution in executions:
464-
row_dict = (
465-
_SUPPORTED_LOGGABLE_RESOURCES[execution.Execution][
466-
metadata_execution.schema_title
467-
]
468-
._query_experiment_row(metadata_execution)
469-
.to_dict()
470-
)
471-
row_dict.update({"experiment_name": self.name})
472-
rows.append(row_dict)
465+
466+
# backward compatibility
467+
futures.extend(
468+
executor.submit(
469+
_SUPPORTED_LOGGABLE_RESOURCES[execution.Execution][
470+
metadata_execution.schema_title
471+
]._query_experiment_row,
472+
metadata_execution,
473+
)
474+
for metadata_execution in executions
475+
)
476+
477+
for future in futures:
478+
try:
479+
row_dict = future.result().to_dict()
480+
except Exception as exc:
481+
raise ValueError(
482+
f"Failed to get experiment row for {self.name}"
483+
) from exc
484+
else:
485+
row_dict.update({"experiment_name": self.name})
486+
rows.append(row_dict)
473487

474488
df = pd.DataFrame(rows)
475489

google/cloud/aiplatform/metadata/experiment_run_resource.py

+39-41
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,23 @@ def get(
437437
except exceptions.NotFound:
438438
return None
439439

440+
def _initialize_experiment_run(
441+
self,
442+
node: Union[context.Context, execution.Execution],
443+
experiment: Optional[experiment_resources.Experiment] = None,
444+
):
445+
self._experiment = experiment
446+
self._run_name = node.display_name
447+
self._metadata_node = node
448+
self._largest_step = None
449+
450+
if self._is_legacy_experiment_run():
451+
self._metadata_metric_artifact = self._v1_get_metric_artifact()
452+
self._backing_tensorboard_run = None
453+
else:
454+
self._metadata_metric_artifact = None
455+
self._backing_tensorboard_run = self._lookup_tensorboard_run_artifact()
456+
440457
@classmethod
441458
def list(
442459
cls,
@@ -495,33 +512,17 @@ def list(
495512

496513
run_executions = execution.Execution.list(filter=filter_str, **metadata_args)
497514

498-
def _initialize_experiment_run(context: context.Context) -> ExperimentRun:
515+
def _create_experiment_run(context: context.Context) -> ExperimentRun:
499516
this_experiment_run = cls.__new__(cls)
500-
this_experiment_run._experiment = experiment
501-
this_experiment_run._run_name = context.display_name
502-
this_experiment_run._metadata_node = context
503-
504-
with experiment_resources._SetLoggerLevel(resource):
505-
tb_run = this_experiment_run._lookup_tensorboard_run_artifact()
506-
if tb_run:
507-
this_experiment_run._backing_tensorboard_run = tb_run
508-
else:
509-
this_experiment_run._backing_tensorboard_run = None
510-
511-
this_experiment_run._largest_step = None
517+
this_experiment_run._initialize_experiment_run(context, experiment)
512518

513519
return this_experiment_run
514520

515-
def _initialize_v1_experiment_run(
521+
def _create_v1_experiment_run(
516522
execution: execution.Execution,
517523
) -> ExperimentRun:
518524
this_experiment_run = cls.__new__(cls)
519-
this_experiment_run._experiment = experiment
520-
this_experiment_run._run_name = execution.display_name
521-
this_experiment_run._metadata_node = execution
522-
this_experiment_run._metadata_metric_artifact = (
523-
this_experiment_run._v1_get_metric_artifact()
524-
)
525+
this_experiment_run._initialize_experiment_run(execution, experiment)
525526

526527
return this_experiment_run
527528

@@ -530,13 +531,13 @@ def _initialize_v1_experiment_run(
530531
max_workers=max([len(run_contexts), len(run_executions)])
531532
) as executor:
532533
submissions = [
533-
executor.submit(_initialize_experiment_run, context)
534+
executor.submit(_create_experiment_run, context)
534535
for context in run_contexts
535536
]
536537
experiment_runs = [submission.result() for submission in submissions]
537538

538539
submissions = [
539-
executor.submit(_initialize_v1_experiment_run, execution)
540+
executor.submit(_create_v1_experiment_run, execution)
540541
for execution in run_executions
541542
]
542543

@@ -560,30 +561,20 @@ def _query_experiment_row(
560561
Experiment run row that represents this run.
561562
"""
562563
this_experiment_run = cls.__new__(cls)
563-
this_experiment_run._metadata_node = node
564+
this_experiment_run._initialize_experiment_run(node)
564565

565566
row = experiment_resources._ExperimentRow(
566567
experiment_run_type=node.schema_title,
567568
name=node.display_name,
568569
)
569570

570-
if isinstance(node, context.Context):
571-
this_experiment_run._backing_tensorboard_run = (
572-
this_experiment_run._lookup_tensorboard_run_artifact()
573-
)
574-
row.params = node.metadata[constants._PARAM_KEY]
575-
row.metrics = node.metadata[constants._METRIC_KEY]
576-
row.time_series_metrics = (
577-
this_experiment_run._get_latest_time_series_metric_columns()
578-
)
579-
row.state = node.metadata[constants._STATE_KEY]
580-
else:
581-
this_experiment_run._metadata_metric_artifact = (
582-
this_experiment_run._v1_get_metric_artifact()
583-
)
584-
row.params = node.metadata
585-
row.metrics = this_experiment_run._metadata_metric_artifact.metadata
586-
row.state = node.state.name
571+
row.params = this_experiment_run.get_params()
572+
row.metrics = this_experiment_run.get_metrics()
573+
row.state = this_experiment_run.get_state()
574+
row.time_series_metrics = (
575+
this_experiment_run._get_latest_time_series_metric_columns()
576+
)
577+
587578
return row
588579

589580
def _get_logged_pipeline_runs(self) -> List[context.Context]:
@@ -659,7 +650,7 @@ def log(
659650

660651
@staticmethod
661652
def _validate_run_id(run_id: str):
662-
"""Validates the run id
653+
"""Validates the run id.
663654
664655
Args:
665656
run_id(str): Required. The run id to validate.
@@ -1455,6 +1446,13 @@ def get_metrics(self) -> Dict[str, Union[float, int, str]]:
14551446
else:
14561447
return self._metadata_node.metadata[constants._METRIC_KEY]
14571448

1449+
def get_state(self) -> gca_execution.Execution.State:
1450+
"""The state of this run."""
1451+
if self._is_legacy_experiment_run():
1452+
return self._metadata_node.state.name
1453+
else:
1454+
return self._metadata_node.metadata[constants._STATE_KEY]
1455+
14581456
@_v1_not_supported
14591457
def get_classification_metrics(self) -> List[Dict[str, Union[str, List]]]:
14601458
"""Get all the classification metrics logged to this run.

google/cloud/aiplatform/metadata/metadata.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -780,7 +780,7 @@ def get_experiment_df(
780780
aiplatform.log_params({'learning_rate': 0.2})
781781
aiplatform.log_metrics({'accuracy': 0.95})
782782
783-
aiplatform.get_experiments_df()
783+
aiplatform.get_experiment_df()
784784
```
785785
786786
Will result in the following DataFrame:

0 commit comments

Comments
 (0)