diff --git a/runner/docs/shim.openapi.yaml b/runner/docs/shim.openapi.yaml index 9cd818b84..bc199181b 100644 --- a/runner/docs/shim.openapi.yaml +++ b/runner/docs/shim.openapi.yaml @@ -270,6 +270,19 @@ components: type: string default: "" description: Mount point inside container + + GPUDevice: + title: shim.GPUDevice + type: object + properties: + path_on_host: + type: string + default: "" + description: Instance (host) path + path_in_container: + type: string + default: "" + description: Path inside container HealthcheckResponse: title: shim.api.HealthcheckResponse @@ -438,6 +451,11 @@ components: items: $ref: "#/components/schemas/InstanceMountPoint" default: [] + gpu_devices: + type: array + items: + $ref: "#/components/schemas/GPUDevice" + default: [] host_ssh_user: type: string default: "" diff --git a/runner/internal/shim/docker.go b/runner/internal/shim/docker.go index 0e037261c..626325c8e 100644 --- a/runner/internal/shim/docker.go +++ b/runner/internal/shim/docker.go @@ -814,7 +814,11 @@ func (d *DockerRunner) createContainer(ctx context.Context, task *Task) error { hostConfig.Resources.NanoCPUs = int64(task.config.CPU * 1000000000) hostConfig.Resources.Memory = task.config.Memory if len(task.gpuIDs) > 0 { - configureGpus(containerConfig, hostConfig, d.gpuVendor, task.gpuIDs) + if len(task.config.GPUDevices) > 0 { + configureGpuDevices(hostConfig, task.config.GPUDevices) + } else { + configureGpus(containerConfig, hostConfig, d.gpuVendor, task.gpuIDs) + } } configureHpcNetworkingIfAvailable(hostConfig) @@ -988,6 +992,19 @@ func getNetworkMode(networkMode NetworkMode) container.NetworkMode { return "default" } +func configureGpuDevices(hostConfig *container.HostConfig, gpuDevices []GPUDevice) { + for _, gpuDevice := range gpuDevices { + hostConfig.Resources.Devices = append( + hostConfig.Resources.Devices, + container.DeviceMapping{ + PathOnHost: gpuDevice.PathOnHost, + PathInContainer: gpuDevice.PathInContainer, + CgroupPermissions: "rwm", + }, + ) + } +} + func configureGpus(config *container.Config, hostConfig *container.HostConfig, vendor host.GpuVendor, ids []string) { // NVIDIA: ids are identifiers reported by nvidia-smi, GPU- strings // AMD: ids are DRI render node paths, e.g., /dev/dri/renderD128 diff --git a/runner/internal/shim/models.go b/runner/internal/shim/models.go index 9ad1f67b8..78c7a4a3e 100644 --- a/runner/internal/shim/models.go +++ b/runner/internal/shim/models.go @@ -70,6 +70,11 @@ type PortMapping struct { Container int `json:"container"` } +type GPUDevice struct { + PathOnHost string `json:"path_on_host"` + PathInContainer string `json:"path_in_container"` +} + type TaskConfig struct { ID string `json:"id"` Name string `json:"name"` @@ -86,8 +91,11 @@ type TaskConfig struct { Volumes []VolumeInfo `json:"volumes"` VolumeMounts []VolumeMountPoint `json:"volume_mounts"` InstanceMounts []InstanceMountPoint `json:"instance_mounts"` - HostSshUser string `json:"host_ssh_user"` - HostSshKeys []string `json:"host_ssh_keys"` + // GPUDevices allows the server to set gpu devices instead of relying on the runner default logic. + // E.g. passing nvidia devices directly instead of using nvidia-container-toolkit. + GPUDevices []GPUDevice `json:"gpu_devices"` + HostSshUser string `json:"host_ssh_user"` + HostSshKeys []string `json:"host_ssh_keys"` // TODO: submit keys to runner, not to shim ContainerSshKeys []string `json:"container_ssh_keys"` } diff --git a/src/dstack/_internal/core/backends/base/compute.py b/src/dstack/_internal/core/backends/base/compute.py index accef8368..471f1aefe 100644 --- a/src/dstack/_internal/core/backends/base/compute.py +++ b/src/dstack/_internal/core/backends/base/compute.py @@ -5,6 +5,7 @@ import threading from abc import ABC, abstractmethod from functools import lru_cache +from pathlib import Path from typing import Dict, List, Optional import git @@ -36,14 +37,12 @@ ) from dstack._internal.core.services import is_valid_dstack_resource_name from dstack._internal.utils.logging import get_logger +from dstack._internal.utils.path import PathLike logger = get_logger(__name__) -DSTACK_WORKING_DIR = "/root/.dstack" DSTACK_SHIM_BINARY_NAME = "dstack-shim" -DSTACK_SHIM_BINARY_PATH = f"/usr/local/bin/{DSTACK_SHIM_BINARY_NAME}" DSTACK_RUNNER_BINARY_NAME = "dstack-runner" -DSTACK_RUNNER_BINARY_PATH = f"/usr/local/bin/{DSTACK_RUNNER_BINARY_NAME}" class Compute(ABC): @@ -336,6 +335,24 @@ def is_volume_detached(self, volume: Volume, instance_id: str) -> bool: return True +def get_dstack_working_dir(base_path: Optional[PathLike] = None) -> str: + if base_path is None: + base_path = "/root" + return str(Path(base_path, ".dstack")) + + +def get_dstack_shim_binary_path(bin_path: Optional[PathLike] = None) -> str: + if bin_path is None: + bin_path = "/usr/local/bin" + return str(Path(bin_path, DSTACK_SHIM_BINARY_NAME)) + + +def get_dstack_runner_binary_path(bin_path: Optional[PathLike] = None) -> str: + if bin_path is None: + bin_path = "/usr/local/bin" + return str(Path(bin_path, DSTACK_RUNNER_BINARY_NAME)) + + def get_job_instance_name(run: Run, job: Job) -> str: return job.job_spec.job_name @@ -442,9 +459,18 @@ def get_cloud_config(**config) -> str: def get_user_data( - authorized_keys: List[str], backend_specific_commands: Optional[List[str]] = None + authorized_keys: List[str], + backend_specific_commands: Optional[List[str]] = None, + base_path: Optional[PathLike] = None, + bin_path: Optional[PathLike] = None, + backend_shim_env: Optional[Dict[str, str]] = None, ) -> str: - shim_commands = get_shim_commands(authorized_keys) + shim_commands = get_shim_commands( + authorized_keys=authorized_keys, + base_path=base_path, + bin_path=bin_path, + backend_shim_env=backend_shim_env, + ) commands = (backend_specific_commands or []) + shim_commands return get_cloud_config( runcmd=[["sh", "-c", " && ".join(commands)]], @@ -452,29 +478,55 @@ def get_user_data( ) -def get_shim_env(authorized_keys: List[str]) -> Dict[str, str]: +def get_shim_env( + authorized_keys: List[str], + base_path: Optional[PathLike] = None, + bin_path: Optional[PathLike] = None, + backend_shim_env: Optional[Dict[str, str]] = None, +) -> Dict[str, str]: log_level = "6" # Trace envs = { - "DSTACK_SHIM_HOME": DSTACK_WORKING_DIR, + "DSTACK_SHIM_HOME": get_dstack_working_dir(base_path), "DSTACK_SHIM_HTTP_PORT": str(DSTACK_SHIM_HTTP_PORT), "DSTACK_SHIM_LOG_LEVEL": log_level, "DSTACK_RUNNER_DOWNLOAD_URL": get_dstack_runner_download_url(), - "DSTACK_RUNNER_BINARY_PATH": DSTACK_RUNNER_BINARY_PATH, + "DSTACK_RUNNER_BINARY_PATH": get_dstack_runner_binary_path(bin_path), "DSTACK_RUNNER_HTTP_PORT": str(DSTACK_RUNNER_HTTP_PORT), "DSTACK_RUNNER_SSH_PORT": str(DSTACK_RUNNER_SSH_PORT), "DSTACK_RUNNER_LOG_LEVEL": log_level, "DSTACK_PUBLIC_SSH_KEY": "\n".join(authorized_keys), } + if backend_shim_env is not None: + envs |= backend_shim_env return envs def get_shim_commands( - authorized_keys: List[str], *, is_privileged: bool = False, pjrt_device: Optional[str] = None + authorized_keys: List[str], + *, + is_privileged: bool = False, + pjrt_device: Optional[str] = None, + base_path: Optional[PathLike] = None, + bin_path: Optional[PathLike] = None, + backend_shim_env: Optional[Dict[str, str]] = None, ) -> List[str]: - commands = get_shim_pre_start_commands() - for k, v in get_shim_env(authorized_keys).items(): + commands = get_shim_pre_start_commands( + base_path=base_path, + bin_path=bin_path, + ) + shim_env = get_shim_env( + authorized_keys=authorized_keys, + base_path=base_path, + bin_path=bin_path, + backend_shim_env=backend_shim_env, + ) + for k, v in shim_env.items(): commands += [f'export "{k}={v}"'] - commands += get_run_shim_script(is_privileged, pjrt_device) + commands += get_run_shim_script( + is_privileged=is_privileged, + pjrt_device=pjrt_device, + bin_path=bin_path, + ) return commands @@ -511,25 +563,33 @@ def get_dstack_shim_download_url() -> str: return f"https://{bucket}.s3.eu-west-1.amazonaws.com/{build}/binaries/dstack-shim-linux-amd64" -def get_shim_pre_start_commands() -> List[str]: +def get_shim_pre_start_commands( + base_path: Optional[PathLike] = None, + bin_path: Optional[PathLike] = None, +) -> List[str]: url = get_dstack_shim_download_url() - + dstack_shim_binary_path = get_dstack_shim_binary_path(bin_path) + dstack_working_dir = get_dstack_working_dir(base_path) return [ f"dlpath=$(sudo mktemp -t {DSTACK_SHIM_BINARY_NAME}.XXXXXXXXXX)", # -sS -- disable progress meter and warnings, but still show errors (unlike bare -s) f'sudo curl -sS --compressed --connect-timeout 60 --max-time 240 --retry 1 --output "$dlpath" "{url}"', - f'sudo mv "$dlpath" {DSTACK_SHIM_BINARY_PATH}', - f"sudo chmod +x {DSTACK_SHIM_BINARY_PATH}", - f"sudo mkdir {DSTACK_WORKING_DIR} -p", + f'sudo mv "$dlpath" {dstack_shim_binary_path}', + f"sudo chmod +x {dstack_shim_binary_path}", + f"sudo mkdir {dstack_working_dir} -p", ] -def get_run_shim_script(is_privileged: bool, pjrt_device: Optional[str]) -> List[str]: +def get_run_shim_script( + is_privileged: bool, + pjrt_device: Optional[str], + bin_path: Optional[PathLike] = None, +) -> List[str]: + dstack_shim_binary_path = get_dstack_shim_binary_path(bin_path) privileged_flag = "--privileged" if is_privileged else "" pjrt_device_env = f"--pjrt-device={pjrt_device}" if pjrt_device else "" - return [ - f"nohup {DSTACK_SHIM_BINARY_PATH} {privileged_flag} {pjrt_device_env} &", + f"nohup {dstack_shim_binary_path} {privileged_flag} {pjrt_device_env} &", ] @@ -555,7 +615,11 @@ def get_gateway_user_data(authorized_key: str) -> str: ) -def get_docker_commands(authorized_keys: list[str]) -> list[str]: +def get_docker_commands( + authorized_keys: list[str], + bin_path: Optional[PathLike] = None, +) -> list[str]: + dstack_runner_binary_path = get_dstack_runner_binary_path(bin_path) authorized_keys_content = "\n".join(authorized_keys).strip() commands = [ # save and unset ld.so variables @@ -606,10 +670,10 @@ def get_docker_commands(authorized_keys: list[str]) -> list[str]: url = get_dstack_runner_download_url() commands += [ - f"curl --connect-timeout 60 --max-time 240 --retry 1 --output {DSTACK_RUNNER_BINARY_PATH} {url}", - f"chmod +x {DSTACK_RUNNER_BINARY_PATH}", + f"curl --connect-timeout 60 --max-time 240 --retry 1 --output {dstack_runner_binary_path} {url}", + f"chmod +x {dstack_runner_binary_path}", ( - f"{DSTACK_RUNNER_BINARY_PATH}" + f"{dstack_runner_binary_path}" " --log-level 6" " start" f" --http-port {DSTACK_RUNNER_HTTP_PORT}" diff --git a/src/dstack/_internal/core/backends/gcp/compute.py b/src/dstack/_internal/core/backends/gcp/compute.py index 8b44866bb..3846e4f4d 100644 --- a/src/dstack/_internal/core/backends/gcp/compute.py +++ b/src/dstack/_internal/core/backends/gcp/compute.py @@ -296,11 +296,9 @@ def create_instance( gpus=instance_offer.instance.resources.gpus, ), spot=instance_offer.instance.resources.spot, - user_data=get_user_data( - authorized_keys, - backend_specific_commands=_get_backend_specific_commands( - instance_offer.instance.name - ), + user_data=_get_user_data( + authorized_keys=authorized_keys, + instance_type_name=instance_offer.instance.name, ), authorized_keys=authorized_keys, labels=labels, @@ -841,10 +839,14 @@ def _get_extra_subnets( ) -> List[Tuple[str, str]]: if config.extra_vpcs is None: return [] - if instance_type_name != "a3-megagpu-8g": + if instance_type_name == "a3-megagpu-8g": + subnets_num = 8 + elif instance_type_name in ["a3-edgegpu-8g", "a3-highgpu-8g"]: + subnets_num = 4 + else: return [] extra_subnets = [] - for vpc_name in config.extra_vpcs: + for vpc_name in config.extra_vpcs[:subnets_num]: subnet = gcp_resources.get_vpc_subnet_or_error( subnetworks_client=subnetworks_client, vpc_project_id=config.vpc_project_id or config.project_id, @@ -856,12 +858,14 @@ def _get_extra_subnets( vpc_name=vpc_name, ) extra_subnets.append((vpc_resource_name, subnet)) - return extra_subnets[:8] + return extra_subnets def _get_image_id(instance_type_name: str, cuda: bool) -> str: if instance_type_name == "a3-megagpu-8g": image_name = "dstack-a3mega-5" + elif instance_type_name in ["a3-edgegpu-8g", "a3-highgpu-8g"]: + return "projects/cos-cloud/global/images/cos-105-17412-535-78" elif cuda: image_name = f"dstack-cuda-{version.base_image}" else: @@ -874,9 +878,37 @@ def _get_gateway_image_id() -> str: return "projects/ubuntu-os-cloud/global/images/ubuntu-2204-jammy-v20230714" +def _get_user_data(authorized_keys: List[str], instance_type_name: str) -> str: + base_path = None + bin_path = None + backend_shim_env = None + if instance_type_name in ["a3-edgegpu-8g", "a3-highgpu-8g"]: + # In the COS image the / file system is not writable. + # /home and /var are writable but not executable. + # Only /etc is both writable and executable, so use it for shim/runner binaries. + # See: https://cloud.google.com/container-optimized-os/docs/concepts/disks-and-filesystem + base_path = bin_path = "/etc" + backend_shim_env = { + # In COS nvidia binaries are not installed on PATH by default. + # Set so that shim can run nvidia-smi. + "PATH": "/var/lib/nvidia/bin:$PATH", + } + return get_user_data( + authorized_keys=authorized_keys, + backend_specific_commands=_get_backend_specific_commands( + instance_type_name=instance_type_name, + ), + base_path=base_path, + bin_path=bin_path, + backend_shim_env=backend_shim_env, + ) + + def _get_backend_specific_commands(instance_type_name: str) -> List[str]: if instance_type_name == "a3-megagpu-8g": return tcpx_features.get_backend_specific_commands_tcpxo() + if instance_type_name in ["a3-edgegpu-8g", "a3-highgpu-8g"]: + return tcpx_features.get_backend_specific_commands_tcpx() return [] diff --git a/src/dstack/_internal/core/backends/gcp/features/tcpx.py b/src/dstack/_internal/core/backends/gcp/features/tcpx.py index 1edaa7440..2d1c34013 100644 --- a/src/dstack/_internal/core/backends/gcp/features/tcpx.py +++ b/src/dstack/_internal/core/backends/gcp/features/tcpx.py @@ -32,3 +32,34 @@ def get_backend_specific_commands_tcpxo() -> List[str]: "--num_hops=2 --num_nics=8 --uid= --alsologtostderr" ), ] + + +def get_backend_specific_commands_tcpx() -> List[str]: + return [ + "cos-extensions install gpu -- --version=latest", + "sudo mount --bind /var/lib/nvidia /var/lib/nvidia", + "sudo mount -o remount,exec /var/lib/nvidia", + ( + "docker run " + "--detach " + "--pull=always " + "--name receive-datapath-manager " + "--privileged " + "--cap-add=NET_ADMIN --network=host " + "--volume /var/lib/nvidia/lib64:/usr/local/nvidia/lib64 " + "--device /dev/nvidia0:/dev/nvidia0 --device /dev/nvidia1:/dev/nvidia1 " + "--device /dev/nvidia2:/dev/nvidia2 --device /dev/nvidia3:/dev/nvidia3 " + "--device /dev/nvidia4:/dev/nvidia4 --device /dev/nvidia5:/dev/nvidia5 " + "--device /dev/nvidia6:/dev/nvidia6 --device /dev/nvidia7:/dev/nvidia7 " + "--device /dev/nvidia-uvm:/dev/nvidia-uvm --device /dev/nvidiactl:/dev/nvidiactl " + "--env LD_LIBRARY_PATH=/usr/local/nvidia/lib64 " + "--volume /run/tcpx:/run/tcpx " + "--entrypoint /tcpgpudmarxd/build/app/tcpgpudmarxd " + "us-docker.pkg.dev/gce-ai-infra/gpudirect-tcpx/tcpgpudmarxd " + '--gpu_nic_preset a3vm --gpu_shmem_type fd --uds_path "/run/tcpx" --setup_param "--verbose 128 2 0"' + ), + "sudo iptables -I INPUT -p tcp -m tcp -j ACCEPT", + "docker run --rm -v /var/lib:/var/lib us-docker.pkg.dev/gce-ai-infra/gpudirect-tcpx/nccl-plugin-gpudirecttcpx install --install-nccl", + "sudo mount --bind /var/lib/tcpx /var/lib/tcpx", + "sudo mount -o remount,exec /var/lib/tcpx", + ] diff --git a/src/dstack/_internal/server/background/tasks/process_instances.py b/src/dstack/_internal/server/background/tasks/process_instances.py index 1e8e275b7..62e4f7cfc 100644 --- a/src/dstack/_internal/server/background/tasks/process_instances.py +++ b/src/dstack/_internal/server/background/tasks/process_instances.py @@ -17,11 +17,11 @@ BACKENDS_WITH_PLACEMENT_GROUPS_SUPPORT, ) from dstack._internal.core.backends.base.compute import ( - DSTACK_RUNNER_BINARY_PATH, - DSTACK_SHIM_BINARY_PATH, - DSTACK_WORKING_DIR, ComputeWithCreateInstanceSupport, ComputeWithPlacementGroupSupport, + get_dstack_runner_binary_path, + get_dstack_shim_binary_path, + get_dstack_working_dir, get_shim_env, get_shim_pre_start_commands, ) @@ -411,23 +411,26 @@ def _deploy_instance( except ValueError as e: raise ProvisioningError(f"Invalid Env: {e}") from e shim_envs.update(fleet_configuration_envs) - upload_envs(client, DSTACK_WORKING_DIR, shim_envs) + dstack_working_dir = get_dstack_working_dir() + dstack_shim_binary_path = get_dstack_shim_binary_path() + dstack_runner_binary_path = get_dstack_runner_binary_path() + upload_envs(client, dstack_working_dir, shim_envs) logger.debug("The dstack-shim environment variables have been installed") # Ensure we have fresh versions of host info.json and dstack-runner - remove_host_info_if_exists(client, DSTACK_WORKING_DIR) - remove_dstack_runner_if_exists(client, DSTACK_RUNNER_BINARY_PATH) + remove_host_info_if_exists(client, dstack_working_dir) + remove_dstack_runner_if_exists(client, dstack_runner_binary_path) # Run dstack-shim as a systemd service run_shim_as_systemd_service( client=client, - binary_path=DSTACK_SHIM_BINARY_PATH, - working_dir=DSTACK_WORKING_DIR, + binary_path=dstack_shim_binary_path, + working_dir=dstack_working_dir, dev=settings.DSTACK_VERSION is None, ) # Get host info - host_info = get_host_info(client, DSTACK_WORKING_DIR) + host_info = get_host_info(client, dstack_working_dir) logger.debug("Received a host_info %s", host_info) raw_health = get_shim_healthcheck(client) diff --git a/src/dstack/_internal/server/background/tasks/process_running_jobs.py b/src/dstack/_internal/server/background/tasks/process_running_jobs.py index 56282332e..51ad15cc6 100644 --- a/src/dstack/_internal/server/background/tasks/process_running_jobs.py +++ b/src/dstack/_internal/server/background/tasks/process_running_jobs.py @@ -40,7 +40,7 @@ RepoModel, RunModel, ) -from dstack._internal.server.schemas.runner import TaskStatus +from dstack._internal.server.schemas.runner import GPUDevice, TaskStatus from dstack._internal.server.services import logs as logs_services from dstack._internal.server.services import services from dstack._internal.server.services.instances import get_instance_ssh_private_keys @@ -438,6 +438,10 @@ def _process_provisioning_with_shim( job_provisioning_data.backend, job_provisioning_data.instance_type.name ) + gpu_devices = _get_instance_specific_gpu_devices( + job_provisioning_data.backend, job_provisioning_data.instance_type.name + ) + container_user = "root" job_runtime_data = get_job_runtime_data(job_model) @@ -471,6 +475,7 @@ def _process_provisioning_with_shim( volumes=volumes, volume_mounts=volume_mounts, instance_mounts=instance_mounts, + gpu_devices=gpu_devices, host_ssh_user=ssh_user, host_ssh_keys=[ssh_key] if ssh_key else [], container_ssh_keys=public_keys, @@ -834,14 +839,60 @@ def _submit_job_to_runner( def _get_instance_specific_mounts( backend_type: BackendType, instance_type_name: str ) -> List[InstanceMountPoint]: - if backend_type == BackendType.GCP and instance_type_name == "a3-megagpu-8g": - return [ - InstanceMountPoint( - instance_path="/dev/aperture_devices", path="/dev/aperture_devices" - ), - InstanceMountPoint(instance_path="/var/lib/tcpxo/lib64", path="/var/lib/tcpxo/lib64"), - InstanceMountPoint( - instance_path="/var/lib/fastrak/lib64", path="/var/lib/fastrak/lib64" - ), - ] + if backend_type == BackendType.GCP: + if instance_type_name == "a3-megagpu-8g": + return [ + InstanceMountPoint( + instance_path="/dev/aperture_devices", + path="/dev/aperture_devices", + ), + InstanceMountPoint( + instance_path="/var/lib/tcpxo/lib64", + path="/var/lib/tcpxo/lib64", + ), + InstanceMountPoint( + instance_path="/var/lib/fastrak/lib64", + path="/var/lib/fastrak/lib64", + ), + ] + if instance_type_name in ["a3-edgegpu-8g", "a3-highgpu-8g"]: + return [ + InstanceMountPoint( + instance_path="/var/lib/nvidia/lib64", + path="/usr/local/nvidia/lib64", + ), + InstanceMountPoint( + instance_path="/var/lib/nvidia/bin", + path="/usr/local/nvidia/bin", + ), + InstanceMountPoint( + instance_path="/var/lib/tcpx/lib64", + path="/usr/local/tcpx/lib64", + ), + InstanceMountPoint( + instance_path="/run/tcpx", + path="/run/tcpx", + ), + ] return [] + + +def _get_instance_specific_gpu_devices( + backend_type: BackendType, instance_type_name: str +) -> List[GPUDevice]: + gpu_devices = [] + if backend_type == BackendType.GCP and instance_type_name in [ + "a3-edgegpu-8g", + "a3-highgpu-8g", + ]: + for i in range(8): + gpu_devices.append( + GPUDevice(path_on_host=f"/dev/nvidia{i}", path_in_container=f"/dev/nvidia{i}") + ) + gpu_devices.append( + GPUDevice(path_on_host="/dev/nvidia-uvm", path_in_container="/dev/nvidia-uvm") + ) + gpu_devices.append( + GPUDevice(path_on_host="/dev/nvidiactl", path_in_container="/dev/nvidiactl") + ) + return gpu_devices diff --git a/src/dstack/_internal/server/schemas/runner.py b/src/dstack/_internal/server/schemas/runner.py index e0c69c750..6f8f91b96 100644 --- a/src/dstack/_internal/server/schemas/runner.py +++ b/src/dstack/_internal/server/schemas/runner.py @@ -114,6 +114,11 @@ class TaskStatus(str, Enum): TERMINATED = "terminated" +class GPUDevice(CoreModel): + path_on_host: str + path_in_container: str + + class TaskInfoResponse(CoreModel): id: str status: TaskStatus @@ -139,6 +144,7 @@ class TaskSubmitRequest(CoreModel): volumes: list[ShimVolumeInfo] volume_mounts: list[VolumeMountPoint] instance_mounts: list[InstanceMountPoint] + gpu_devices: list[GPUDevice] host_ssh_user: str host_ssh_keys: list[str] container_ssh_keys: list[str] diff --git a/src/dstack/_internal/server/services/runner/client.py b/src/dstack/_internal/server/services/runner/client.py index bc53dd55d..ac9641622 100644 --- a/src/dstack/_internal/server/services/runner/client.py +++ b/src/dstack/_internal/server/services/runner/client.py @@ -15,6 +15,7 @@ from dstack._internal.core.models.runs import ClusterInfo, JobSpec, RunSpec from dstack._internal.core.models.volumes import InstanceMountPoint, Volume, VolumeMountPoint from dstack._internal.server.schemas.runner import ( + GPUDevice, HealthcheckResponse, LegacyPullResponse, LegacyStopBody, @@ -233,6 +234,7 @@ def submit_task( volumes: list[Volume], volume_mounts: list[VolumeMountPoint], instance_mounts: list[InstanceMountPoint], + gpu_devices: list[GPUDevice], host_ssh_user: str, host_ssh_keys: list[str], container_ssh_keys: list[str], @@ -256,6 +258,7 @@ def submit_task( volumes=[_volume_to_shim_volume_info(v, instance_id) for v in volumes], volume_mounts=volume_mounts, instance_mounts=instance_mounts, + gpu_devices=gpu_devices, host_ssh_user=host_ssh_user, host_ssh_keys=host_ssh_keys, container_ssh_keys=container_ssh_keys, diff --git a/src/tests/_internal/server/background/tasks/test_process_running_jobs.py b/src/tests/_internal/server/background/tasks/test_process_running_jobs.py index 3f58015e1..52afca62a 100644 --- a/src/tests/_internal/server/background/tasks/test_process_running_jobs.py +++ b/src/tests/_internal/server/background/tasks/test_process_running_jobs.py @@ -340,6 +340,7 @@ async def test_provisioning_shim_with_volumes( volumes=[volume_model_to_volume(volume)], volume_mounts=[VolumeMountPoint(name="my-vol", path="/volume")], instance_mounts=[InstanceMountPoint(instance_path="/root/.cache", path="/cache")], + gpu_devices=[], host_ssh_user="ubuntu", host_ssh_keys=["user_ssh_key"], container_ssh_keys=[project_ssh_pub_key, "user_ssh_key"], diff --git a/src/tests/_internal/server/services/runner/test_client.py b/src/tests/_internal/server/services/runner/test_client.py index b7e1fde45..e68a007cf 100644 --- a/src/tests/_internal/server/services/runner/test_client.py +++ b/src/tests/_internal/server/services/runner/test_client.py @@ -331,6 +331,7 @@ def test_submit_task(self, client: ShimClient, adapter: requests_mock.Adapter): volumes=[volume], volume_mounts=[VolumeMountPoint(name="vol", path="/vol")], instance_mounts=[InstanceMountPoint(instance_path="/mnt/nfs/home", path="/home")], + gpu_devices=[], host_ssh_user="dstack", host_ssh_keys=["host_key"], container_ssh_keys=["project_key", "user_key"], @@ -365,6 +366,7 @@ def test_submit_task(self, client: ShimClient, adapter: requests_mock.Adapter): "instance_mounts": [ {"instance_path": "/mnt/nfs/home", "path": "/home", "optional": False} ], + "gpu_devices": [], "host_ssh_user": "dstack", "host_ssh_keys": ["host_key"], "container_ssh_keys": ["project_key", "user_key"],