|
198 | 198 | }
|
199 | 199 | )
|
200 | 200 |
|
| 201 | +_TEST_MODEL_EVAL_PIPELINE_SPEC_WITH_CACHING_OPTIONS_JSON = json.dumps( |
| 202 | + { |
| 203 | + "pipelineInfo": {"name": "evaluation-default-pipeline"}, |
| 204 | + "root": { |
| 205 | + "dag": { |
| 206 | + "tasks": { |
| 207 | + "model-evaluation-text-generation": { |
| 208 | + "taskInfo": {"name": "model-evaluation-text-generation"}, |
| 209 | + "cachingOptions": {"enableCache": False}, |
| 210 | + } |
| 211 | + } |
| 212 | + }, |
| 213 | + "inputDefinitions": { |
| 214 | + "parameters": { |
| 215 | + "batch_predict_gcs_source_uris": {"type": "STRING"}, |
| 216 | + "dataflow_service_account": {"type": "STRING"}, |
| 217 | + "batch_predict_instances_format": {"type": "STRING"}, |
| 218 | + "batch_predict_machine_type": {"type": "STRING"}, |
| 219 | + "evaluation_class_labels": {"type": "STRING"}, |
| 220 | + "location": {"type": "STRING"}, |
| 221 | + "model_name": {"type": "STRING"}, |
| 222 | + "project": {"type": "STRING"}, |
| 223 | + "batch_predict_gcs_destination_output_uri": {"type": "STRING"}, |
| 224 | + "target_field_name": {"type": "STRING"}, |
| 225 | + } |
| 226 | + }, |
| 227 | + }, |
| 228 | + "schemaVersion": "2.0.0", |
| 229 | + "sdkVersion": "kfp-1.8.12", |
| 230 | + "components": {}, |
| 231 | + } |
| 232 | +) |
| 233 | + |
201 | 234 | _TEST_INVALID_MODEL_EVAL_PIPELINE_SPEC = json.dumps(
|
202 | 235 | {
|
203 | 236 | "pipelineInfo": {"name": "my-pipeline"},
|
@@ -1083,6 +1116,100 @@ def test_model_evaluation_job_submit(
|
1083 | 1116 |
|
1084 | 1117 | assert mock_model_eval_job_get.called_once
|
1085 | 1118 |
|
| 1119 | + @pytest.mark.parametrize( |
| 1120 | + "job_spec", |
| 1121 | + [_TEST_MODEL_EVAL_PIPELINE_SPEC_WITH_CACHING_OPTIONS_JSON], |
| 1122 | + ) |
| 1123 | + @pytest.mark.usefixtures("mock_pipeline_service_create") |
| 1124 | + def test_model_evaluation_job_submit_with_caching_disabled( |
| 1125 | + self, |
| 1126 | + job_spec, |
| 1127 | + mock_load_yaml_and_json, |
| 1128 | + mock_model, |
| 1129 | + get_model_mock, |
| 1130 | + mock_model_eval_get, |
| 1131 | + mock_model_eval_job_get, |
| 1132 | + mock_pipeline_service_get, |
| 1133 | + mock_model_eval_job_create, |
| 1134 | + mock_pipeline_bucket_exists, |
| 1135 | + mock_request_urlopen, |
| 1136 | + ): |
| 1137 | + test_model_eval_job = model_evaluation_job._ModelEvaluationJob.submit( |
| 1138 | + model_name=_TEST_MODEL_RESOURCE_NAME, |
| 1139 | + prediction_type=_TEST_MODEL_EVAL_PREDICTION_TYPE, |
| 1140 | + instances_format=_TEST_MODEL_EVAL_PIPELINE_PARAMETER_VALUES[ |
| 1141 | + "batch_predict_instances_format" |
| 1142 | + ], |
| 1143 | + model_type="automl_tabular", |
| 1144 | + pipeline_root=_TEST_GCS_BUCKET_NAME, |
| 1145 | + target_field_name=_TEST_MODEL_EVAL_PIPELINE_PARAMETER_VALUES[ |
| 1146 | + "target_field_name" |
| 1147 | + ], |
| 1148 | + evaluation_pipeline_display_name=_TEST_MODEL_EVAL_PIPELINE_JOB_DISPLAY_NAME, |
| 1149 | + gcs_source_uris=_TEST_MODEL_EVAL_PIPELINE_PARAMETER_VALUES[ |
| 1150 | + "batch_predict_gcs_source_uris" |
| 1151 | + ], |
| 1152 | + job_id=_TEST_PIPELINE_JOB_ID, |
| 1153 | + service_account=_TEST_SERVICE_ACCOUNT, |
| 1154 | + network=_TEST_NETWORK, |
| 1155 | + enable_caching=False, |
| 1156 | + ) |
| 1157 | + |
| 1158 | + test_model_eval_job.wait() |
| 1159 | + |
| 1160 | + expected_runtime_config_dict = { |
| 1161 | + "gcsOutputDirectory": _TEST_GCS_BUCKET_NAME, |
| 1162 | + "parameters": { |
| 1163 | + "batch_predict_gcs_source_uris": { |
| 1164 | + "stringValue": '["gs://my-bucket/my-prediction-data.csv"]' |
| 1165 | + }, |
| 1166 | + "dataflow_service_account": {"stringValue": _TEST_SERVICE_ACCOUNT}, |
| 1167 | + "batch_predict_instances_format": {"stringValue": "csv"}, |
| 1168 | + "model_name": {"stringValue": _TEST_MODEL_RESOURCE_NAME}, |
| 1169 | + "project": {"stringValue": _TEST_PROJECT}, |
| 1170 | + "location": {"stringValue": _TEST_LOCATION}, |
| 1171 | + "batch_predict_gcs_destination_output_uri": { |
| 1172 | + "stringValue": _TEST_GCS_BUCKET_NAME |
| 1173 | + }, |
| 1174 | + "target_field_name": {"stringValue": "predict_class"}, |
| 1175 | + }, |
| 1176 | + } |
| 1177 | + |
| 1178 | + runtime_config = gca_pipeline_job.PipelineJob.RuntimeConfig()._pb |
| 1179 | + json_format.ParseDict(expected_runtime_config_dict, runtime_config) |
| 1180 | + |
| 1181 | + job_spec = yaml.safe_load(job_spec) |
| 1182 | + pipeline_spec = job_spec.get("pipelineSpec") or job_spec |
| 1183 | + |
| 1184 | + # Construct expected request |
| 1185 | + expected_gapic_pipeline_job = gca_pipeline_job.PipelineJob( |
| 1186 | + display_name=_TEST_MODEL_EVAL_PIPELINE_JOB_DISPLAY_NAME, |
| 1187 | + pipeline_spec={ |
| 1188 | + "components": {}, |
| 1189 | + "pipelineInfo": pipeline_spec["pipelineInfo"], |
| 1190 | + "root": pipeline_spec["root"], |
| 1191 | + "schemaVersion": "2.0.0", |
| 1192 | + "sdkVersion": "kfp-1.8.12", |
| 1193 | + }, |
| 1194 | + runtime_config=runtime_config, |
| 1195 | + service_account=_TEST_SERVICE_ACCOUNT, |
| 1196 | + network=_TEST_NETWORK, |
| 1197 | + template_uri=_TEST_KFP_TEMPLATE_URI, |
| 1198 | + ) |
| 1199 | + |
| 1200 | + mock_model_eval_job_create.assert_called_with( |
| 1201 | + parent=_TEST_PARENT, |
| 1202 | + pipeline_job=expected_gapic_pipeline_job, |
| 1203 | + pipeline_job_id=_TEST_PIPELINE_JOB_ID, |
| 1204 | + timeout=None, |
| 1205 | + ) |
| 1206 | + |
| 1207 | + assert mock_model_eval_job_get.called_once |
| 1208 | + |
| 1209 | + assert mock_pipeline_service_get.called_once |
| 1210 | + |
| 1211 | + assert mock_model_eval_job_get.called_once |
| 1212 | + |
1086 | 1213 | @pytest.mark.parametrize(
|
1087 | 1214 | "job_spec",
|
1088 | 1215 | [_TEST_MODEL_EVAL_PIPELINE_SPEC_JSON],
|
|
0 commit comments