Skip to content

Check whether the task finishes before deferring the task for KubernetesPodOperatorAsync #1104

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
from airflow.providers.cncf.kubernetes.operators.kubernetes_pod import (
KubernetesPodOperator,
)
from airflow.providers.cncf.kubernetes.utils.pod_manager import (
PodPhase,
container_is_running,
)
from kubernetes.client import models as k8s
from pendulum import DateTime

Expand Down Expand Up @@ -76,9 +80,20 @@ def defer(self, last_log_time: Optional[DateTime] = None, **kwargs: Any) -> None
method_name=self.trigger_reentry.__name__,
)

def execute(self, context: Context) -> None: # noqa: D102
def execute(self, context: Context) -> Any: # noqa: D102
self.pod_request_obj = self.build_pod_request_obj(context)
self.pod: k8s.V1Pod = self.get_or_create_pod(self.pod_request_obj, context)
pod_status = self.pod.status.phase
if pod_status in PodPhase.terminal_states or not container_is_running(
pod=self.pod, container_name=self.BASE_CONTAINER_NAME
Copy link
Collaborator

@pankajkoti pankajkoti Jun 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will there always be a base container in all the pods? Also, if there is a base container what does it do? I am thinking if there are multiple containers in the pod then should we not check for the running status of all the containers in the pod? Or the base container is meant to keep a check on the running status of other containers in the pod?

If it's not possible to check the status of all containers I think we could just remove the or condition which checks the container status and then the PR looks good to me.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am afraid we have a wrong implementation there with respect to the above questions. It's done a bit differently in the OSS provider, no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, I think we implement the logic in different ways

):
event = {
"status": "done",
"namespace": self.pod.metadata.namespace,
"pod_name": self.pod.metadata.name,
}
return self.trigger_reentry(context=context, event=event)

self.defer()

def execute_complete(self, context: Context, event: Dict[str, Any]) -> Any: # type: ignore[override]
Expand Down
18 changes: 10 additions & 8 deletions astronomer/providers/cncf/kubernetes/triggers/wait_container.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from __future__ import annotations

import asyncio
import traceback
from datetime import timedelta
from typing import Any, AsyncIterator, Dict, Optional, Tuple
from typing import Any, AsyncIterator

