Skip to content

Commit a2dc9aa

Browse files
committed
feat(kubernetes): check state before deferring KubernetesPodOperatorAsync
1 parent 5698a89 commit a2dc9aa

File tree

2 files changed

+167
-4
lines changed

2 files changed

+167
-4
lines changed

astronomer/providers/cncf/kubernetes/operators/kubernetes_pod.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@
55
from airflow.providers.cncf.kubernetes.operators.kubernetes_pod import (
66
KubernetesPodOperator,
77
)
8+
from airflow.providers.cncf.kubernetes.triggers.kubernetes_pod import (
9+
ContainerState,
10+
KubernetesPodTrigger,
11+
)
812
from kubernetes.client import models as k8s
913
from pendulum import DateTime
1014

@@ -50,6 +54,21 @@ def raise_for_trigger_status(event: Dict[str, Any]) -> None:
5054
else:
5155
raise AirflowException(description)
5256

57+
def define_container_state(self) -> ContainerState:
58+
"""Define the state of container"""
59+
pod_containers = self.pod.status.container_statuses
60+
if not pod_containers:
61+
return ContainerState.UNDEFINED
62+
container = [c for c in pod_containers if c.name == self.BASE_CONTAINER_NAME][0]
63+
for state in (ContainerState.RUNNING, ContainerState.WAITING, ContainerState.TERMINATED):
64+
state_obj = getattr(container.state, state)
65+
if state_obj:
66+
if state != ContainerState.TERMINATED:
67+
return state
68+
else:
69+
return ContainerState.TERMINATED if state_obj.exit_code == 0 else ContainerState.FAILED
70+
return ContainerState.UNDEFINED
71+
5372
def defer(self, last_log_time: Optional[DateTime] = None, **kwargs: Any) -> None:
5473
"""Defers to ``WaitContainerTrigger`` optionally with last log time."""
5574
if kwargs:
@@ -76,10 +95,30 @@ def defer(self, last_log_time: Optional[DateTime] = None, **kwargs: Any) -> None
7695
method_name=self.trigger_reentry.__name__,
7796
)
7897

79-
def execute(self, context: Context) -> None: # noqa: D102
98+
def execute(self, context: Context) -> Any: # noqa: D102
8099
self.pod_request_obj = self.build_pod_request_obj(context)
81100
self.pod: k8s.V1Pod = self.get_or_create_pod(self.pod_request_obj, context)
82-
self.defer()
101+
container_state = self.define_container_state()
102+
pod_status = self.pod.status.phase
103+
if KubernetesPodTrigger.should_wait(pod_status, container_state):
104+
self.defer()
105+
return
106+
107+
if container_state == ContainerState.TERMINATED:
108+
event = {
109+
"name": self.pod.metadata.name,
110+
"namespace": self.pod.metadata.namespace,
111+
"status": "success",
112+
"message": "All containers inside pod have started successfully.",
113+
}
114+
else:
115+
event = {
116+
"name": self.pod.metadata.name,
117+
"namespace": self.pod.metadata.namespace,
118+
"status": "failed",
119+
"message": self.pod.status.message,
120+
}
121+
return self.trigger_reentry(context=context, event=event)
83122

84123
def execute_complete(self, context: Context, event: Dict[str, Any]) -> Any: # type: ignore[override]
85124
"""Deprecated; replaced by trigger_reentry."""

tests/cncf/kubernetes/operators/test_kubernetes_pod.py

Lines changed: 126 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33

44
import pytest
55
from airflow.exceptions import TaskDeferred
6+
from airflow.providers.cncf.kubernetes.triggers.kubernetes_pod import ContainerState
67
from airflow.providers.cncf.kubernetes.utils.pod_manager import PodLoggingStatus
8+
from kubernetes.client import models as k8s
79

