20
20
from collections import defaultdict
21
21
import functools
22
22
import logging
23
- import os
24
23
import re
25
24
import time
26
25
from typing import ContextManager , Dict , FrozenSet , Generator , Iterable , Optional , Tuple
27
26
import uuid
28
27
29
- from google .api_core import exceptions
30
28
from google .cloud import storage
31
29
from google .cloud .aiplatform import base
32
30
from google .cloud .aiplatform .compat .services import (
33
31
tensorboard_service_client ,
34
32
)
35
33
from google .cloud .aiplatform .compat .types import tensorboard_data
36
- from google .cloud .aiplatform .compat .types import tensorboard_experiment
37
34
from google .cloud .aiplatform .compat .types import tensorboard_service
38
35
from google .cloud .aiplatform .compat .types import tensorboard_time_series
36
+ from google .cloud .aiplatform .metadata import experiment_resources
37
+ from google .cloud .aiplatform .metadata import metadata
39
38
from google .cloud .aiplatform .tensorboard import logdir_loader
40
39
from google .cloud .aiplatform .tensorboard import upload_tracker
41
40
from google .cloud .aiplatform .tensorboard import uploader_constants
@@ -215,47 +214,45 @@ def active_filter(secs):
215
214
216
215
self ._create_additional_senders ()
217
216
218
- def _create_or_get_experiment (self ) -> tensorboard_experiment .TensorboardExperiment :
219
- """Create an experiment or get an experiment.
220
-
221
- Attempts to create an experiment. If the experiment already exists and
222
- creation fails then the experiment will be retrieved.
217
+ def create_experiment (self ):
218
+ """Creates an Experiment for this upload session.
223
219
224
- Returns:
225
- The created or retrieved experiment .
220
+ Sets the tensorboard resource and experiment, which will get or create a
221
+ Vertex Experiment and associate it with a Tensorboard Experiment .
226
222
"""
227
- logger . info ( "Creating experiment" )
223
+ m = self . _api . parse_tensorboard_path ( self . _tensorboard_resource_name )
228
224
229
- tb_experiment = tensorboard_experiment .TensorboardExperiment (
230
- description = self ._description , display_name = self ._experiment_display_name
225
+ existing_experiment = experiment_resources .Experiment .get (
226
+ experiment_name = self ._experiment_name ,
227
+ project = m ["project" ],
228
+ location = m ["location" ],
231
229
)
232
-
233
- try :
234
- experiment = self ._api .create_tensorboard_experiment (
235
- parent = self ._tensorboard_resource_name ,
236
- tensorboard_experiment = tb_experiment ,
237
- tensorboard_experiment_id = self ._experiment_name ,
238
- )
230
+ if not existing_experiment :
239
231
self ._is_brand_new_experiment = True
240
- except exceptions .AlreadyExists :
241
- logger .info ("Creating experiment failed. Retrieving experiment." )
242
- experiment_name = os .path .join (
243
- self ._tensorboard_resource_name , "experiments" , self ._experiment_name
244
- )
245
- experiment = self ._api .get_tensorboard_experiment (name = experiment_name )
246
- return experiment
247
232
248
- def create_experiment (self ):
249
- """Creates an Experiment for this upload session and returns the ID."""
233
+ metadata ._experiment_tracker .reset ()
234
+ metadata ._experiment_tracker .set_tensorboard (
235
+ tensorboard = self ._tensorboard_resource_name ,
236
+ project = m ["project" ],
237
+ location = m ["location" ],
238
+ )
239
+ metadata ._experiment_tracker .set_experiment (
240
+ project = m ["project" ],
241
+ location = m ["location" ],
242
+ experiment = self ._experiment_name ,
243
+ description = self ._description ,
244
+ backing_tensorboard = self ._tensorboard_resource_name ,
245
+ )
250
246
251
- experiment = self ._create_or_get_experiment ()
252
- self ._experiment = experiment
247
+ self ._tensorboard_experiment_resource_name = (
248
+ f"{ self ._tensorboard_resource_name } /experiments/{ self ._experiment_name } "
249
+ )
253
250
self ._one_platform_resource_manager = uploader_utils .OnePlatformResourceManager (
254
- self ._experiment . name , self ._api
251
+ self ._tensorboard_experiment_resource_name , self ._api
255
252
)
256
253
257
254
self ._request_sender = _BatchedRequestSender (
258
- self ._experiment . name ,
255
+ self ._tensorboard_experiment_resource_name ,
259
256
self ._api ,
260
257
allowed_plugins = self ._allowed_plugins ,
261
258
upload_limits = self ._upload_limits ,
@@ -271,7 +268,7 @@ def create_experiment(self):
271
268
# Update partials with experiment name
272
269
for sender in self ._additional_senders .keys ():
273
270
self ._additional_senders [sender ] = self ._additional_senders [sender ](
274
- experiment_resource_name = self ._experiment . name ,
271
+ experiment_resource_name = self ._tensorboard_experiment_resource_name ,
275
272
)
276
273
277
274
self ._dispatcher = _Dispatcher (
@@ -310,7 +307,7 @@ def _create_additional_senders(self) -> Dict[str, uploader_utils.RequestSender]:
310
307
)
311
308
312
309
def get_experiment_resource_name (self ):
313
- return self ._experiment . name
310
+ return self ._tensorboard_experiment_resource_name
314
311
315
312
def start_uploading (self ):
316
313
"""Blocks forever to continuously upload data from the logdir.
0 commit comments