@@ -192,23 +192,22 @@ def init(
192
192
ValueError:
193
193
If experiment_description is provided but experiment is not.
194
194
"""
195
-
196
- if api_endpoint is not None :
197
- self ._api_endpoint = api_endpoint
198
-
195
+ # This method mutates state, so we need to be careful with the validation
196
+ # First, we need to validate all passed values
197
+ if api_transport :
198
+ VALID_TRANSPORT_TYPES = ["grpc" , "rest" ]
199
+ if api_transport not in VALID_TRANSPORT_TYPES :
200
+ raise ValueError (
201
+ f"{ api_transport } is not a valid transport type. "
202
+ + f"Valid transport types: { VALID_TRANSPORT_TYPES } "
203
+ )
204
+ if location :
205
+ utils .validate_region (location )
199
206
if experiment_description and experiment is None :
200
207
raise ValueError (
201
208
"Experiment needs to be set in `init` in order to add experiment descriptions."
202
209
)
203
210
204
- if experiment_tensorboard and not isinstance (experiment_tensorboard , bool ):
205
- metadata ._experiment_tracker .set_tensorboard (
206
- tensorboard = experiment_tensorboard ,
207
- project = project ,
208
- location = location ,
209
- credentials = credentials ,
210
- )
211
-
212
211
# reset metadata_service config if project or location is updated.
213
212
if (project and project != self ._project ) or (
214
213
location and location != self ._location
@@ -217,10 +216,14 @@ def init(
217
216
logging .info ("project/location updated, reset Experiment config." )
218
217
metadata ._experiment_tracker .reset ()
219
218
219
+ # Then we change the main state
220
+ if api_endpoint is not None :
221
+ self ._api_endpoint = api_endpoint
222
+ if api_transport :
223
+ self ._api_transport = api_transport
220
224
if project :
221
225
self ._project = project
222
226
if location :
223
- utils .validate_region (location )
224
227
self ._location = location
225
228
if staging_bucket :
226
229
self ._staging_bucket = staging_bucket
@@ -233,22 +236,22 @@ def init(
233
236
if service_account is not None :
234
237
self ._service_account = service_account
235
238
239
+ # Finally, perform secondary state updates
240
+ if experiment_tensorboard and not isinstance (experiment_tensorboard , bool ):
241
+ metadata ._experiment_tracker .set_tensorboard (
242
+ tensorboard = experiment_tensorboard ,
243
+ project = project ,
244
+ location = location ,
245
+ credentials = credentials ,
246
+ )
247
+
236
248
if experiment :
237
249
metadata ._experiment_tracker .set_experiment (
238
250
experiment = experiment ,
239
251
description = experiment_description ,
240
252
backing_tensorboard = experiment_tensorboard ,
241
253
)
242
254
243
- if api_transport :
244
- VALID_TRANSPORT_TYPES = ["grpc" , "rest" ]
245
- if api_transport not in VALID_TRANSPORT_TYPES :
246
- raise ValueError (
247
- f"{ api_transport } is not a valid transport type. "
248
- + f"Valid transport types: { VALID_TRANSPORT_TYPES } "
249
- )
250
- self ._api_transport = api_transport
251
-
252
255
def get_encryption_spec (
253
256
self ,
254
257
encryption_spec_key_name : Optional [str ],
0 commit comments