Skip to content

Commit 16ddda8

Browse files
authored
Support A3 High/Edge GCP clusters with GPUDirect-TCPX (#2549)
* Support gpu_devices in task config * Implement tcpx prototype * Parametrize shim and runner host paths
1 parent a2cc68a commit 16ddda8

File tree

12 files changed

+291
-55
lines changed

12 files changed

+291
-55
lines changed

runner/docs/shim.openapi.yaml

+18
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,19 @@ components:
270270
type: string
271271
default: ""
272272
description: Mount point inside container
273+
274+
GPUDevice:
275+
title: shim.GPUDevice
276+
type: object
277+
properties:
278+
path_on_host:
279+
type: string
280+
default: ""
281+
description: Instance (host) path
282+
path_in_container:
283+
type: string
284+
default: ""
285+
description: Path inside container
273286

274287
HealthcheckResponse:
275288
title: shim.api.HealthcheckResponse
@@ -438,6 +451,11 @@ components:
438451
items:
439452
$ref: "#/components/schemas/InstanceMountPoint"
440453
default: []
454+
gpu_devices:
455+
type: array
456+
items:
457+
$ref: "#/components/schemas/GPUDevice"
458+
default: []
441459
host_ssh_user:
442460
type: string
443461
default: ""

runner/internal/shim/docker.go

+18-1
Original file line numberDiff line numberDiff line change
@@ -814,7 +814,11 @@ func (d *DockerRunner) createContainer(ctx context.Context, task *Task) error {
814814
hostConfig.Resources.NanoCPUs = int64(task.config.CPU * 1000000000)
815815
hostConfig.Resources.Memory = task.config.Memory
816816
if len(task.gpuIDs) > 0 {
817-
configureGpus(containerConfig, hostConfig, d.gpuVendor, task.gpuIDs)
817+
if len(task.config.GPUDevices) > 0 {
818+
configureGpuDevices(hostConfig, task.config.GPUDevices)
819+
} else {
820+
configureGpus(containerConfig, hostConfig, d.gpuVendor, task.gpuIDs)
821+
}
818822
}
819823
configureHpcNetworkingIfAvailable(hostConfig)
820824

@@ -988,6 +992,19 @@ func getNetworkMode(networkMode NetworkMode) container.NetworkMode {
988992
return "default"
989993
}
990994

995+
func configureGpuDevices(hostConfig *container.HostConfig, gpuDevices []GPUDevice) {
996+
for _, gpuDevice := range gpuDevices {
997+
hostConfig.Resources.Devices = append(
998+
hostConfig.Resources.Devices,
999+
container.DeviceMapping{
1000+
PathOnHost: gpuDevice.PathOnHost,
1001+
PathInContainer: gpuDevice.PathInContainer,
1002+
CgroupPermissions: "rwm",
1003+
},
1004+
)
1005+
}
1006+
}
1007+
9911008
func configureGpus(config *container.Config, hostConfig *container.HostConfig, vendor host.GpuVendor, ids []string) {
9921009
// NVIDIA: ids are identifiers reported by nvidia-smi, GPU-<UUID> strings
9931010
// AMD: ids are DRI render node paths, e.g., /dev/dri/renderD128

runner/internal/shim/models.go

+10-2
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,11 @@ type PortMapping struct {
7070
Container int `json:"container"`
7171
}
7272

73+
type GPUDevice struct {
74+
PathOnHost string `json:"path_on_host"`
75+
PathInContainer string `json:"path_in_container"`
76+
}
77+
7378
type TaskConfig struct {
7479
ID string `json:"id"`
7580
Name string `json:"name"`
@@ -86,8 +91,11 @@ type TaskConfig struct {
8691
Volumes []VolumeInfo `json:"volumes"`
8792
VolumeMounts []VolumeMountPoint `json:"volume_mounts"`
8893
InstanceMounts []InstanceMountPoint `json:"instance_mounts"`
89-
HostSshUser string `json:"host_ssh_user"`
90-
HostSshKeys []string `json:"host_ssh_keys"`
94+
// GPUDevices allows the server to set gpu devices instead of relying on the runner default logic.
95+
// E.g. passing nvidia devices directly instead of using nvidia-container-toolkit.
96+
GPUDevices []GPUDevice `json:"gpu_devices"`
97+
HostSshUser string `json:"host_ssh_user"`
98+
HostSshKeys []string `json:"host_ssh_keys"`
9199
// TODO: submit keys to runner, not to shim
92100
ContainerSshKeys []string `json:"container_ssh_keys"`
93101
}

src/dstack/_internal/core/backends/base/compute.py

+88-24
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import threading
66
from abc import ABC, abstractmethod
77
from functools import lru_cache
8+
from pathlib import Path
89
from typing import Dict, List, Optional
910

1011
import git
@@ -36,14 +37,12 @@
3637
)
3738
from dstack._internal.core.services import is_valid_dstack_resource_name
3839
from dstack._internal.utils.logging import get_logger
40+
from dstack._internal.utils.path import PathLike
3941

4042
logger = get_logger(__name__)
4143

42-
DSTACK_WORKING_DIR = "/root/.dstack"
4344
DSTACK_SHIM_BINARY_NAME = "dstack-shim"
44-
DSTACK_SHIM_BINARY_PATH = f"/usr/local/bin/{DSTACK_SHIM_BINARY_NAME}"
4545
DSTACK_RUNNER_BINARY_NAME = "dstack-runner"
46-
DSTACK_RUNNER_BINARY_PATH = f"/usr/local/bin/{DSTACK_RUNNER_BINARY_NAME}"
4746

4847

4948
class Compute(ABC):
@@ -336,6 +335,24 @@ def is_volume_detached(self, volume: Volume, instance_id: str) -> bool:
336335
return True
337336

338337

338+
def get_dstack_working_dir(base_path: Optional[PathLike] = None) -> str:
339+
if base_path is None:
340+
base_path = "/root"
341+
return str(Path(base_path, ".dstack"))
342+
343+
344+
def get_dstack_shim_binary_path(bin_path: Optional[PathLike] = None) -> str:
345+
if bin_path is None:
346+
bin_path = "/usr/local/bin"
347+
return str(Path(bin_path, DSTACK_SHIM_BINARY_NAME))
348+
349+
350+
def get_dstack_runner_binary_path(bin_path: Optional[PathLike] = None) -> str:
351+
if bin_path is None:
352+
bin_path = "/usr/local/bin"
353+
return str(Path(bin_path, DSTACK_RUNNER_BINARY_NAME))
354+
355+
339356
def get_job_instance_name(run: Run, job: Job) -> str:
340357
return job.job_spec.job_name
341358

@@ -442,39 +459,74 @@ def get_cloud_config(**config) -> str:
442459

443460

444461
def get_user_data(
445-
authorized_keys: List[str], backend_specific_commands: Optional[List[str]] = None
462+
authorized_keys: List[str],
463+
backend_specific_commands: Optional[List[str]] = None,
464+
base_path: Optional[PathLike] = None,
465+
bin_path: Optional[PathLike] = None,
466+
backend_shim_env: Optional[Dict[str, str]] = None,
446467
) -> str:
447-
shim_commands = get_shim_commands(authorized_keys)
468+
shim_commands = get_shim_commands(
469+
authorized_keys=authorized_keys,
470+
base_path=base_path,
471+
bin_path=bin_path,
472+
backend_shim_env=backend_shim_env,
473+
)
448474
commands = (backend_specific_commands or []) + shim_commands
449475
return get_cloud_config(
450476
runcmd=[["sh", "-c", " && ".join(commands)]],
451477
ssh_authorized_keys=authorized_keys,
452478
)
453479

454480

455-
def get_shim_env(authorized_keys: List[str]) -> Dict[str, str]:
481+
def get_shim_env(
482+
authorized_keys: List[str],
483+
base_path: Optional[PathLike] = None,
484+
bin_path: Optional[PathLike] = None,
485+
backend_shim_env: Optional[Dict[str, str]] = None,
486+
) -> Dict[str, str]:
456487
log_level = "6" # Trace
457488
envs = {
458-
"DSTACK_SHIM_HOME": DSTACK_WORKING_DIR,
489+
"DSTACK_SHIM_HOME": get_dstack_working_dir(base_path),
459490
"DSTACK_SHIM_HTTP_PORT": str(DSTACK_SHIM_HTTP_PORT),
460491
"DSTACK_SHIM_LOG_LEVEL": log_level,
461492
"DSTACK_RUNNER_DOWNLOAD_URL": get_dstack_runner_download_url(),
462-
"DSTACK_RUNNER_BINARY_PATH": DSTACK_RUNNER_BINARY_PATH,
493+
"DSTACK_RUNNER_BINARY_PATH": get_dstack_runner_binary_path(bin_path),
463494
"DSTACK_RUNNER_HTTP_PORT": str(DSTACK_RUNNER_HTTP_PORT),
464495
"DSTACK_RUNNER_SSH_PORT": str(DSTACK_RUNNER_SSH_PORT),
465496
"DSTACK_RUNNER_LOG_LEVEL": log_level,
466497
"DSTACK_PUBLIC_SSH_KEY": "\n".join(authorized_keys),
467498
}
499+
if backend_shim_env is not None:
500+
envs |= backend_shim_env
468501
return envs
469502

470503

471504
def get_shim_commands(
472-
authorized_keys: List[str], *, is_privileged: bool = False, pjrt_device: Optional[str] = None
505+
authorized_keys: List[str],
506+
*,
507+
is_privileged: bool = False,
508+
pjrt_device: Optional[str] = None,
509+
base_path: Optional[PathLike] = None,
510+
bin_path: Optional[PathLike] = None,
511+
backend_shim_env: Optional[Dict[str, str]] = None,
473512
) -> List[str]:
474-
commands = get_shim_pre_start_commands()
475-
for k, v in get_shim_env(authorized_keys).items():
513+
commands = get_shim_pre_start_commands(
514+
base_path=base_path,
515+
bin_path=bin_path,
516+
)
517+
shim_env = get_shim_env(
518+
authorized_keys=authorized_keys,
519+
base_path=base_path,
520+
bin_path=bin_path,
521+
backend_shim_env=backend_shim_env,
522+
)
523+
for k, v in shim_env.items():
476524
commands += [f'export "{k}={v}"']
477-
commands += get_run_shim_script(is_privileged, pjrt_device)
525+
commands += get_run_shim_script(
526+
is_privileged=is_privileged,
527+
pjrt_device=pjrt_device,
528+
bin_path=bin_path,
529+
)
478530
return commands
479531

480532

@@ -511,25 +563,33 @@ def get_dstack_shim_download_url() -> str:
511563
return f"https://{bucket}.s3.eu-west-1.amazonaws.com/{build}/binaries/dstack-shim-linux-amd64"
512564

513565

514-
def get_shim_pre_start_commands() -> List[str]:
566+
def get_shim_pre_start_commands(
567+
base_path: Optional[PathLike] = None,
568+
bin_path: Optional[PathLike] = None,
569+
) -> List[str]:
515570
url = get_dstack_shim_download_url()
516-
571+
dstack_shim_binary_path = get_dstack_shim_binary_path(bin_path)
572+
dstack_working_dir = get_dstack_working_dir(base_path)
517573
return [
518574
f"dlpath=$(sudo mktemp -t {DSTACK_SHIM_BINARY_NAME}.XXXXXXXXXX)",
519575
# -sS -- disable progress meter and warnings, but still show errors (unlike bare -s)
520576
f'sudo curl -sS --compressed --connect-timeout 60 --max-time 240 --retry 1 --output "$dlpath" "{url}"',
521-
f'sudo mv "$dlpath" {DSTACK_SHIM_BINARY_PATH}',
522-
f"sudo chmod +x {DSTACK_SHIM_BINARY_PATH}",
523-
f"sudo mkdir {DSTACK_WORKING_DIR} -p",
577+
f'sudo mv "$dlpath" {dstack_shim_binary_path}',
578+
f"sudo chmod +x {dstack_shim_binary_path}",
579+
f"sudo mkdir {dstack_working_dir} -p",
524580
]
525581

526582

527-
def get_run_shim_script(is_privileged: bool, pjrt_device: Optional[str]) -> List[str]:
583+
def get_run_shim_script(
584+
is_privileged: bool,
585+
pjrt_device: Optional[str],
586+
bin_path: Optional[PathLike] = None,
587+
) -> List[str]:
588+
dstack_shim_binary_path = get_dstack_shim_binary_path(bin_path)
528589
privileged_flag = "--privileged" if is_privileged else ""
529590
pjrt_device_env = f"--pjrt-device={pjrt_device}" if pjrt_device else ""
530-
531591
return [
532-
f"nohup {DSTACK_SHIM_BINARY_PATH} {privileged_flag} {pjrt_device_env} &",
592+
f"nohup {dstack_shim_binary_path} {privileged_flag} {pjrt_device_env} &",
533593
]
534594

535595

@@ -555,7 +615,11 @@ def get_gateway_user_data(authorized_key: str) -> str:
555615
)
556616

557617

558-
def get_docker_commands(authorized_keys: list[str]) -> list[str]:
618+
def get_docker_commands(
619+
authorized_keys: list[str],
620+
bin_path: Optional[PathLike] = None,
621+
) -> list[str]:
622+
dstack_runner_binary_path = get_dstack_runner_binary_path(bin_path)
559623
authorized_keys_content = "\n".join(authorized_keys).strip()
560624
commands = [
561625
# save and unset ld.so variables
@@ -606,10 +670,10 @@ def get_docker_commands(authorized_keys: list[str]) -> list[str]:
606670

607671
url = get_dstack_runner_download_url()
608672
commands += [
609-
f"curl --connect-timeout 60 --max-time 240 --retry 1 --output {DSTACK_RUNNER_BINARY_PATH} {url}",
610-
f"chmod +x {DSTACK_RUNNER_BINARY_PATH}",
673+
f"curl --connect-timeout 60 --max-time 240 --retry 1 --output {dstack_runner_binary_path} {url}",
674+
f"chmod +x {dstack_runner_binary_path}",
611675
(
612-
f"{DSTACK_RUNNER_BINARY_PATH}"
676+
f"{dstack_runner_binary_path}"
613677
" --log-level 6"
614678
" start"
615679
f" --http-port {DSTACK_RUNNER_HTTP_PORT}"

src/dstack/_internal/core/backends/gcp/compute.py

+40-8
Original file line numberDiff line numberDiff line change
@@ -296,11 +296,9 @@ def create_instance(
296296
gpus=instance_offer.instance.resources.gpus,
297297
),
298298
spot=instance_offer.instance.resources.spot,
299-
user_data=get_user_data(
300-
authorized_keys,
301-
backend_specific_commands=_get_backend_specific_commands(
302-
instance_offer.instance.name
303-
),
299+
user_data=_get_user_data(
300+
authorized_keys=authorized_keys,
301+
instance_type_name=instance_offer.instance.name,
304302
),
305303
authorized_keys=authorized_keys,
306304
labels=labels,
@@ -841,10 +839,14 @@ def _get_extra_subnets(
841839
) -> List[Tuple[str, str]]:
842840
if config.extra_vpcs is None:
843841
return []
844-
if instance_type_name != "a3-megagpu-8g":
842+
if instance_type_name == "a3-megagpu-8g":
843+
subnets_num = 8
844+
elif instance_type_name in ["a3-edgegpu-8g", "a3-highgpu-8g"]:
845+
subnets_num = 4
846+
else:
845847
return []
846848
extra_subnets = []
847-
for vpc_name in config.extra_vpcs:
849+
for vpc_name in config.extra_vpcs[:subnets_num]:
848850
subnet = gcp_resources.get_vpc_subnet_or_error(
849851
subnetworks_client=subnetworks_client,
850852
vpc_project_id=config.vpc_project_id or config.project_id,
@@ -856,12 +858,14 @@ def _get_extra_subnets(
856858
vpc_name=vpc_name,
857859
)
858860
extra_subnets.append((vpc_resource_name, subnet))
859-
return extra_subnets[:8]
861+
return extra_subnets
860862

861863

862864
def _get_image_id(instance_type_name: str, cuda: bool) -> str:
863865
if instance_type_name == "a3-megagpu-8g":
864866
image_name = "dstack-a3mega-5"
867+
elif instance_type_name in ["a3-edgegpu-8g", "a3-highgpu-8g"]:
868+
return "projects/cos-cloud/global/images/cos-105-17412-535-78"
865869
elif cuda:
866870
image_name = f"dstack-cuda-{version.base_image}"
867871
else:
@@ -874,9 +878,37 @@ def _get_gateway_image_id() -> str:
874878
return "projects/ubuntu-os-cloud/global/images/ubuntu-2204-jammy-v20230714"
875879

876880

881+
def _get_user_data(authorized_keys: List[str], instance_type_name: str) -> str:
882+
base_path = None
883+
bin_path = None
884+
backend_shim_env = None
885+
if instance_type_name in ["a3-edgegpu-8g", "a3-highgpu-8g"]:
886+
# In the COS image the / file system is not writable.
887+
# /home and /var are writable but not executable.
888+
# Only /etc is both writable and executable, so use it for shim/runner binaries.
889+
# See: https://cloud.google.com/container-optimized-os/docs/concepts/disks-and-filesystem
890+
base_path = bin_path = "/etc"
891+
backend_shim_env = {
892+
# In COS nvidia binaries are not installed on PATH by default.
893+
# Set so that shim can run nvidia-smi.
894+
"PATH": "/var/lib/nvidia/bin:$PATH",
895+
}
896+
return get_user_data(
897+
authorized_keys=authorized_keys,
898+
backend_specific_commands=_get_backend_specific_commands(
899+
instance_type_name=instance_type_name,
900+
),
901+
base_path=base_path,
902+
bin_path=bin_path,
903+
backend_shim_env=backend_shim_env,
904+
)
905+
906+
877907
def _get_backend_specific_commands(instance_type_name: str) -> List[str]:
878908
if instance_type_name == "a3-megagpu-8g":
879909
return tcpx_features.get_backend_specific_commands_tcpxo()
910+
if instance_type_name in ["a3-edgegpu-8g", "a3-highgpu-8g"]:
911+
return tcpx_features.get_backend_specific_commands_tcpx()
880912
return []
881913

882914

0 commit comments

Comments
 (0)