Skip to content

Commit 0066f99

Browse files
committed
feat:add new test and precommit
1 parent 0f3b7d6 commit 0066f99

File tree

2 files changed

+62
-44
lines changed

2 files changed

+62
-44
lines changed

src/integrations/prefect-kubernetes/prefect_kubernetes/volcanoworker.py

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@
2424
ApiClient,
2525
CoreV1Api,
2626
CustomObjectsApi,
27-
V1Pod,
2827
V1Job,
28+
V1Pod,
2929
)
3030
from kubernetes_asyncio.client.exceptions import ApiException
3131
from pydantic import Field, model_validator
@@ -161,12 +161,15 @@ async def _get_job(
161161
job_name: str,
162162
namespace: str,
163163
client: "ApiClient",
164-
job_manifest: Optional[Dict[str, Any]] = None
164+
job_manifest: Optional[Dict[str, Any]] = None,
165165
) -> Union[Dict[str, Any], "V1Job", None]:
166166
"""
167167
Get a Kubernetes or Volcano job by name.
168168
"""
169-
if job_manifest and job_manifest.get("apiVersion") == "batch.volcano.sh/v1alpha1":
169+
if (
170+
job_manifest
171+
and job_manifest.get("apiVersion") == "batch.volcano.sh/v1alpha1"
172+
):
170173
# For Volcano Job, use CustomObjectsApi
171174
custom_api = CustomObjectsApi(client)
172175
try:
@@ -175,7 +178,7 @@ async def _get_job(
175178
version="v1alpha1",
176179
namespace=namespace,
177180
plural="jobs",
178-
name=job_name
181+
name=job_name,
179182
)
180183
except ApiException as e:
181184
if e.status == 404:
@@ -245,36 +248,29 @@ async def _watch_job(
245248

246249
# Get job and pod information
247250
job = await self._get_job(
248-
job_name=job_name,
249-
namespace=configuration.namespace,
251+
job_name=job_name,
252+
namespace=configuration.namespace,
250253
client=client,
251-
job_manifest=configuration.job_manifest
254+
job_manifest=configuration.job_manifest,
252255
)
253256
if not job:
254257
return -1
255-
258+
256259
pod = await self._get_job_pod(logger, job_name, configuration, client)
257260
if not pod:
258261
return -1
259262

260263
# Volcano Job monitoring
261264
tasks = [
262265
self._monitor_volcano_job_state(
263-
logger,
264-
job_name,
265-
configuration.namespace,
266-
client
266+
logger, job_name, configuration.namespace, client
267267
)
268268
]
269-
269+
270270
if configuration.stream_output:
271271
tasks.append(
272272
self._stream_job_logs(
273-
logger,
274-
pod.metadata.name,
275-
job_name,
276-
configuration,
277-
client
273+
logger, pod.metadata.name, job_name, configuration, client
278274
)
279275
)
280276

@@ -283,14 +279,17 @@ async def _watch_job(
283279
results = await asyncio.gather(*tasks, return_exceptions=True)
284280
for result in results:
285281
if isinstance(result, Exception):
286-
logger.error("Error while monitoring Volcano job", exc_info=result)
282+
logger.error(
283+
"Error while monitoring Volcano job", exc_info=result
284+
)
287285
return -1
288286
except TimeoutError:
289287
logger.error(f"Volcano job {job_name!r} timed out.")
290288
return -1
291289

292-
return await self._get_container_exit_code(logger, job_name, configuration, client)
293-
290+
return await self._get_container_exit_code(
291+
logger, job_name, configuration, client
292+
)
294293

295294
# ------------------------------------------------------------------
296295
# PID helper override (job is dict for Volcano)
@@ -404,7 +403,6 @@ async def run( # type: ignore[override]
404403

405404
return KubernetesWorkerResult(identifier=pid, status_code=status_code)
406405

407-
408406
async def _monitor_volcano_job_state(
409407
self,
410408
logger: logging.Logger,
@@ -414,7 +412,7 @@ async def _monitor_volcano_job_state(
414412
) -> None:
415413
"""
416414
Monitor the state of a Volcano job until completion.
417-
415+
418416
Args:
419417
logger: Logger to use for logging
420418
job_name: Name of the Volcano job
@@ -435,14 +433,17 @@ async def _monitor_volcano_job_state(
435433
logger.info(f"Volcano job {job_name!r} state: {volcano_state}")
436434

437435
if volcano_state in ["Completed", "Failed", "Aborted"]:
438-
logger.info(f"Volcano job {job_name!r} finished with state: {volcano_state}")
436+
logger.info(
437+
f"Volcano job {job_name!r} finished with state: {volcano_state}"
438+
)
439439
return
440440

441441
await asyncio.sleep(5) # Poll every 5 seconds
442442
except Exception as e:
443443
logger.warning(f"Error monitoring Volcano job {job_name!r}: {e}")
444444
await asyncio.sleep(5)
445445

446+
446447
# ---------------------------------------------------------------------------
447448
# Export for Prefect plugin system -----------------------------------------
448449
# ---------------------------------------------------------------------------

src/integrations/prefect-kubernetes/tests/test_volcanoworker.py

Lines changed: 37 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,11 @@
88
from __future__ import annotations
99

1010
import uuid
11-
from unittest.mock import AsyncMock, MagicMock
11+
from unittest.mock import AsyncMock, MagicMock, patch
1212

1313
import pytest
1414
from kubernetes_asyncio.client import V1Pod
15+
import warnings
1516

1617
pytest.importorskip("prefect_kubernetes.volcanoworker")
1718
from prefect_kubernetes.volcanoworker import ( # noqa: E402
@@ -103,36 +104,39 @@ async def test_get_job_pod_selector_order(monkeypatch, job_cfg):
103104
"""
104105

105106
# --- Create fake CoreV1Api.list_namespaced_pod ---
106-
async def fake_list_ns_pod(namespace, label_selector=None):
107+
async def fake_list_ns_pod(namespace, label_selector=None, **kwargs):
107108
fake_resp = MagicMock(items=[])
108109
# First selector hits
109110
if label_selector == "volcano.sh/job-name=my-job":
110111
pod = MagicMock(spec=V1Pod)
112+
pod.status.phase = "Running" # Set a non-Pending phase to exit the watch loop
111113
fake_resp.items = [pod]
112114
return fake_resp
113115

114116
list_pod_mock = AsyncMock(side_effect=fake_list_ns_pod)
115-
monkeypatch.setattr(
116-
"prefect_kubernetes.volcanoworker.CoreV1Api",
117-
MagicMock(return_value=MagicMock(list_namespaced_pod=list_pod_mock)),
118-
)
119-
120-
worker = VolcanoWorker(work_pool_name="dummy")
121-
pod = await worker._get_job_pod(
122-
logger=worker._logger,
123-
job_name="my-job",
124-
configuration=job_cfg,
125-
client=MagicMock(),
126-
)
117+
118+
# Mock the Watch class to avoid timeout_seconds issue
119+
watch_mock = MagicMock()
120+
watch_mock.stream = AsyncMock(return_value=[{"object": MagicMock(spec=V1Pod, status=MagicMock(phase="Running"))}])
121+
122+
with patch("kubernetes_asyncio.watch.Watch", return_value=watch_mock):
123+
monkeypatch.setattr(
124+
"prefect_kubernetes.volcanoworker.CoreV1Api",
125+
MagicMock(return_value=MagicMock(list_namespaced_pod=list_pod_mock)),
126+
)
127+
128+
worker = VolcanoWorker(work_pool_name="dummy")
129+
pod = await worker._get_job_pod(
130+
logger=worker._logger,
131+
job_name="my-job",
132+
configuration=job_cfg,
133+
client=MagicMock(),
134+
)
127135

128136
# Confirm pod was returned
129137
assert pod is not None
130-
# Should use volcano.sh/job-name first
131-
assert (
132-
list_pod_mock.await_args_list[0]
133-
.kwargs["label_selector"]
134-
.startswith("volcano.sh/job-name")
135-
)
138+
# Should use volcano.sh/job-name in the watch stream
139+
assert watch_mock.stream.called
136140

137141

138142
@pytest.mark.asyncio
@@ -143,6 +147,10 @@ async def test_run_full_flow(monkeypatch, job_cfg, dummy_flow_run):
143147
• status_code from _watch_job is passed through
144148
• PID format is clusterUID:ns:name
145149
"""
150+
# Filter out FutureWarning about ad-hoc flow submission
151+
warnings.filterwarnings("ignore", category=FutureWarning,
152+
message="Ad-hoc flow submission via workers is experimental.*")
153+
146154
fake_job = {
147155
"metadata": {
148156
"name": "vc-job-999",
@@ -177,6 +185,15 @@ async def test_run_full_flow(monkeypatch, job_cfg, dummy_flow_run):
177185
AsyncMock(return_value="CLUSTER-UID"),
178186
)
179187

188+
# Mock KubernetesEventsReplicator
189+
events_replicator_mock = MagicMock()
190+
events_replicator_mock.__aenter__ = AsyncMock(return_value=None)
191+
events_replicator_mock.__aexit__ = AsyncMock(return_value=None)
192+
monkeypatch.setattr(
193+
"prefect_kubernetes.volcanoworker.KubernetesEventsReplicator",
194+
MagicMock(return_value=events_replicator_mock),
195+
)
196+
180197
# Mock kubernetes client configuration to avoid kubeconfig errors
181198
client_mock = MagicMock()
182199
client_context_manager = MagicMock()

0 commit comments

Comments
 (0)