Skip to content

Commit 2503200

Browse files
committed
fix(cncf): align how KubernetesPodOperatorAsync handle event and WaitContainerTrigger
1 parent 819ac0c commit 2503200

File tree

2 files changed

+19
-120
lines changed

2 files changed

+19
-120
lines changed

astronomer/providers/cncf/kubernetes/operators/kubernetes_pod.py

Lines changed: 11 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
from airflow.providers.cncf.kubernetes.operators.kubernetes_pod import (
66
KubernetesPodOperator,
77
)
8-
from airflow.providers.cncf.kubernetes.triggers.kubernetes_pod import (
9-
ContainerState,
10-
KubernetesPodTrigger,
8+
from airflow.providers.cncf.kubernetes.utils.pod_manager import (
9+
PodPhase,
10+
container_is_running,
1111
)
1212
from kubernetes.client import models as k8s
1313
from pendulum import DateTime
@@ -54,21 +54,6 @@ def raise_for_trigger_status(event: Dict[str, Any]) -> None:
5454
else:
5555
raise AirflowException(description)
5656

57-
def define_container_state(self) -> ContainerState:
58-
"""Define the state of container"""
59-
pod_containers = self.pod.status.container_statuses
60-
if not pod_containers:
61-
return ContainerState.UNDEFINED
62-
container = [c for c in pod_containers if c.name == self.BASE_CONTAINER_NAME][0]
63-
for state in (ContainerState.RUNNING, ContainerState.WAITING, ContainerState.TERMINATED):
64-
state_obj = getattr(container.state, state)
65-
if state_obj:
66-
if state != ContainerState.TERMINATED:
67-
return state
68-
else:
69-
return ContainerState.TERMINATED if state_obj.exit_code == 0 else ContainerState.FAILED
70-
return ContainerState.UNDEFINED
71-
7257
def defer(self, last_log_time: Optional[DateTime] = None, **kwargs: Any) -> None:
7358
"""Defers to ``WaitContainerTrigger`` optionally with last log time."""
7459
if kwargs:
@@ -98,27 +83,18 @@ def defer(self, last_log_time: Optional[DateTime] = None, **kwargs: Any) -> None
9883
def execute(self, context: Context) -> Any: # noqa: D102
9984
self.pod_request_obj = self.build_pod_request_obj(context)
10085
self.pod: k8s.V1Pod = self.get_or_create_pod(self.pod_request_obj, context)
101-
container_state = self.define_container_state()
10286
pod_status = self.pod.status.phase
103-
if KubernetesPodTrigger.should_wait(pod_status, container_state):
104-
self.defer()
105-
return
106-
107-
if container_state == ContainerState.TERMINATED:
108-
event = {
109-
"name": self.pod.metadata.name,
110-
"namespace": self.pod.metadata.namespace,
111-
"status": "success",
112-
"message": "All containers inside pod have started successfully.",
113-
}
114-
else:
87+
if pod_status in PodPhase.terminal_states or not container_is_running(
88+
pod=self.pod, container_name=self.BASE_CONTAINER_NAME
89+
):
11590
event = {
116-
"name": self.pod.metadata.name,
91+
"status": "done",
11792
"namespace": self.pod.metadata.namespace,
118-
"status": "failed",
119-
"message": self.pod.status.message,
93+
"pod_name": self.pod.metadata.name,
12094
}
121-
return self.trigger_reentry(context=context, event=event)
95+
return self.trigger_reentry(context=context, event=event)
96+
97+
self.defer()
12298

12399
def execute_complete(self, context: Context, event: Dict[str, Any]) -> Any: # type: ignore[override]
124100
"""Deprecated; replaced by trigger_reentry."""

tests/cncf/kubernetes/operators/test_kubernetes_pod.py

Lines changed: 8 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33

44
import pytest
55
from airflow.exceptions import TaskDeferred
6-
from airflow.providers.cncf.kubernetes.triggers.kubernetes_pod import ContainerState
7-
from airflow.providers.cncf.kubernetes.utils.pod_manager import PodLoggingStatus
6+
from airflow.providers.cncf.kubernetes.utils.pod_manager import (
7+
PodLoggingStatus,
8+
PodPhase,
9+
)
810
from kubernetes.client import models as k8s
911

1012
from astronomer.providers.cncf.kubernetes.operators.kubernetes_pod import (
@@ -172,63 +174,22 @@ def test_defer_with_kwargs(self):
172174
with pytest.raises(ValueError):
173175
op.defer(kwargs={"timeout": 10})
174176

177+
@pytest.mark.parametrize("pod_phase", PodPhase.terminal_states)
175178
@mock.patch(f"{KUBE_POD_MOD}.KubernetesPodOperatorAsync.trigger_reentry")
176-
@mock.patch(
177-
f"{KUBE_POD_MOD}.KubernetesPodOperatorAsync.define_container_state",
178-
return_value=ContainerState.FAILED,
179-
)
180179
@mock.patch(f"{KUBE_POD_MOD}.KubernetesPodOperatorAsync.build_pod_request_obj")
181180
@mock.patch(f"{KUBE_POD_MOD}.KubernetesPodOperatorAsync.get_or_create_pod")
182181
@mock.patch(f"{KUBE_POD_MOD}.KubernetesPodOperatorAsync.defer")
183-
def test_execute_failed_before_defer(
184-
self,
185-
mock_defer,
186-
mock_get_or_create_pod,
187-
mock_build_pod_request_obj,
188-
mock_define_container_state,
189-
mock_trigger_reentry,
182+
def test_execute_done_before_defer(
183+
self, mock_defer, mock_get_or_create_pod, mock_build_pod_request_obj, mock_trigger_reentry, pod_phase
190184
):
191-
mock_get_or_create_pod.return_value = _build_mock_pod(
192-
k8s.V1ContainerState({"running": k8s.V1ContainerStateTerminated(exit_code=1), "waiting": None})
193-
)
194-
mock_build_pod_request_obj.return_value = {}
195-
mock_defer.return_value = {}
196-
op = KubernetesPodOperatorAsync(task_id="test_task", name="test-pod", get_logs=True)
197-
198-
op.execute(context=create_context(op))
199-
assert mock_trigger_reentry.called
200-
assert not mock_defer.called
201-
202-
@mock.patch(f"{KUBE_POD_MOD}.KubernetesPodOperatorAsync.trigger_reentry")
203-
@mock.patch(
204-
f"{KUBE_POD_MOD}.KubernetesPodOperatorAsync.define_container_state",
205-
return_value=ContainerState.TERMINATED,
206-
)
207-
@mock.patch(f"{KUBE_POD_MOD}.KubernetesPodOperatorAsync.build_pod_request_obj")
208-
@mock.patch(f"{KUBE_POD_MOD}.KubernetesPodOperatorAsync.get_or_create_pod")
209-
@mock.patch(f"{KUBE_POD_MOD}.KubernetesPodOperatorAsync.defer")
210-
def test_execute_succeeded_before_defer(
211-
self,
212-
mock_defer,
213-
mock_get_or_create_pod,
214-
mock_build_pod_request_obj,
215-
mock_define_container_state,
216-
mock_trigger_reentry,
217-
):
218-
mock_get_or_create_pod.return_value = _build_mock_pod(
219-
k8s.V1ContainerState({"running": k8s.V1ContainerStateTerminated(exit_code=0), "waiting": None})
220-
)
185+
mock_get_or_create_pod.return_value.status.phase = pod_phase
221186
mock_build_pod_request_obj.return_value = {}
222187
mock_defer.return_value = {}
223188
op = KubernetesPodOperatorAsync(task_id="test_task", name="test-pod", get_logs=True)
224189
assert op.execute(context=create_context(op))
225190
assert mock_trigger_reentry.called
226191
assert not mock_defer.called
227192

228-
@mock.patch(
229-
"astronomer.providers.cncf.kubernetes.operators.kubernetes_pod.KubernetesPodOperatorAsync.define_container_state",
230-
return_value=ContainerState.RUNNING,
231-
)
232193
@mock.patch(f"{KUBE_POD_MOD}.KubernetesPodOperatorAsync.build_pod_request_obj")
233194
@mock.patch(f"{KUBE_POD_MOD}.KubernetesPodOperatorAsync.get_or_create_pod")
234195
@mock.patch(f"{KUBE_POD_MOD}.KubernetesPodOperatorAsync.defer")
@@ -237,7 +198,6 @@ def test_execute(
237198
mock_defer,
238199
mock_get_or_create_pod,
239200
mock_build_pod_request_obj,
240-
mock_define_container_state,
241201
):
242202
"""Assert that execute succeeded"""
243203
mock_get_or_create_pod.return_value = _build_mock_pod(
@@ -256,40 +216,3 @@ def test_execute_complete(self, mock_trigger_reentry):
256216
mock_trigger_reentry.return_value = {}
257217
op = KubernetesPodOperatorAsync(task_id="test_task", name="test-pod", get_logs=True)
258218
assert op.execute_complete(context=create_context(op), event={}) is None
259-
260-
@pytest.mark.parametrize(
261-
"container_state, expected_state",
262-
[
263-
(
264-
{"running": k8s.V1ContainerStateRunning(), "terminated": None, "waiting": None},
265-
ContainerState.RUNNING,
266-
),
267-
(
268-
{"running": None, "terminated": k8s.V1ContainerStateTerminated(exit_code=0), "waiting": None},
269-
ContainerState.TERMINATED,
270-
),
271-
(
272-
{"running": None, "terminated": None, "waiting": k8s.V1ContainerStateWaiting()},
273-
ContainerState.WAITING,
274-
),
275-
],
276-
)
277-
def test_define_container_state_should_execute_successfully(self, container_state, expected_state):
278-
op = KubernetesPodOperatorAsync(task_id="test_task", name="test-pod", get_logs=True)
279-
op.pod = _build_mock_pod(k8s.V1ContainerState(**container_state))
280-
assert expected_state == op.define_container_state()
281-
282-
@pytest.mark.parametrize(
283-
"pod",
284-
(
285-
_build_mock_pod(k8s.V1ContainerState(running=None, terminated=None, waiting=None)),
286-
k8s.V1Pod(
287-
metadata=k8s.V1ObjectMeta(name="base", namespace="default"),
288-
status=k8s.V1PodStatus(container_statuses=[]),
289-
),
290-
),
291-
)
292-
def test_define_container_state_with_undefined_state(self, pod):
293-
op = KubernetesPodOperatorAsync(task_id="test_task", name="test-pod", get_logs=True)
294-
op.pod = pod
295-
assert op.define_container_state() == ContainerState.UNDEFINED

0 commit comments

Comments
 (0)