|
24 | 24 | from datetime import datetime, timezone
|
25 | 25 | from pathlib import Path
|
26 | 26 | from subprocess import check_call, check_output
|
| 27 | +from typing import Literal |
27 | 28 |
|
28 | 29 | import pytest
|
29 | 30 | import requests
|
@@ -65,7 +66,7 @@ def base_tests_setup(self, request):
|
65 | 66 | # Replacement for unittests.TestCase.id()
|
66 | 67 | self.test_id = f"{request.node.cls.__name__}_{request.node.name}"
|
67 | 68 | # Ensure the api-server deployment is healthy at kubernetes level before calling the any API
|
68 |
| - self.ensure_deployment_health("airflow-api-server") |
| 69 | + self.ensure_resource_health("airflow-api-server") |
69 | 70 | try:
|
70 | 71 | self.session = self._get_session_with_retries()
|
71 | 72 | self._ensure_airflow_api_server_is_healthy()
|
@@ -227,12 +228,24 @@ def monitor_task(self, host, dag_run_id, dag_id, task_id, expected_final_state,
|
227 | 228 | assert state == expected_final_state
|
228 | 229 |
|
229 | 230 | @staticmethod
|
230 |
| - def ensure_deployment_health(deployment_name: str, namespace: str = "airflow"): |
231 |
| - """Watch the deployment until it is healthy.""" |
232 |
| - deployment_rollout_status = check_output( |
233 |
| - ["kubectl", "rollout", "status", "deployment", deployment_name, "-n", namespace, "--watch"] |
| 231 | + def ensure_resource_health( |
| 232 | + resource_name: str, |
| 233 | + namespace: str = "airflow", |
| 234 | + resource_type: Literal["deployment", "statefulset"] = "deployment", |
| 235 | + ): |
| 236 | + """Watch the resource until it is healthy. |
| 237 | + Args: |
| 238 | + resource_name (str): Name of the resource to check. |
| 239 | + resource_type (str): Type of the resource (e.g., deployment, statefulset). |
| 240 | + namespace (str): Kubernetes namespace where the resource is located. |
| 241 | + """ |
| 242 | + rollout_status = check_output( |
| 243 | + ["kubectl", "rollout", "status", f"{resource_type}/{resource_name}", "-n", namespace, "--watch"], |
234 | 244 | ).decode()
|
235 |
| - assert "successfully rolled out" in deployment_rollout_status |
| 245 | + if resource_type == "deployment": |
| 246 | + assert "successfully rolled out" in rollout_status |
| 247 | + else: |
| 248 | + assert "roll out complete" in rollout_status |
236 | 249 |
|
237 | 250 | def ensure_dag_expected_state(self, host, logical_date, dag_id, expected_final_state, timeout):
|
238 | 251 | tries = 0
|
|
0 commit comments