76
76
_TEST_BQ_JOB_ID = "123459876"
77
77
_TEST_BQ_MAX_RESULTS = 100
78
78
_TEST_GCS_BUCKET_NAME = "my-bucket"
79
+ _TEST_SERVICE_ACCOUNT = "[email protected] "
80
+
79
81
80
82
_TEST_BQ_PATH = f"bq://{ _TEST_BQ_PROJECT_ID } .{ _TEST_BQ_DATASET_ID } "
81
83
_TEST_GCS_BUCKET_PATH = f"gs://{ _TEST_GCS_BUCKET_NAME } "
@@ -719,6 +721,7 @@ def test_batch_predict_gcs_source_and_dest(
719
721
gcs_destination_prefix = _TEST_BATCH_PREDICTION_GCS_DEST_PREFIX ,
720
722
sync = sync ,
721
723
create_request_timeout = None ,
724
+ service_account = _TEST_SERVICE_ACCOUNT ,
722
725
)
723
726
724
727
batch_prediction_job .wait_for_resource_creation ()
@@ -741,6 +744,7 @@ def test_batch_predict_gcs_source_and_dest(
741
744
),
742
745
predictions_format = "jsonl" ,
743
746
),
747
+ service_account = _TEST_SERVICE_ACCOUNT ,
744
748
)
745
749
746
750
create_batch_prediction_job_mock .assert_called_once_with (
@@ -766,6 +770,7 @@ def test_batch_predict_gcs_source_and_dest_with_timeout(
766
770
gcs_destination_prefix = _TEST_BATCH_PREDICTION_GCS_DEST_PREFIX ,
767
771
sync = sync ,
768
772
create_request_timeout = 180.0 ,
773
+ service_account = _TEST_SERVICE_ACCOUNT ,
769
774
)
770
775
771
776
batch_prediction_job .wait_for_resource_creation ()
@@ -788,6 +793,7 @@ def test_batch_predict_gcs_source_and_dest_with_timeout(
788
793
),
789
794
predictions_format = "jsonl" ,
790
795
),
796
+ service_account = _TEST_SERVICE_ACCOUNT ,
791
797
)
792
798
793
799
create_batch_prediction_job_mock .assert_called_once_with (
@@ -812,6 +818,7 @@ def test_batch_predict_gcs_source_and_dest_with_timeout_not_explicitly_set(
812
818
gcs_source = _TEST_BATCH_PREDICTION_GCS_SOURCE ,
813
819
gcs_destination_prefix = _TEST_BATCH_PREDICTION_GCS_DEST_PREFIX ,
814
820
sync = sync ,
821
+ service_account = _TEST_SERVICE_ACCOUNT ,
815
822
)
816
823
817
824
batch_prediction_job .wait_for_resource_creation ()
@@ -834,6 +841,7 @@ def test_batch_predict_gcs_source_and_dest_with_timeout_not_explicitly_set(
834
841
),
835
842
predictions_format = "jsonl" ,
836
843
),
844
+ service_account = _TEST_SERVICE_ACCOUNT ,
837
845
)
838
846
839
847
create_batch_prediction_job_mock .assert_called_once_with (
@@ -855,6 +863,7 @@ def test_batch_predict_job_done_create(self, create_batch_prediction_job_mock):
855
863
gcs_source = _TEST_BATCH_PREDICTION_GCS_SOURCE ,
856
864
gcs_destination_prefix = _TEST_BATCH_PREDICTION_GCS_DEST_PREFIX ,
857
865
sync = False ,
866
+ service_account = _TEST_SERVICE_ACCOUNT ,
858
867
)
859
868
860
869
batch_prediction_job .wait_for_resource_creation ()
@@ -881,6 +890,7 @@ def test_batch_predict_gcs_source_bq_dest(
881
890
bigquery_destination_prefix = _TEST_BATCH_PREDICTION_BQ_PREFIX ,
882
891
sync = sync ,
883
892
create_request_timeout = None ,
893
+ service_account = _TEST_SERVICE_ACCOUNT ,
884
894
)
885
895
886
896
batch_prediction_job .wait_for_resource_creation ()
@@ -908,6 +918,7 @@ def test_batch_predict_gcs_source_bq_dest(
908
918
),
909
919
predictions_format = "bigquery" ,
910
920
),
921
+ service_account = _TEST_SERVICE_ACCOUNT ,
911
922
)
912
923
913
924
create_batch_prediction_job_mock .assert_called_once_with (
@@ -946,6 +957,7 @@ def test_batch_predict_with_all_args(
946
957
sync = sync ,
947
958
create_request_timeout = None ,
948
959
batch_size = _TEST_BATCH_SIZE ,
960
+ service_account = _TEST_SERVICE_ACCOUNT ,
949
961
)
950
962
951
963
batch_prediction_job .wait_for_resource_creation ()
@@ -986,6 +998,7 @@ def test_batch_predict_with_all_args(
986
998
parameters = _TEST_EXPLANATION_PARAMETERS ,
987
999
),
988
1000
labels = _TEST_LABEL ,
1001
+ service_account = _TEST_SERVICE_ACCOUNT ,
989
1002
)
990
1003
991
1004
create_batch_prediction_job_with_explanations_mock .assert_called_once_with (
@@ -1047,6 +1060,7 @@ def test_batch_predict_with_all_args_and_model_monitoring(
1047
1060
model_monitoring_objective_config = mm_obj_cfg ,
1048
1061
model_monitoring_alert_config = mm_alert_cfg ,
1049
1062
analysis_instance_schema_uri = "" ,
1063
+ service_account = _TEST_SERVICE_ACCOUNT ,
1050
1064
)
1051
1065
1052
1066
batch_prediction_job .wait_for_resource_creation ()
@@ -1086,6 +1100,7 @@ def test_batch_predict_with_all_args_and_model_monitoring(
1086
1100
generate_explanation = True ,
1087
1101
model_monitoring_config = _TEST_MODEL_MONITORING_CFG ,
1088
1102
labels = _TEST_LABEL ,
1103
+ service_account = _TEST_SERVICE_ACCOUNT ,
1089
1104
)
1090
1105
create_batch_prediction_job_v1beta1_mock .assert_called_once_with (
1091
1106
parent = f"projects/{ _TEST_PROJECT } /locations/{ _TEST_LOCATION } " ,
@@ -1103,6 +1118,7 @@ def test_batch_predict_create_fails(self):
1103
1118
gcs_source = _TEST_BATCH_PREDICTION_GCS_SOURCE ,
1104
1119
bigquery_destination_prefix = _TEST_BATCH_PREDICTION_BQ_PREFIX ,
1105
1120
sync = False ,
1121
+ service_account = _TEST_SERVICE_ACCOUNT ,
1106
1122
)
1107
1123
1108
1124
with pytest .raises (RuntimeError ) as e :
@@ -1143,6 +1159,7 @@ def test_batch_predict_no_source(self, create_batch_prediction_job_mock):
1143
1159
model_name = _TEST_MODEL_NAME ,
1144
1160
job_display_name = _TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME ,
1145
1161
bigquery_destination_prefix = _TEST_BATCH_PREDICTION_BQ_PREFIX ,
1162
+ service_account = _TEST_SERVICE_ACCOUNT ,
1146
1163
)
1147
1164
1148
1165
assert e .match (regexp = r"source" )
@@ -1159,6 +1176,7 @@ def test_batch_predict_two_sources(self, create_batch_prediction_job_mock):
1159
1176
gcs_source = _TEST_BATCH_PREDICTION_GCS_SOURCE ,
1160
1177
bigquery_source = _TEST_BATCH_PREDICTION_BQ_PREFIX ,
1161
1178
bigquery_destination_prefix = _TEST_BATCH_PREDICTION_BQ_PREFIX ,
1179
+ service_account = _TEST_SERVICE_ACCOUNT ,
1162
1180
)
1163
1181
1164
1182
assert e .match (regexp = r"source" )
@@ -1173,6 +1191,7 @@ def test_batch_predict_no_destination(self):
1173
1191
model_name = _TEST_MODEL_NAME ,
1174
1192
job_display_name = _TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME ,
1175
1193
gcs_source = _TEST_BATCH_PREDICTION_GCS_SOURCE ,
1194
+ service_account = _TEST_SERVICE_ACCOUNT ,
1176
1195
)
1177
1196
1178
1197
assert e .match (regexp = r"destination" )
@@ -1189,6 +1208,7 @@ def test_batch_predict_wrong_instance_format(self):
1189
1208
gcs_source = _TEST_BATCH_PREDICTION_GCS_SOURCE ,
1190
1209
instances_format = "wrong" ,
1191
1210
bigquery_destination_prefix = _TEST_BATCH_PREDICTION_BQ_PREFIX ,
1211
+ service_account = _TEST_SERVICE_ACCOUNT ,
1192
1212
)
1193
1213
1194
1214
assert e .match (regexp = r"accepted instances format" )
@@ -1205,6 +1225,7 @@ def test_batch_predict_wrong_prediction_format(self):
1205
1225
gcs_source = _TEST_BATCH_PREDICTION_GCS_SOURCE ,
1206
1226
predictions_format = "wrong" ,
1207
1227
bigquery_destination_prefix = _TEST_BATCH_PREDICTION_BQ_PREFIX ,
1228
+ service_account = _TEST_SERVICE_ACCOUNT ,
1208
1229
)
1209
1230
1210
1231
assert e .match (regexp = r"accepted prediction format" )
@@ -1222,6 +1243,7 @@ def test_batch_predict_job_with_versioned_model(
1222
1243
gcs_source = _TEST_BATCH_PREDICTION_GCS_SOURCE ,
1223
1244
gcs_destination_prefix = _TEST_BATCH_PREDICTION_GCS_DEST_PREFIX ,
1224
1245
sync = True ,
1246
+ service_account = _TEST_SERVICE_ACCOUNT ,
1225
1247
)
1226
1248
assert (
1227
1249
create_batch_prediction_job_mock .call_args_list [0 ][1 ][
@@ -1237,6 +1259,7 @@ def test_batch_predict_job_with_versioned_model(
1237
1259
gcs_source = _TEST_BATCH_PREDICTION_GCS_SOURCE ,
1238
1260
gcs_destination_prefix = _TEST_BATCH_PREDICTION_GCS_DEST_PREFIX ,
1239
1261
sync = True ,
1262
+ service_account = _TEST_SERVICE_ACCOUNT ,
1240
1263
)
1241
1264
assert (
1242
1265
create_batch_prediction_job_mock .call_args_list [0 ][1 ][
0 commit comments