diff --git a/astronomer/providers/cncf/kubernetes/operators/kubernetes_pod.py b/astronomer/providers/cncf/kubernetes/operators/kubernetes_pod.py index 406a02b3d..81ae502fd 100644 --- a/astronomer/providers/cncf/kubernetes/operators/kubernetes_pod.py +++ b/astronomer/providers/cncf/kubernetes/operators/kubernetes_pod.py @@ -5,6 +5,10 @@ from airflow.providers.cncf.kubernetes.operators.kubernetes_pod import ( KubernetesPodOperator, ) +from airflow.providers.cncf.kubernetes.utils.pod_manager import ( + PodPhase, + container_is_running, +) from kubernetes.client import models as k8s from pendulum import DateTime @@ -76,9 +80,20 @@ def defer(self, last_log_time: Optional[DateTime] = None, **kwargs: Any) -> None method_name=self.trigger_reentry.__name__, ) - def execute(self, context: Context) -> None: # noqa: D102 + def execute(self, context: Context) -> Any: # noqa: D102 self.pod_request_obj = self.build_pod_request_obj(context) self.pod: k8s.V1Pod = self.get_or_create_pod(self.pod_request_obj, context) + pod_status = self.pod.status.phase + if pod_status in PodPhase.terminal_states or not container_is_running( + pod=self.pod, container_name=self.BASE_CONTAINER_NAME + ): + event = { + "status": "done", + "namespace": self.pod.metadata.namespace, + "pod_name": self.pod.metadata.name, + } + return self.trigger_reentry(context=context, event=event) + self.defer() def execute_complete(self, context: Context, event: Dict[str, Any]) -> Any: # type: ignore[override] diff --git a/astronomer/providers/cncf/kubernetes/triggers/wait_container.py b/astronomer/providers/cncf/kubernetes/triggers/wait_container.py index 70f3bc348..53b8f86b0 100644 --- a/astronomer/providers/cncf/kubernetes/triggers/wait_container.py +++ b/astronomer/providers/cncf/kubernetes/triggers/wait_container.py @@ -1,7 +1,9 @@ +from __future__ import annotations + import asyncio import traceback from datetime import timedelta -from typing import Any, AsyncIterator, Dict, Optional, Tuple +from typing import Any, AsyncIterator from airflow.exceptions import AirflowException from airflow.providers.cncf.kubernetes.utils.pod_manager import ( @@ -43,12 +45,12 @@ def __init__( container_name: str, pod_name: str, pod_namespace: str, - kubernetes_conn_id: Optional[str] = None, - hook_params: Optional[Dict[str, Any]] = None, + kubernetes_conn_id: str | None = None, + hook_params: dict[str, Any] | None = None, pending_phase_timeout: float = 120, poll_interval: float = 5, - logging_interval: Optional[int] = None, - last_log_time: Optional[DateTime] = None, + logging_interval: int | None = None, + last_log_time: DateTime | None = None, ): super().__init__() self.kubernetes_conn_id = kubernetes_conn_id @@ -61,7 +63,7 @@ def __init__( self.logging_interval = logging_interval self.last_log_time = last_log_time - def serialize(self) -> Tuple[str, Dict[str, Any]]: # noqa: D102 + def serialize(self) -> tuple[str, dict[str, Any]]: # noqa: D102 return ( "astronomer.providers.cncf.kubernetes.triggers.wait_container.WaitContainerTrigger", { @@ -94,7 +96,7 @@ async def wait_for_pod_start(self, v1_api: CoreV1Api) -> Any: await asyncio.sleep(self.poll_interval) raise PodLaunchTimeoutException("Pod did not leave 'Pending' phase within specified timeout") - async def wait_for_container_completion(self, v1_api: CoreV1Api) -> "TriggerEvent": + async def wait_for_container_completion(self, v1_api: CoreV1Api) -> TriggerEvent: """ Waits until container ``self.container_name`` is no longer in running state. If trigger is configured with a logging period, then will emit an event to @@ -114,7 +116,7 @@ async def wait_for_container_completion(self, v1_api: CoreV1Api) -> "TriggerEven return TriggerEvent({"status": "running", "last_log_time": self.last_log_time}) await asyncio.sleep(self.poll_interval) - async def run(self) -> AsyncIterator["TriggerEvent"]: # noqa: D102 + async def run(self) -> AsyncIterator[TriggerEvent]: # noqa: D102 self.log.debug("Checking pod %r in namespace %r.", self.pod_name, self.pod_namespace) try: hook = await self.get_hook() diff --git a/tests/cncf/kubernetes/operators/test_kubernetes_pod.py b/tests/cncf/kubernetes/operators/test_kubernetes_pod.py index ce2f399ca..c470b0560 100644 --- a/tests/cncf/kubernetes/operators/test_kubernetes_pod.py +++ b/tests/cncf/kubernetes/operators/test_kubernetes_pod.py @@ -3,7 +3,11 @@ import pytest from airflow.exceptions import TaskDeferred -from airflow.providers.cncf.kubernetes.utils.pod_manager import PodLoggingStatus +from airflow.providers.cncf.kubernetes.utils.pod_manager import ( + PodLoggingStatus, + PodPhase, +) +from kubernetes.client import models as k8s from astronomer.providers.cncf.kubernetes.operators.kubernetes_pod import ( KubernetesPodOperatorAsync, @@ -17,6 +21,24 @@ KUBE_POD_MOD = "astronomer.providers.cncf.kubernetes.operators.kubernetes_pod" +def _build_mock_pod(state: k8s.V1ContainerState) -> k8s.V1Pod: + return k8s.V1Pod( + metadata=k8s.V1ObjectMeta(name="base", namespace="default"), + status=k8s.V1PodStatus( + container_statuses=[ + k8s.V1ContainerStatus( + name="base", + image="alpine", + image_id="1", + ready=True, + restart_count=1, + state=state, + ) + ] + ), + ) + + class TestKubernetesPodOperatorAsync: def test_raise_for_trigger_status_pending_timeout(self): """Assert trigger raise exception in case of timeout""" @@ -152,12 +174,37 @@ def test_defer_with_kwargs(self): with pytest.raises(ValueError): op.defer(kwargs={"timeout": 10}) + @pytest.mark.parametrize("pod_phase", PodPhase.terminal_states) + @mock.patch(f"{KUBE_POD_MOD}.KubernetesPodOperatorAsync.trigger_reentry") @mock.patch(f"{KUBE_POD_MOD}.KubernetesPodOperatorAsync.build_pod_request_obj") @mock.patch(f"{KUBE_POD_MOD}.KubernetesPodOperatorAsync.get_or_create_pod") @mock.patch(f"{KUBE_POD_MOD}.KubernetesPodOperatorAsync.defer") - def test_execute(self, mock_defer, mock_get_or_create_pod, mock_build_pod_request_obj): + def test_execute_done_before_defer( + self, mock_defer, mock_get_or_create_pod, mock_build_pod_request_obj, mock_trigger_reentry, pod_phase + ): + mock_get_or_create_pod.return_value.status.phase = pod_phase + mock_build_pod_request_obj.return_value = {} + mock_defer.return_value = {} + op = KubernetesPodOperatorAsync(task_id="test_task", name="test-pod", get_logs=True) + assert op.execute(context=create_context(op)) + assert mock_trigger_reentry.called + assert not mock_defer.called + + @mock.patch(f"{KUBE_POD_MOD}.KubernetesPodOperatorAsync.build_pod_request_obj") + @mock.patch(f"{KUBE_POD_MOD}.KubernetesPodOperatorAsync.get_or_create_pod") + @mock.patch(f"{KUBE_POD_MOD}.KubernetesPodOperatorAsync.defer") + def test_execute( + self, + mock_defer, + mock_get_or_create_pod, + mock_build_pod_request_obj, + ): """Assert that execute succeeded""" - mock_get_or_create_pod.return_value = {} + mock_get_or_create_pod.return_value = _build_mock_pod( + k8s.V1ContainerState( + {"running": k8s.V1ContainerStateRunning(), "terminated": None, "waiting": None} + ) + ) mock_build_pod_request_obj.return_value = {} mock_defer.return_value = {} op = KubernetesPodOperatorAsync(task_id="test_task", name="test-pod", get_logs=True)