24
24
ApiClient ,
25
25
CoreV1Api ,
26
26
CustomObjectsApi ,
27
- V1Pod ,
28
27
V1Job ,
28
+ V1Pod ,
29
29
)
30
30
from kubernetes_asyncio .client .exceptions import ApiException
31
31
from pydantic import Field , model_validator
@@ -161,12 +161,15 @@ async def _get_job(
161
161
job_name : str ,
162
162
namespace : str ,
163
163
client : "ApiClient" ,
164
- job_manifest : Optional [Dict [str , Any ]] = None
164
+ job_manifest : Optional [Dict [str , Any ]] = None ,
165
165
) -> Union [Dict [str , Any ], "V1Job" , None ]:
166
166
"""
167
167
Get a Kubernetes or Volcano job by name.
168
168
"""
169
- if job_manifest and job_manifest .get ("apiVersion" ) == "batch.volcano.sh/v1alpha1" :
169
+ if (
170
+ job_manifest
171
+ and job_manifest .get ("apiVersion" ) == "batch.volcano.sh/v1alpha1"
172
+ ):
170
173
# For Volcano Job, use CustomObjectsApi
171
174
custom_api = CustomObjectsApi (client )
172
175
try :
@@ -175,7 +178,7 @@ async def _get_job(
175
178
version = "v1alpha1" ,
176
179
namespace = namespace ,
177
180
plural = "jobs" ,
178
- name = job_name
181
+ name = job_name ,
179
182
)
180
183
except ApiException as e :
181
184
if e .status == 404 :
@@ -245,36 +248,29 @@ async def _watch_job(
245
248
246
249
# Get job and pod information
247
250
job = await self ._get_job (
248
- job_name = job_name ,
249
- namespace = configuration .namespace ,
251
+ job_name = job_name ,
252
+ namespace = configuration .namespace ,
250
253
client = client ,
251
- job_manifest = configuration .job_manifest
254
+ job_manifest = configuration .job_manifest ,
252
255
)
253
256
if not job :
254
257
return - 1
255
-
258
+
256
259
pod = await self ._get_job_pod (logger , job_name , configuration , client )
257
260
if not pod :
258
261
return - 1
259
262
260
263
# Volcano Job monitoring
261
264
tasks = [
262
265
self ._monitor_volcano_job_state (
263
- logger ,
264
- job_name ,
265
- configuration .namespace ,
266
- client
266
+ logger , job_name , configuration .namespace , client
267
267
)
268
268
]
269
-
269
+
270
270
if configuration .stream_output :
271
271
tasks .append (
272
272
self ._stream_job_logs (
273
- logger ,
274
- pod .metadata .name ,
275
- job_name ,
276
- configuration ,
277
- client
273
+ logger , pod .metadata .name , job_name , configuration , client
278
274
)
279
275
)
280
276
@@ -283,14 +279,17 @@ async def _watch_job(
283
279
results = await asyncio .gather (* tasks , return_exceptions = True )
284
280
for result in results :
285
281
if isinstance (result , Exception ):
286
- logger .error ("Error while monitoring Volcano job" , exc_info = result )
282
+ logger .error (
283
+ "Error while monitoring Volcano job" , exc_info = result
284
+ )
287
285
return - 1
288
286
except TimeoutError :
289
287
logger .error (f"Volcano job { job_name !r} timed out." )
290
288
return - 1
291
289
292
- return await self ._get_container_exit_code (logger , job_name , configuration , client )
293
-
290
+ return await self ._get_container_exit_code (
291
+ logger , job_name , configuration , client
292
+ )
294
293
295
294
# ------------------------------------------------------------------
296
295
# PID helper override (job is dict for Volcano)
@@ -404,7 +403,6 @@ async def run( # type: ignore[override]
404
403
405
404
return KubernetesWorkerResult (identifier = pid , status_code = status_code )
406
405
407
-
408
406
async def _monitor_volcano_job_state (
409
407
self ,
410
408
logger : logging .Logger ,
@@ -414,7 +412,7 @@ async def _monitor_volcano_job_state(
414
412
) -> None :
415
413
"""
416
414
Monitor the state of a Volcano job until completion.
417
-
415
+
418
416
Args:
419
417
logger: Logger to use for logging
420
418
job_name: Name of the Volcano job
@@ -435,14 +433,17 @@ async def _monitor_volcano_job_state(
435
433
logger .info (f"Volcano job { job_name !r} state: { volcano_state } " )
436
434
437
435
if volcano_state in ["Completed" , "Failed" , "Aborted" ]:
438
- logger .info (f"Volcano job { job_name !r} finished with state: { volcano_state } " )
436
+ logger .info (
437
+ f"Volcano job { job_name !r} finished with state: { volcano_state } "
438
+ )
439
439
return
440
440
441
441
await asyncio .sleep (5 ) # Poll every 5 seconds
442
442
except Exception as e :
443
443
logger .warning (f"Error monitoring Volcano job { job_name !r} : { e } " )
444
444
await asyncio .sleep (5 )
445
445
446
+
446
447
# ---------------------------------------------------------------------------
447
448
# Export for Prefect plugin system -----------------------------------------
448
449
# ---------------------------------------------------------------------------
0 commit comments