5
5
import threading
6
6
from abc import ABC , abstractmethod
7
7
from functools import lru_cache
8
+ from pathlib import Path
8
9
from typing import Dict , List , Optional
9
10
10
11
import git
36
37
)
37
38
from dstack ._internal .core .services import is_valid_dstack_resource_name
38
39
from dstack ._internal .utils .logging import get_logger
40
+ from dstack ._internal .utils .path import PathLike
39
41
40
42
logger = get_logger (__name__ )
41
43
42
- DSTACK_WORKING_DIR = "/root/.dstack"
43
44
DSTACK_SHIM_BINARY_NAME = "dstack-shim"
44
- DSTACK_SHIM_BINARY_PATH = f"/usr/local/bin/{ DSTACK_SHIM_BINARY_NAME } "
45
45
DSTACK_RUNNER_BINARY_NAME = "dstack-runner"
46
- DSTACK_RUNNER_BINARY_PATH = f"/usr/local/bin/{ DSTACK_RUNNER_BINARY_NAME } "
47
46
48
47
49
48
class Compute (ABC ):
@@ -336,6 +335,24 @@ def is_volume_detached(self, volume: Volume, instance_id: str) -> bool:
336
335
return True
337
336
338
337
338
+ def get_dstack_working_dir (base_path : Optional [PathLike ] = None ) -> str :
339
+ if base_path is None :
340
+ base_path = "/root"
341
+ return str (Path (base_path , ".dstack" ))
342
+
343
+
344
+ def get_dstack_shim_binary_path (bin_path : Optional [PathLike ] = None ) -> str :
345
+ if bin_path is None :
346
+ bin_path = "/usr/local/bin"
347
+ return str (Path (bin_path , DSTACK_SHIM_BINARY_NAME ))
348
+
349
+
350
+ def get_dstack_runner_binary_path (bin_path : Optional [PathLike ] = None ) -> str :
351
+ if bin_path is None :
352
+ bin_path = "/usr/local/bin"
353
+ return str (Path (bin_path , DSTACK_RUNNER_BINARY_NAME ))
354
+
355
+
339
356
def get_job_instance_name (run : Run , job : Job ) -> str :
340
357
return job .job_spec .job_name
341
358
@@ -442,39 +459,74 @@ def get_cloud_config(**config) -> str:
442
459
443
460
444
461
def get_user_data (
445
- authorized_keys : List [str ], backend_specific_commands : Optional [List [str ]] = None
462
+ authorized_keys : List [str ],
463
+ backend_specific_commands : Optional [List [str ]] = None ,
464
+ base_path : Optional [PathLike ] = None ,
465
+ bin_path : Optional [PathLike ] = None ,
466
+ backend_shim_env : Optional [Dict [str , str ]] = None ,
446
467
) -> str :
447
- shim_commands = get_shim_commands (authorized_keys )
468
+ shim_commands = get_shim_commands (
469
+ authorized_keys = authorized_keys ,
470
+ base_path = base_path ,
471
+ bin_path = bin_path ,
472
+ backend_shim_env = backend_shim_env ,
473
+ )
448
474
commands = (backend_specific_commands or []) + shim_commands
449
475
return get_cloud_config (
450
476
runcmd = [["sh" , "-c" , " && " .join (commands )]],
451
477
ssh_authorized_keys = authorized_keys ,
452
478
)
453
479
454
480
455
- def get_shim_env (authorized_keys : List [str ]) -> Dict [str , str ]:
481
+ def get_shim_env (
482
+ authorized_keys : List [str ],
483
+ base_path : Optional [PathLike ] = None ,
484
+ bin_path : Optional [PathLike ] = None ,
485
+ backend_shim_env : Optional [Dict [str , str ]] = None ,
486
+ ) -> Dict [str , str ]:
456
487
log_level = "6" # Trace
457
488
envs = {
458
- "DSTACK_SHIM_HOME" : DSTACK_WORKING_DIR ,
489
+ "DSTACK_SHIM_HOME" : get_dstack_working_dir ( base_path ) ,
459
490
"DSTACK_SHIM_HTTP_PORT" : str (DSTACK_SHIM_HTTP_PORT ),
460
491
"DSTACK_SHIM_LOG_LEVEL" : log_level ,
461
492
"DSTACK_RUNNER_DOWNLOAD_URL" : get_dstack_runner_download_url (),
462
- "DSTACK_RUNNER_BINARY_PATH" : DSTACK_RUNNER_BINARY_PATH ,
493
+ "DSTACK_RUNNER_BINARY_PATH" : get_dstack_runner_binary_path ( bin_path ) ,
463
494
"DSTACK_RUNNER_HTTP_PORT" : str (DSTACK_RUNNER_HTTP_PORT ),
464
495
"DSTACK_RUNNER_SSH_PORT" : str (DSTACK_RUNNER_SSH_PORT ),
465
496
"DSTACK_RUNNER_LOG_LEVEL" : log_level ,
466
497
"DSTACK_PUBLIC_SSH_KEY" : "\n " .join (authorized_keys ),
467
498
}
499
+ if backend_shim_env is not None :
500
+ envs |= backend_shim_env
468
501
return envs
469
502
470
503
471
504
def get_shim_commands (
472
- authorized_keys : List [str ], * , is_privileged : bool = False , pjrt_device : Optional [str ] = None
505
+ authorized_keys : List [str ],
506
+ * ,
507
+ is_privileged : bool = False ,
508
+ pjrt_device : Optional [str ] = None ,
509
+ base_path : Optional [PathLike ] = None ,
510
+ bin_path : Optional [PathLike ] = None ,
511
+ backend_shim_env : Optional [Dict [str , str ]] = None ,
473
512
) -> List [str ]:
474
- commands = get_shim_pre_start_commands ()
475
- for k , v in get_shim_env (authorized_keys ).items ():
513
+ commands = get_shim_pre_start_commands (
514
+ base_path = base_path ,
515
+ bin_path = bin_path ,
516
+ )
517
+ shim_env = get_shim_env (
518
+ authorized_keys = authorized_keys ,
519
+ base_path = base_path ,
520
+ bin_path = bin_path ,
521
+ backend_shim_env = backend_shim_env ,
522
+ )
523
+ for k , v in shim_env .items ():
476
524
commands += [f'export "{ k } ={ v } "' ]
477
- commands += get_run_shim_script (is_privileged , pjrt_device )
525
+ commands += get_run_shim_script (
526
+ is_privileged = is_privileged ,
527
+ pjrt_device = pjrt_device ,
528
+ bin_path = bin_path ,
529
+ )
478
530
return commands
479
531
480
532
@@ -511,25 +563,33 @@ def get_dstack_shim_download_url() -> str:
511
563
return f"https://{ bucket } .s3.eu-west-1.amazonaws.com/{ build } /binaries/dstack-shim-linux-amd64"
512
564
513
565
514
- def get_shim_pre_start_commands () -> List [str ]:
566
+ def get_shim_pre_start_commands (
567
+ base_path : Optional [PathLike ] = None ,
568
+ bin_path : Optional [PathLike ] = None ,
569
+ ) -> List [str ]:
515
570
url = get_dstack_shim_download_url ()
516
-
571
+ dstack_shim_binary_path = get_dstack_shim_binary_path (bin_path )
572
+ dstack_working_dir = get_dstack_working_dir (base_path )
517
573
return [
518
574
f"dlpath=$(sudo mktemp -t { DSTACK_SHIM_BINARY_NAME } .XXXXXXXXXX)" ,
519
575
# -sS -- disable progress meter and warnings, but still show errors (unlike bare -s)
520
576
f'sudo curl -sS --compressed --connect-timeout 60 --max-time 240 --retry 1 --output "$dlpath" "{ url } "' ,
521
- f'sudo mv "$dlpath" { DSTACK_SHIM_BINARY_PATH } ' ,
522
- f"sudo chmod +x { DSTACK_SHIM_BINARY_PATH } " ,
523
- f"sudo mkdir { DSTACK_WORKING_DIR } -p" ,
577
+ f'sudo mv "$dlpath" { dstack_shim_binary_path } ' ,
578
+ f"sudo chmod +x { dstack_shim_binary_path } " ,
579
+ f"sudo mkdir { dstack_working_dir } -p" ,
524
580
]
525
581
526
582
527
- def get_run_shim_script (is_privileged : bool , pjrt_device : Optional [str ]) -> List [str ]:
583
+ def get_run_shim_script (
584
+ is_privileged : bool ,
585
+ pjrt_device : Optional [str ],
586
+ bin_path : Optional [PathLike ] = None ,
587
+ ) -> List [str ]:
588
+ dstack_shim_binary_path = get_dstack_shim_binary_path (bin_path )
528
589
privileged_flag = "--privileged" if is_privileged else ""
529
590
pjrt_device_env = f"--pjrt-device={ pjrt_device } " if pjrt_device else ""
530
-
531
591
return [
532
- f"nohup { DSTACK_SHIM_BINARY_PATH } { privileged_flag } { pjrt_device_env } &" ,
592
+ f"nohup { dstack_shim_binary_path } { privileged_flag } { pjrt_device_env } &" ,
533
593
]
534
594
535
595
@@ -555,7 +615,11 @@ def get_gateway_user_data(authorized_key: str) -> str:
555
615
)
556
616
557
617
558
- def get_docker_commands (authorized_keys : list [str ]) -> list [str ]:
618
+ def get_docker_commands (
619
+ authorized_keys : list [str ],
620
+ bin_path : Optional [PathLike ] = None ,
621
+ ) -> list [str ]:
622
+ dstack_runner_binary_path = get_dstack_runner_binary_path (bin_path )
559
623
authorized_keys_content = "\n " .join (authorized_keys ).strip ()
560
624
commands = [
561
625
# save and unset ld.so variables
@@ -606,10 +670,10 @@ def get_docker_commands(authorized_keys: list[str]) -> list[str]:
606
670
607
671
url = get_dstack_runner_download_url ()
608
672
commands += [
609
- f"curl --connect-timeout 60 --max-time 240 --retry 1 --output { DSTACK_RUNNER_BINARY_PATH } { url } " ,
610
- f"chmod +x { DSTACK_RUNNER_BINARY_PATH } " ,
673
+ f"curl --connect-timeout 60 --max-time 240 --retry 1 --output { dstack_runner_binary_path } { url } " ,
674
+ f"chmod +x { dstack_runner_binary_path } " ,
611
675
(
612
- f"{ DSTACK_RUNNER_BINARY_PATH } "
676
+ f"{ dstack_runner_binary_path } "
613
677
" --log-level 6"
614
678
" start"
615
679
f" --http-port { DSTACK_RUNNER_HTTP_PORT } "
0 commit comments