Skip to content

Support A3 High/Edge GCP clusters with GPUDirect-TCPX #2549

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions runner/docs/shim.openapi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,19 @@ components:
type: string
default: ""
description: Mount point inside container

GPUDevice:
title: shim.GPUDevice
type: object
properties:
path_on_host:
type: string
default: ""
description: Instance (host) path
path_in_container:
type: string
default: ""
description: Path inside container

HealthcheckResponse:
title: shim.api.HealthcheckResponse
Expand Down Expand Up @@ -438,6 +451,11 @@ components:
items:
$ref: "#/components/schemas/InstanceMountPoint"
default: []
gpu_devices:
type: array
items:
$ref: "#/components/schemas/GPUDevice"
default: []
host_ssh_user:
type: string
default: ""
Expand Down
19 changes: 18 additions & 1 deletion runner/internal/shim/docker.go
Original file line number Diff line number Diff line change
Expand Up @@ -814,7 +814,11 @@ func (d *DockerRunner) createContainer(ctx context.Context, task *Task) error {
hostConfig.Resources.NanoCPUs = int64(task.config.CPU * 1000000000)
hostConfig.Resources.Memory = task.config.Memory
if len(task.gpuIDs) > 0 {
configureGpus(containerConfig, hostConfig, d.gpuVendor, task.gpuIDs)
if len(task.config.GPUDevices) > 0 {
configureGpuDevices(hostConfig, task.config.GPUDevices)
} else {
configureGpus(containerConfig, hostConfig, d.gpuVendor, task.gpuIDs)
}
}
configureHpcNetworkingIfAvailable(hostConfig)

Expand Down Expand Up @@ -988,6 +992,19 @@ func getNetworkMode(networkMode NetworkMode) container.NetworkMode {
return "default"
}

func configureGpuDevices(hostConfig *container.HostConfig, gpuDevices []GPUDevice) {
for _, gpuDevice := range gpuDevices {
hostConfig.Resources.Devices = append(
hostConfig.Resources.Devices,
container.DeviceMapping{
PathOnHost: gpuDevice.PathOnHost,
PathInContainer: gpuDevice.PathInContainer,
CgroupPermissions: "rwm",
},
)
}
}

func configureGpus(config *container.Config, hostConfig *container.HostConfig, vendor host.GpuVendor, ids []string) {
// NVIDIA: ids are identifiers reported by nvidia-smi, GPU-<UUID> strings
// AMD: ids are DRI render node paths, e.g., /dev/dri/renderD128
Expand Down
12 changes: 10 additions & 2 deletions runner/internal/shim/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@ type PortMapping struct {
Container int `json:"container"`
}

type GPUDevice struct {
PathOnHost string `json:"path_on_host"`
PathInContainer string `json:"path_in_container"`
}

type TaskConfig struct {
ID string `json:"id"`
Name string `json:"name"`
Expand All @@ -86,8 +91,11 @@ type TaskConfig struct {
Volumes []VolumeInfo `json:"volumes"`
VolumeMounts []VolumeMountPoint `json:"volume_mounts"`
InstanceMounts []InstanceMountPoint `json:"instance_mounts"`
HostSshUser string `json:"host_ssh_user"`
HostSshKeys []string `json:"host_ssh_keys"`
// GPUDevices allows the server to set gpu devices instead of relying on the runner default logic.
// E.g. passing nvidia devices directly instead of using nvidia-container-toolkit.
GPUDevices []GPUDevice `json:"gpu_devices"`
HostSshUser string `json:"host_ssh_user"`
HostSshKeys []string `json:"host_ssh_keys"`
// TODO: submit keys to runner, not to shim
ContainerSshKeys []string `json:"container_ssh_keys"`
}
Expand Down
112 changes: 88 additions & 24 deletions src/dstack/_internal/core/backends/base/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import threading
from abc import ABC, abstractmethod
from functools import lru_cache
from pathlib import Path
from typing import Dict, List, Optional

import git
Expand Down Expand Up @@ -36,14 +37,12 @@
)
from dstack._internal.core.services import is_valid_dstack_resource_name
from dstack._internal.utils.logging import get_logger
from dstack._internal.utils.path import PathLike

logger = get_logger(__name__)

DSTACK_WORKING_DIR = "/root/.dstack"
DSTACK_SHIM_BINARY_NAME = "dstack-shim"
DSTACK_SHIM_BINARY_PATH = f"/usr/local/bin/{DSTACK_SHIM_BINARY_NAME}"
DSTACK_RUNNER_BINARY_NAME = "dstack-runner"
DSTACK_RUNNER_BINARY_PATH = f"/usr/local/bin/{DSTACK_RUNNER_BINARY_NAME}"


class Compute(ABC):
Expand Down Expand Up @@ -336,6 +335,24 @@ def is_volume_detached(self, volume: Volume, instance_id: str) -> bool:
return True


def get_dstack_working_dir(base_path: Optional[PathLike] = None) -> str:
if base_path is None:
base_path = "/root"
return str(Path(base_path, ".dstack"))


def get_dstack_shim_binary_path(bin_path: Optional[PathLike] = None) -> str:
if bin_path is None:
bin_path = "/usr/local/bin"
return str(Path(bin_path, DSTACK_SHIM_BINARY_NAME))


def get_dstack_runner_binary_path(bin_path: Optional[PathLike] = None) -> str:
if bin_path is None:
bin_path = "/usr/local/bin"
return str(Path(bin_path, DSTACK_RUNNER_BINARY_NAME))


def get_job_instance_name(run: Run, job: Job) -> str:
return job.job_spec.job_name

Expand Down Expand Up @@ -442,39 +459,74 @@ def get_cloud_config(**config) -> str:


def get_user_data(
authorized_keys: List[str], backend_specific_commands: Optional[List[str]] = None
authorized_keys: List[str],
backend_specific_commands: Optional[List[str]] = None,
base_path: Optional[PathLike] = None,
bin_path: Optional[PathLike] = None,
backend_shim_env: Optional[Dict[str, str]] = None,
) -> str:
shim_commands = get_shim_commands(authorized_keys)
shim_commands = get_shim_commands(
authorized_keys=authorized_keys,
base_path=base_path,
bin_path=bin_path,
backend_shim_env=backend_shim_env,
)
commands = (backend_specific_commands or []) + shim_commands
return get_cloud_config(
runcmd=[["sh", "-c", " && ".join(commands)]],
ssh_authorized_keys=authorized_keys,
)


def get_shim_env(authorized_keys: List[str]) -> Dict[str, str]:
def get_shim_env(
authorized_keys: List[str],
base_path: Optional[PathLike] = None,
bin_path: Optional[PathLike] = None,
backend_shim_env: Optional[Dict[str, str]] = None,
) -> Dict[str, str]:
log_level = "6" # Trace
envs = {
"DSTACK_SHIM_HOME": DSTACK_WORKING_DIR,
"DSTACK_SHIM_HOME": get_dstack_working_dir(base_path),
"DSTACK_SHIM_HTTP_PORT": str(DSTACK_SHIM_HTTP_PORT),
"DSTACK_SHIM_LOG_LEVEL": log_level,
"DSTACK_RUNNER_DOWNLOAD_URL": get_dstack_runner_download_url(),
"DSTACK_RUNNER_BINARY_PATH": DSTACK_RUNNER_BINARY_PATH,
"DSTACK_RUNNER_BINARY_PATH": get_dstack_runner_binary_path(bin_path),
"DSTACK_RUNNER_HTTP_PORT": str(DSTACK_RUNNER_HTTP_PORT),
"DSTACK_RUNNER_SSH_PORT": str(DSTACK_RUNNER_SSH_PORT),
"DSTACK_RUNNER_LOG_LEVEL": log_level,
"DSTACK_PUBLIC_SSH_KEY": "\n".join(authorized_keys),
}
if backend_shim_env is not None:
envs |= backend_shim_env
return envs


def get_shim_commands(
authorized_keys: List[str], *, is_privileged: bool = False, pjrt_device: Optional[str] = None
authorized_keys: List[str],
*,
is_privileged: bool = False,
pjrt_device: Optional[str] = None,
base_path: Optional[PathLike] = None,
bin_path: Optional[PathLike] = None,
backend_shim_env: Optional[Dict[str, str]] = None,
) -> List[str]:
commands = get_shim_pre_start_commands()
for k, v in get_shim_env(authorized_keys).items():
commands = get_shim_pre_start_commands(
base_path=base_path,
bin_path=bin_path,
)
shim_env = get_shim_env(
authorized_keys=authorized_keys,
base_path=base_path,
bin_path=bin_path,
backend_shim_env=backend_shim_env,
)
for k, v in shim_env.items():
commands += [f'export "{k}={v}"']
commands += get_run_shim_script(is_privileged, pjrt_device)
commands += get_run_shim_script(
is_privileged=is_privileged,
pjrt_device=pjrt_device,
bin_path=bin_path,
)
return commands


Expand Down Expand Up @@ -511,25 +563,33 @@ def get_dstack_shim_download_url() -> str:
return f"https://{bucket}.s3.eu-west-1.amazonaws.com/{build}/binaries/dstack-shim-linux-amd64"


def get_shim_pre_start_commands() -> List[str]:
def get_shim_pre_start_commands(
base_path: Optional[PathLike] = None,
bin_path: Optional[PathLike] = None,
) -> List[str]:
url = get_dstack_shim_download_url()

dstack_shim_binary_path = get_dstack_shim_binary_path(bin_path)
dstack_working_dir = get_dstack_working_dir(base_path)
return [
f"dlpath=$(sudo mktemp -t {DSTACK_SHIM_BINARY_NAME}.XXXXXXXXXX)",
# -sS -- disable progress meter and warnings, but still show errors (unlike bare -s)
f'sudo curl -sS --compressed --connect-timeout 60 --max-time 240 --retry 1 --output "$dlpath" "{url}"',
f'sudo mv "$dlpath" {DSTACK_SHIM_BINARY_PATH}',
f"sudo chmod +x {DSTACK_SHIM_BINARY_PATH}",
f"sudo mkdir {DSTACK_WORKING_DIR} -p",
f'sudo mv "$dlpath" {dstack_shim_binary_path}',
f"sudo chmod +x {dstack_shim_binary_path}",
f"sudo mkdir {dstack_working_dir} -p",
]


def get_run_shim_script(is_privileged: bool, pjrt_device: Optional[str]) -> List[str]:
def get_run_shim_script(
is_privileged: bool,
pjrt_device: Optional[str],
bin_path: Optional[PathLike] = None,
) -> List[str]:
dstack_shim_binary_path = get_dstack_shim_binary_path(bin_path)
privileged_flag = "--privileged" if is_privileged else ""
pjrt_device_env = f"--pjrt-device={pjrt_device}" if pjrt_device else ""

return [
f"nohup {DSTACK_SHIM_BINARY_PATH} {privileged_flag} {pjrt_device_env} &",
f"nohup {dstack_shim_binary_path} {privileged_flag} {pjrt_device_env} &",
]


Expand All @@ -555,7 +615,11 @@ def get_gateway_user_data(authorized_key: str) -> str:
)


def get_docker_commands(authorized_keys: list[str]) -> list[str]:
def get_docker_commands(
authorized_keys: list[str],
bin_path: Optional[PathLike] = None,
) -> list[str]:
dstack_runner_binary_path = get_dstack_runner_binary_path(bin_path)
authorized_keys_content = "\n".join(authorized_keys).strip()
commands = [
# save and unset ld.so variables
Expand Down Expand Up @@ -606,10 +670,10 @@ def get_docker_commands(authorized_keys: list[str]) -> list[str]:

url = get_dstack_runner_download_url()
commands += [
f"curl --connect-timeout 60 --max-time 240 --retry 1 --output {DSTACK_RUNNER_BINARY_PATH} {url}",
f"chmod +x {DSTACK_RUNNER_BINARY_PATH}",
f"curl --connect-timeout 60 --max-time 240 --retry 1 --output {dstack_runner_binary_path} {url}",
f"chmod +x {dstack_runner_binary_path}",
(
f"{DSTACK_RUNNER_BINARY_PATH}"
f"{dstack_runner_binary_path}"
" --log-level 6"
" start"
f" --http-port {DSTACK_RUNNER_HTTP_PORT}"
Expand Down
48 changes: 40 additions & 8 deletions src/dstack/_internal/core/backends/gcp/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,11 +296,9 @@ def create_instance(
gpus=instance_offer.instance.resources.gpus,
),
spot=instance_offer.instance.resources.spot,
user_data=get_user_data(
authorized_keys,
backend_specific_commands=_get_backend_specific_commands(
instance_offer.instance.name
),
user_data=_get_user_data(
authorized_keys=authorized_keys,
instance_type_name=instance_offer.instance.name,
),
authorized_keys=authorized_keys,
labels=labels,
Expand Down Expand Up @@ -841,10 +839,14 @@ def _get_extra_subnets(
) -> List[Tuple[str, str]]:
if config.extra_vpcs is None:
return []
if instance_type_name != "a3-megagpu-8g":
if instance_type_name == "a3-megagpu-8g":
subnets_num = 8
elif instance_type_name in ["a3-edgegpu-8g", "a3-highgpu-8g"]:
subnets_num = 4
else:
return []
extra_subnets = []
for vpc_name in config.extra_vpcs:
for vpc_name in config.extra_vpcs[:subnets_num]:
subnet = gcp_resources.get_vpc_subnet_or_error(
subnetworks_client=subnetworks_client,
vpc_project_id=config.vpc_project_id or config.project_id,
Expand All @@ -856,12 +858,14 @@ def _get_extra_subnets(
vpc_name=vpc_name,
)
extra_subnets.append((vpc_resource_name, subnet))
return extra_subnets[:8]
return extra_subnets


def _get_image_id(instance_type_name: str, cuda: bool) -> str:
if instance_type_name == "a3-megagpu-8g":
image_name = "dstack-a3mega-5"
elif instance_type_name in ["a3-edgegpu-8g", "a3-highgpu-8g"]:
return "projects/cos-cloud/global/images/cos-105-17412-535-78"
elif cuda:
image_name = f"dstack-cuda-{version.base_image}"
else:
Expand All @@ -874,9 +878,37 @@ def _get_gateway_image_id() -> str:
return "projects/ubuntu-os-cloud/global/images/ubuntu-2204-jammy-v20230714"


def _get_user_data(authorized_keys: List[str], instance_type_name: str) -> str:
base_path = None
bin_path = None
backend_shim_env = None
if instance_type_name in ["a3-edgegpu-8g", "a3-highgpu-8g"]:
# In the COS image the / file system is not writable.
# /home and /var are writable but not executable.
# Only /etc is both writable and executable, so use it for shim/runner binaries.
# See: https://cloud.google.com/container-optimized-os/docs/concepts/disks-and-filesystem
base_path = bin_path = "/etc"
backend_shim_env = {
# In COS nvidia binaries are not installed on PATH by default.
# Set so that shim can run nvidia-smi.
"PATH": "/var/lib/nvidia/bin:$PATH",
}
return get_user_data(
authorized_keys=authorized_keys,
backend_specific_commands=_get_backend_specific_commands(
instance_type_name=instance_type_name,
),
base_path=base_path,
bin_path=bin_path,
backend_shim_env=backend_shim_env,
)


def _get_backend_specific_commands(instance_type_name: str) -> List[str]:
if instance_type_name == "a3-megagpu-8g":
return tcpx_features.get_backend_specific_commands_tcpxo()
if instance_type_name in ["a3-edgegpu-8g", "a3-highgpu-8g"]:
return tcpx_features.get_backend_specific_commands_tcpx()
return []


Expand Down
Loading
Loading