@@ -785,7 +785,11 @@ def batch_create_time_series(parent, requests):
785
785
with mock .patch .object (
786
786
uploader , "_logdir_loader_pre_create" , mock_logdir_loader_pre_create
787
787
):
788
- uploader .start_uploading ()
788
+ with mock .patch .object (
789
+ uploader , "_end_experiment_runs" , return_value = None
790
+ ):
791
+ uploader .start_uploading ()
792
+ uploader ._end_experiment_runs .assert_called_once ()
789
793
790
794
self .assertEqual (existing_experiment is None , uploader ._is_brand_new_experiment )
791
795
self .assertEqual (2 , mock_client .write_tensorboard_experiment_data .call_count )
@@ -797,6 +801,7 @@ def batch_create_time_series(parent, requests):
797
801
self .assertLen (mock_tracker .scalars_tracker .call_args [0 ], 1 )
798
802
self .assertEqual (mock_tracker .tensors_tracker .call_count , 0 )
799
803
self .assertEqual (mock_tracker .blob_tracker .call_count , 0 )
804
+ experiment_tracker_mock .set_experiment .assert_called_once ()
800
805
801
806
@patch .object (metadata , "_experiment_tracker" , autospec = True )
802
807
@patch .object (experiment_resources , "Experiment" , autospec = True )
@@ -814,6 +819,7 @@ def test_upload_empty_logdir(
814
819
uploader .create_experiment ()
815
820
uploader ._upload_once ()
816
821
mock_client .write_tensorboard_experiment_data .assert_not_called ()
822
+ experiment_tracker_mock .set_experiment .assert_called_once ()
817
823
818
824
@patch .object (metadata , "_experiment_tracker" , autospec = True )
819
825
@patch .object (experiment_resources , "Experiment" , autospec = True )
@@ -847,6 +853,7 @@ def mock_upload_once():
847
853
uploader .create_experiment ()
848
854
with self .assertRaises (SuccessError ):
849
855
uploader .start_uploading ()
856
+ experiment_tracker_mock .set_experiment .assert_called_once ()
850
857
851
858
@patch .object (
852
859
uploader_utils .OnePlatformResourceManager ,
@@ -874,6 +881,7 @@ def test_upload_swallows_rpc_failure(
874
881
mock_client .write_tensorboard_experiment_data .side_effect = error
875
882
uploader ._upload_once ()
876
883
mock_client .write_tensorboard_experiment_data .assert_called_once ()
884
+ experiment_tracker_mock .set_experiment .assert_called_once ()
877
885
878
886
@patch .object (
879
887
uploader_utils .OnePlatformResourceManager ,
@@ -1006,10 +1014,12 @@ def test_upload_full_logdir(
1006
1014
self .assertProtoEquals (expected_request3 [1 ], request3 [1 ])
1007
1015
self .assertProtoEquals (expected_request4 [0 ], request4 [0 ])
1008
1016
mock_client .write_tensorboard_experiment_data .reset_mock ()
1017
+ experiment_tracker_mock .set_experiment .assert_called_once ()
1009
1018
1010
1019
# Empty third round
1011
1020
uploader ._upload_once ()
1012
1021
mock_client .write_tensorboard_experiment_data .assert_not_called ()
1022
+ experiment_tracker_mock .set_experiment .assert_called_once ()
1013
1023
1014
1024
@patch .object (
1015
1025
uploader_utils .OnePlatformResourceManager ,
@@ -1057,6 +1067,7 @@ def test_verbosity_zero_creates_upload_tracker_with_verbosity_zero(
1057
1067
self .assertEqual (mock_constructor .call_count , 1 )
1058
1068
self .assertEqual (mock_constructor .call_args [1 ], {"verbosity" : 0 })
1059
1069
self .assertEqual (mock_tracker .scalars_tracker .call_count , 1 )
1070
+ experiment_tracker_mock .set_experiment .assert_called_once ()
1060
1071
1061
1072
@patch .object (
1062
1073
uploader_utils .OnePlatformResourceManager ,
@@ -1160,6 +1171,7 @@ def create_time_series(tensorboard_time_series, parent=None):
1160
1171
self .assertEqual (mock_tracker .scalars_tracker .call_count , 0 )
1161
1172
self .assertEqual (mock_tracker .tensors_tracker .call_count , 0 )
1162
1173
self .assertEqual (mock_tracker .blob_tracker .call_count , 12 )
1174
+ experiment_tracker_mock .set_experiment .assert_called_once ()
1163
1175
1164
1176
@patch .object (
1165
1177
uploader_utils .OnePlatformResourceManager ,
@@ -1282,6 +1294,27 @@ def test_profile_plugin_included_by_default(
1282
1294
profile_sender = senders ["profile" ]
1283
1295
self .assertIn (run_name , profile_sender ._run_to_profile_loaders )
1284
1296
self .assertIn (run_name , profile_sender ._run_to_file_request_sender )
1297
+ experiment_tracker_mock .set_experiment .assert_called_once ()
1298
+
1299
+ @patch .object (metadata , "_experiment_tracker" , autospec = True )
1300
+ @patch .object (experiment_resources , "Experiment" , autospec = True )
1301
+ def test_active_experiment_set_experiment_not_called (
1302
+ self , experiment_resources_mock , experiment_tracker_mock
1303
+ ):
1304
+ experiment_resources_mock .get .return_value = _TEST_EXPERIMENT_NAME
1305
+ experiment_tracker_mock .set_experiment .return_value = _TEST_EXPERIMENT_NAME
1306
+ experiment_tracker_mock .experiment_name = _TEST_EXPERIMENT_NAME
1307
+ experiment_tracker_mock .set_tensorboard .return_value = (
1308
+ _TEST_TENSORBOARD_RESOURCE_NAME
1309
+ )
1310
+ logdir = self .get_temp_dir ()
1311
+ mock_client = _create_mock_client ()
1312
+
1313
+ uploader = _create_uploader (mock_client , logdir )
1314
+ uploader .create_experiment ()
1315
+ uploader ._upload_once ()
1316
+
1317
+ experiment_tracker_mock .set_experiment .assert_not_called ()
1285
1318
1286
1319
1287
1320
# TODO(b/276368161)
@@ -1387,6 +1420,7 @@ def test_thread_continuously_uploads(
1387
1420
self .assertEqual (b"12345" , request2 .plugin_data )
1388
1421
self .assertEqual ("scalars" , request3 .plugin_name )
1389
1422
self .assertEqual ("profile" , request4 .plugin_name )
1423
+ experiment_tracker_mock .set_experiment .assert_called_once ()
1390
1424
1391
1425
# Check write_tensorboard_experiment_data calls
1392
1426
self .assertEqual (1 , mock_client .write_tensorboard_experiment_data .call_count )
@@ -1425,17 +1459,22 @@ def test_thread_continuously_uploads(
1425
1459
self .assertProtoEquals (expected_request1 [1 ], request1 [1 ])
1426
1460
self .assertProtoEquals (expected_request2 [0 ], request2 [0 ])
1427
1461
1428
- uploader ._end_uploading ()
1462
+ with mock .patch .object (uploader , "_end_experiment_runs" , return_value = None ):
1463
+ uploader ._end_uploading ()
1464
+ uploader ._end_experiment_runs .assert_called_once ()
1429
1465
time .sleep (1 )
1430
1466
self .assertFalse (uploader_thread .is_alive ())
1431
1467
mock_client .write_tensorboard_experiment_data .reset_mock ()
1432
1468
1433
1469
# Empty directory
1434
1470
uploader ._upload_once ()
1435
1471
mock_client .write_tensorboard_experiment_data .assert_not_called ()
1436
- uploader ._end_uploading ()
1472
+ with mock .patch .object (uploader , "_end_experiment_runs" , return_value = None ):
1473
+ uploader ._end_uploading ()
1474
+ uploader ._end_experiment_runs .assert_called_once ()
1437
1475
time .sleep (1 )
1438
1476
self .assertFalse (uploader_thread .is_alive ())
1477
+ experiment_tracker_mock .set_experiment .assert_called_once ()
1439
1478
1440
1479
1441
1480
@pytest .mark .usefixtures ("google_auth_mock" )
0 commit comments