810
from astronomer.providers.cncf.kubernetes.operators.kubernetes_pod import (
911
KubernetesPodOperatorAsync,
@@ -17,6 +19,24 @@
1719
KUBE_POD_MOD = "astronomer.providers.cncf.kubernetes.operators.kubernetes_pod"
1820

1921

22+
def _build_mock_pod(state: k8s.V1ContainerState) -> k8s.V1Pod:
23+
return k8s.V1Pod(
24+
metadata=k8s.V1ObjectMeta(name="base", namespace="default"),
25+
status=k8s.V1PodStatus(
26+
container_statuses=[
27+
k8s.V1ContainerStatus(
28+
name="base",
29+
image="alpine",
30+
image_id="1",
31+
ready=True,
32+
restart_count=1,
33+
state=state,
34+
)
35+
]
36+
),
37+
)
38+
39+
2040
class TestKubernetesPodOperatorAsync:
2141
def test_raise_for_trigger_status_pending_timeout(self):
2242
"""Assert trigger raise exception in case of timeout"""
@@ -152,12 +172,79 @@ def test_defer_with_kwargs(self):
152172
with pytest.raises(ValueError):
153173
op.defer(kwargs={"timeout": 10})
154174

175+
@mock.patch(f"{KUBE_POD_MOD}.KubernetesPodOperatorAsync.trigger_reentry")
176+
@mock.patch(
177+
f"{KUBE_POD_MOD}.KubernetesPodOperatorAsync.define_container_state",
178+
return_value=ContainerState.FAILED,
179+
)
180+
@mock.patch(f"{KUBE_POD_MOD}.KubernetesPodOperatorAsync.build_pod_request_obj")
181+
@mock.patch(f"{KUBE_POD_MOD}.KubernetesPodOperatorAsync.get_or_create_pod")
182+
@mock.patch(f"{KUBE_POD_MOD}.KubernetesPodOperatorAsync.defer")
183+
def test_execute_failed_before_defer(
184+
self,
185+
mock_defer,
186+
mock_get_or_create_pod,
187+
mock_build_pod_request_obj,
188+
mock_define_container_state,
189+
mock_trigger_reentry,
190+
):
191+
mock_get_or_create_pod.return_value = _build_mock_pod(
192+
k8s.V1ContainerState({"running": k8s.V1ContainerStateTerminated(exit_code=1), "waiting": None})
193+
)
194+
mock_build_pod_request_obj.return_value = {}
195+
mock_defer.return_value = {}
196+
op = KubernetesPodOperatorAsync(task_id="test_task", name="test-pod", get_logs=True)
197+
198+
op.execute(context=create_context(op))
199+
assert mock_trigger_reentry.called
200+
assert not mock_defer.called
201+
202+
@mock.patch(f"{KUBE_POD_MOD}.KubernetesPodOperatorAsync.trigger_reentry")
203+
@mock.patch(
204+
f"{KUBE_POD_MOD}.KubernetesPodOperatorAsync.define_container_state",
205+
return_value=ContainerState.TERMINATED,
206+
)
207+
@mock.patch(f"{KUBE_POD_MOD}.KubernetesPodOperatorAsync.build_pod_request_obj")
208+
@mock.patch(f"{KUBE_POD_MOD}.KubernetesPodOperatorAsync.get_or_create_pod")
209+
@mock.patch(f"{KUBE_POD_MOD}.KubernetesPodOperatorAsync.defer")
210+
def test_execute_succeeded_before_defer(
211+
self,
212+
mock_defer,
213+
mock_get_or_create_pod,
214+
mock_build_pod_request_obj,
215+
mock_define_container_state,
216+
mock_trigger_reentry,
217+
):
218+
mock_get_or_create_pod.return_value = _build_mock_pod(
219+
k8s.V1ContainerState({"running": k8s.V1ContainerStateTerminated(exit_code=0), "waiting": None})
220+
)
221+
mock_build_pod_request_obj.return_value = {}
222+
mock_defer.return_value = {}
223+
op = KubernetesPodOperatorAsync(task_id="test_task", name="test-pod", get_logs=True)
224+
assert op.execute(context=create_context(op))
225+
assert mock_trigger_reentry.called
226+
assert not mock_defer.called
227+
228+
@mock.patch(
229+
"astronomer.providers.cncf.kubernetes.operators.kubernetes_pod.KubernetesPodOperatorAsync.define_container_state",
230+
return_value=ContainerState.RUNNING,
231+
)
155232
@mock.patch(f"{KUBE_POD_MOD}.KubernetesPodOperatorAsync.build_pod_request_obj")
156233
@mock.patch(f"{KUBE_POD_MOD}.KubernetesPodOperatorAsync.get_or_create_pod")
157234
@mock.patch(f"{KUBE_POD_MOD}.KubernetesPodOperatorAsync.defer")
158-
def test_execute(self, mock_defer, mock_get_or_create_pod, mock_build_pod_request_obj):
235+
def test_execute(
236+
self,
237+
mock_defer,
238+
mock_get_or_create_pod,
239+
mock_build_pod_request_obj,
240+
mock_define_container_state,
241+
):
159242
"""Assert that execute succeeded"""
160-
mock_get_or_create_pod.return_value = {}
243+
mock_get_or_create_pod.return_value = _build_mock_pod(
244+
k8s.V1ContainerState(
245+
{"running": k8s.V1ContainerStateRunning(), "terminated": None, "waiting": None}
246+
)
247+
)
161248
mock_build_pod_request_obj.return_value = {}
162249
mock_defer.return_value = {}
163250
op = KubernetesPodOperatorAsync(task_id="test_task", name="test-pod", get_logs=True)
@@ -169,3 +256,40 @@ def test_execute_complete(self, mock_trigger_reentry):
169256
mock_trigger_reentry.return_value = {}
170257
op = KubernetesPodOperatorAsync(task_id="test_task", name="test-pod", get_logs=True)
171258
assert op.execute_complete(context=create_context(op), event={}) is None
259+
260+
@pytest.mark.parametrize(
261+
"container_state, expected_state",
262+
[
263+
(
264+
{"running": k8s.V1ContainerStateRunning(), "terminated": None, "waiting": None},
265+
ContainerState.RUNNING,
266+
),
267+
(
268+
{"running": None, "terminated": k8s.V1ContainerStateTerminated(exit_code=0), "waiting": None},
269+
ContainerState.TERMINATED,
270+
),
271+
(
272+
{"running": None, "terminated": None, "waiting": k8s.V1ContainerStateWaiting()},
273+
ContainerState.WAITING,
274+
),
275+
],
276+
)
277+
def test_define_container_state_should_execute_successfully(self, container_state, expected_state):
278+
op = KubernetesPodOperatorAsync(task_id="test_task", name="test-pod", get_logs=True)
279+
op.pod = _build_mock_pod(k8s.V1ContainerState(**container_state))
280+
assert expected_state == op.define_container_state()
281+
282+
@pytest.mark.parametrize(
283+
"pod",
284+
(
285+
_build_mock_pod(k8s.V1ContainerState(running=None, terminated=None, waiting=None)),
286+
k8s.V1Pod(
287+
metadata=k8s.V1ObjectMeta(name="base", namespace="default"),
288+
status=k8s.V1PodStatus(container_statuses=[]),
289+
),
290+
),
291+
)
292+
def test_define_container_state_with_undefined_state(self, pod):
293+
op = KubernetesPodOperatorAsync(task_id="test_task", name="test-pod", get_logs=True)
294+
op.pod = pod
295+
assert op.define_container_state() == ContainerState.UNDEFINED

0 commit comments

Comments
 (0)