from airflow.exceptions import AirflowException
from airflow.providers.cncf.kubernetes.utils.pod_manager import (
Expand Down Expand Up @@ -43,12 +45,12 @@ def __init__(
container_name: str,
pod_name: str,
pod_namespace: str,
kubernetes_conn_id: Optional[str] = None,
hook_params: Optional[Dict[str, Any]] = None,
kubernetes_conn_id: str | None = None,
hook_params: dict[str, Any] | None = None,
pending_phase_timeout: float = 120,
poll_interval: float = 5,
logging_interval: Optional[int] = None,
last_log_time: Optional[DateTime] = None,
logging_interval: int | None = None,
last_log_time: DateTime | None = None,
):
super().__init__()
self.kubernetes_conn_id = kubernetes_conn_id
Expand All @@ -61,7 +63,7 @@ def __init__(
self.logging_interval = logging_interval
self.last_log_time = last_log_time

def serialize(self) -> Tuple[str, Dict[str, Any]]: # noqa: D102
def serialize(self) -> tuple[str, dict[str, Any]]: # noqa: D102
return (
"astronomer.providers.cncf.kubernetes.triggers.wait_container.WaitContainerTrigger",
{
Expand Down Expand Up @@ -94,7 +96,7 @@ async def wait_for_pod_start(self, v1_api: CoreV1Api) -> Any:
await asyncio.sleep(self.poll_interval)
raise PodLaunchTimeoutException("Pod did not leave 'Pending' phase within specified timeout")

async def wait_for_container_completion(self, v1_api: CoreV1Api) -> "TriggerEvent":
async def wait_for_container_completion(self, v1_api: CoreV1Api) -> TriggerEvent:
"""
Waits until container ``self.container_name`` is no longer in running state.
If trigger is configured with a logging period, then will emit an event to
Expand All @@ -114,7 +116,7 @@ async def wait_for_container_completion(self, v1_api: CoreV1Api) -> "TriggerEven
return TriggerEvent({"status": "running", "last_log_time": self.last_log_time})
await asyncio.sleep(self.poll_interval)

async def run(self) -> AsyncIterator["TriggerEvent"]: # noqa: D102
async def run(self) -> AsyncIterator[TriggerEvent]: # noqa: D102
self.log.debug("Checking pod %r in namespace %r.", self.pod_name, self.pod_namespace)
try:
hook = await self.get_hook()
Expand Down
53 changes: 50 additions & 3 deletions tests/cncf/kubernetes/operators/test_kubernetes_pod.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@

import pytest
from airflow.exceptions import TaskDeferred
from airflow.providers.cncf.kubernetes.utils.pod_manager import PodLoggingStatus
from airflow.providers.cncf.kubernetes.utils.pod_manager import (
PodLoggingStatus,
PodPhase,
)
from kubernetes.client import models as k8s

from astronomer.providers.cncf.kubernetes.operators.kubernetes_pod import (
KubernetesPodOperatorAsync,
Expand All @@ -17,6 +21,24 @@
KUBE_POD_MOD = "astronomer.providers.cncf.kubernetes.operators.kubernetes_pod"


def _build_mock_pod(state: k8s.V1ContainerState) -> k8s.V1Pod:
return k8s.V1Pod(
metadata=k8s.V1ObjectMeta(name="base", namespace="default"),
status=k8s.V1PodStatus(
container_statuses=[
k8s.V1ContainerStatus(
name="base",
image="alpine",
image_id="1",
ready=True,
restart_count=1,
state=state,
)
]
),
)


class TestKubernetesPodOperatorAsync:
def test_raise_for_trigger_status_pending_timeout(self):
"""Assert trigger raise exception in case of timeout"""
Expand Down Expand Up @@ -152,12 +174,37 @@ def test_defer_with_kwargs(self):
with pytest.raises(ValueError):
op.defer(kwargs={"timeout": 10})

@pytest.mark.parametrize("pod_phase", PodPhase.terminal_states)
@mock.patch(f"{KUBE_POD_MOD}.KubernetesPodOperatorAsync.trigger_reentry")
@mock.patch(f"{KUBE_POD_MOD}.KubernetesPodOperatorAsync.build_pod_request_obj")
@mock.patch(f"{KUBE_POD_MOD}.KubernetesPodOperatorAsync.get_or_create_pod")
@mock.patch(f"{KUBE_POD_MOD}.KubernetesPodOperatorAsync.defer")
def test_execute(self, mock_defer, mock_get_or_create_pod, mock_build_pod_request_obj):
def test_execute_done_before_defer(
self, mock_defer, mock_get_or_create_pod, mock_build_pod_request_obj, mock_trigger_reentry, pod_phase
):
mock_get_or_create_pod.return_value.status.phase = pod_phase
mock_build_pod_request_obj.return_value = {}
mock_defer.return_value = {}
op = KubernetesPodOperatorAsync(task_id="test_task", name="test-pod", get_logs=True)
assert op.execute(context=create_context(op))
assert mock_trigger_reentry.called
assert not mock_defer.called

@mock.patch(f"{KUBE_POD_MOD}.KubernetesPodOperatorAsync.build_pod_request_obj")
@mock.patch(f"{KUBE_POD_MOD}.KubernetesPodOperatorAsync.get_or_create_pod")
@mock.patch(f"{KUBE_POD_MOD}.KubernetesPodOperatorAsync.defer")
def test_execute(
self,
mock_defer,
mock_get_or_create_pod,
mock_build_pod_request_obj,
):
"""Assert that execute succeeded"""
mock_get_or_create_pod.return_value = {}
mock_get_or_create_pod.return_value = _build_mock_pod(
k8s.V1ContainerState(
{"running": k8s.V1ContainerStateRunning(), "terminated": None, "waiting": None}
)
)
mock_build_pod_request_obj.return_value = {}
mock_defer.return_value = {}
op = KubernetesPodOperatorAsync(task_id="test_task", name="test-pod", get_logs=True)
Expand Down