@@ -437,6 +437,23 @@ def get(
437
437
except exceptions .NotFound :
438
438
return None
439
439
440
+ def _initialize_experiment_run (
441
+ self ,
442
+ node : Union [context .Context , execution .Execution ],
443
+ experiment : Optional [experiment_resources .Experiment ] = None ,
444
+ ):
445
+ self ._experiment = experiment
446
+ self ._run_name = node .display_name
447
+ self ._metadata_node = node
448
+ self ._largest_step = None
449
+
450
+ if self ._is_legacy_experiment_run ():
451
+ self ._metadata_metric_artifact = self ._v1_get_metric_artifact ()
452
+ self ._backing_tensorboard_run = None
453
+ else :
454
+ self ._metadata_metric_artifact = None
455
+ self ._backing_tensorboard_run = self ._lookup_tensorboard_run_artifact ()
456
+
440
457
@classmethod
441
458
def list (
442
459
cls ,
@@ -495,33 +512,17 @@ def list(
495
512
496
513
run_executions = execution .Execution .list (filter = filter_str , ** metadata_args )
497
514
498
- def _initialize_experiment_run (context : context .Context ) -> ExperimentRun :
515
+ def _create_experiment_run (context : context .Context ) -> ExperimentRun :
499
516
this_experiment_run = cls .__new__ (cls )
500
- this_experiment_run ._experiment = experiment
501
- this_experiment_run ._run_name = context .display_name
502
- this_experiment_run ._metadata_node = context
503
-
504
- with experiment_resources ._SetLoggerLevel (resource ):
505
- tb_run = this_experiment_run ._lookup_tensorboard_run_artifact ()
506
- if tb_run :
507
- this_experiment_run ._backing_tensorboard_run = tb_run
508
- else :
509
- this_experiment_run ._backing_tensorboard_run = None
510
-
511
- this_experiment_run ._largest_step = None
517
+ this_experiment_run ._initialize_experiment_run (context , experiment )
512
518
513
519
return this_experiment_run
514
520
515
- def _initialize_v1_experiment_run (
521
+ def _create_v1_experiment_run (
516
522
execution : execution .Execution ,
517
523
) -> ExperimentRun :
518
524
this_experiment_run = cls .__new__ (cls )
519
- this_experiment_run ._experiment = experiment
520
- this_experiment_run ._run_name = execution .display_name
521
- this_experiment_run ._metadata_node = execution
522
- this_experiment_run ._metadata_metric_artifact = (
523
- this_experiment_run ._v1_get_metric_artifact ()
524
- )
525
+ this_experiment_run ._initialize_experiment_run (execution , experiment )
525
526
526
527
return this_experiment_run
527
528
@@ -530,13 +531,13 @@ def _initialize_v1_experiment_run(
530
531
max_workers = max ([len (run_contexts ), len (run_executions )])
531
532
) as executor :
532
533
submissions = [
533
- executor .submit (_initialize_experiment_run , context )
534
+ executor .submit (_create_experiment_run , context )
534
535
for context in run_contexts
535
536
]
536
537
experiment_runs = [submission .result () for submission in submissions ]
537
538
538
539
submissions = [
539
- executor .submit (_initialize_v1_experiment_run , execution )
540
+ executor .submit (_create_v1_experiment_run , execution )
540
541
for execution in run_executions
541
542
]
542
543
@@ -560,30 +561,20 @@ def _query_experiment_row(
560
561
Experiment run row that represents this run.
561
562
"""
562
563
this_experiment_run = cls .__new__ (cls )
563
- this_experiment_run ._metadata_node = node
564
+ this_experiment_run ._initialize_experiment_run ( node )
564
565
565
566
row = experiment_resources ._ExperimentRow (
566
567
experiment_run_type = node .schema_title ,
567
568
name = node .display_name ,
568
569
)
569
570
570
- if isinstance (node , context .Context ):
571
- this_experiment_run ._backing_tensorboard_run = (
572
- this_experiment_run ._lookup_tensorboard_run_artifact ()
573
- )
574
- row .params = node .metadata [constants ._PARAM_KEY ]
575
- row .metrics = node .metadata [constants ._METRIC_KEY ]
576
- row .time_series_metrics = (
577
- this_experiment_run ._get_latest_time_series_metric_columns ()
578
- )
579
- row .state = node .metadata [constants ._STATE_KEY ]
580
- else :
581
- this_experiment_run ._metadata_metric_artifact = (
582
- this_experiment_run ._v1_get_metric_artifact ()
583
- )
584
- row .params = node .metadata
585
- row .metrics = this_experiment_run ._metadata_metric_artifact .metadata
586
- row .state = node .state .name
571
+ row .params = this_experiment_run .get_params ()
572
+ row .metrics = this_experiment_run .get_metrics ()
573
+ row .state = this_experiment_run .get_state ()
574
+ row .time_series_metrics = (
575
+ this_experiment_run ._get_latest_time_series_metric_columns ()
576
+ )
577
+
587
578
return row
588
579
589
580
def _get_logged_pipeline_runs (self ) -> List [context .Context ]:
@@ -659,7 +650,7 @@ def log(
659
650
660
651
@staticmethod
661
652
def _validate_run_id (run_id : str ):
662
- """Validates the run id
653
+ """Validates the run id.
663
654
664
655
Args:
665
656
run_id(str): Required. The run id to validate.
@@ -1455,6 +1446,13 @@ def get_metrics(self) -> Dict[str, Union[float, int, str]]:
1455
1446
else :
1456
1447
return self ._metadata_node .metadata [constants ._METRIC_KEY ]
1457
1448
1449
+ def get_state (self ) -> gca_execution .Execution .State :
1450
+ """The state of this run."""
1451
+ if self ._is_legacy_experiment_run ():
1452
+ return self ._metadata_node .state .name
1453
+ else :
1454
+ return self ._metadata_node .metadata [constants ._STATE_KEY ]
1455
+
1458
1456
@_v1_not_supported
1459
1457
def get_classification_metrics (self ) -> List [Dict [str , Union [str , List ]]]:
1460
1458
"""Get all the classification metrics logged to this run.
0 commit comments