|
| 1 | +import dask |
| 2 | + |
| 3 | +from dask_cloudprovider.generic.vmcluster import ( |
| 4 | + VMCluster, |
| 5 | + VMInterface, |
| 6 | + SchedulerMixin, |
| 7 | + WorkerMixin, |
| 8 | +) |
| 9 | + |
| 10 | +try: |
| 11 | + from nebius.api.nebius.common.v1 import ResourceMetadata |
| 12 | + from nebius.api.nebius.vpc.v1 import SubnetServiceClient, ListSubnetsRequest |
| 13 | + from nebius.sdk import SDK |
| 14 | + from nebius.api.nebius.compute.v1 import ( |
| 15 | + InstanceServiceClient, |
| 16 | + CreateInstanceRequest, |
| 17 | + DiskServiceClient, |
| 18 | + CreateDiskRequest, |
| 19 | + DiskSpec, |
| 20 | + SourceImageFamily, |
| 21 | + InstanceSpec, |
| 22 | + AttachedDiskSpec, |
| 23 | + ExistingDisk, |
| 24 | + ResourcesSpec, |
| 25 | + NetworkInterfaceSpec, |
| 26 | + IPAddress, |
| 27 | + PublicIPAddress, |
| 28 | + GetInstanceRequest, |
| 29 | + DeleteInstanceRequest, |
| 30 | + DeleteDiskRequest, |
| 31 | + ) |
| 32 | +except ImportError as e: |
| 33 | + msg = ( |
| 34 | + "Dask Cloud Provider Nebius requirements are not installed.\n\n" |
| 35 | + "Please pip install as follows:\n\n" |
| 36 | + ' pip install "dask-cloudprovider[nebius]" --upgrade # or python -m pip install' |
| 37 | + ) |
| 38 | + raise ImportError(msg) from e |
| 39 | + |
| 40 | + |
| 41 | +class NebiusInstance(VMInterface): |
| 42 | + def __init__( |
| 43 | + self, |
| 44 | + cluster: str, |
| 45 | + config, |
| 46 | + env_vars: dict = None, |
| 47 | + bootstrap=None, |
| 48 | + extra_bootstrap=None, |
| 49 | + docker_image: str = None, |
| 50 | + image_family: str = None, |
| 51 | + project_id: str = None, |
| 52 | + server_platform: str = None, |
| 53 | + server_preset: str = None, |
| 54 | + disk_size: int = None, |
| 55 | + *args, |
| 56 | + **kwargs, |
| 57 | + ): |
| 58 | + super().__init__(*args, **kwargs) |
| 59 | + self.cluster = cluster |
| 60 | + self.config = config |
| 61 | + self.extra_bootstrap = extra_bootstrap |
| 62 | + self.env_vars = env_vars |
| 63 | + self.bootstrap = bootstrap |
| 64 | + self.image_family = image_family |
| 65 | + self.project_id = project_id |
| 66 | + self.docker_image = docker_image |
| 67 | + self.server_platform = server_platform |
| 68 | + self.server_preset = server_preset |
| 69 | + self.sdk = SDK(credentials=self.config.get("token")) |
| 70 | + self.disk_size = disk_size |
| 71 | + self.instance_id = None |
| 72 | + self.disk_id = None |
| 73 | + |
| 74 | + async def create_vm(self, user_data=None): |
| 75 | + service = DiskServiceClient(self.sdk) |
| 76 | + operation = await service.create( |
| 77 | + CreateDiskRequest( |
| 78 | + metadata=ResourceMetadata( |
| 79 | + parent_id=self.project_id, |
| 80 | + name=self.name + "-disk", |
| 81 | + ), |
| 82 | + spec=DiskSpec( |
| 83 | + source_image_family=SourceImageFamily( |
| 84 | + image_family=self.image_family |
| 85 | + ), |
| 86 | + size_gibibytes=self.disk_size, |
| 87 | + type=DiskSpec.DiskType.NETWORK_SSD, |
| 88 | + ), |
| 89 | + ) |
| 90 | + ) |
| 91 | + await operation.wait() |
| 92 | + self.disk_id = operation.resource_id |
| 93 | + |
| 94 | + service = SubnetServiceClient(self.sdk) |
| 95 | + sub_net = await service.list(ListSubnetsRequest(parent_id=self.project_id)) |
| 96 | + subnet_id = sub_net.items[0].metadata.id |
| 97 | + |
| 98 | + service = InstanceServiceClient(self.sdk) |
| 99 | + operation = await service.create( |
| 100 | + CreateInstanceRequest( |
| 101 | + metadata=ResourceMetadata( |
| 102 | + parent_id=self.project_id, |
| 103 | + name=self.name, |
| 104 | + ), |
| 105 | + spec=InstanceSpec( |
| 106 | + boot_disk=AttachedDiskSpec( |
| 107 | + attach_mode=AttachedDiskSpec.AttachMode(2), |
| 108 | + existing_disk=ExistingDisk(id=self.disk_id), |
| 109 | + ), |
| 110 | + cloud_init_user_data=self.cluster.render_process_cloud_init(self), |
| 111 | + resources=ResourcesSpec( |
| 112 | + platform=self.server_platform, preset=self.server_preset |
| 113 | + ), |
| 114 | + network_interfaces=[ |
| 115 | + NetworkInterfaceSpec( |
| 116 | + subnet_id=subnet_id, |
| 117 | + ip_address=IPAddress(), |
| 118 | + name="network-interface-0", |
| 119 | + public_ip_address=PublicIPAddress(), |
| 120 | + ) |
| 121 | + ], |
| 122 | + ), |
| 123 | + ) |
| 124 | + ) |
| 125 | + self.instance_id = operation.resource_id |
| 126 | + |
| 127 | + self.cluster._log(f"Creating Nebius instance {self.name}") |
| 128 | + await operation.wait() |
| 129 | + service = InstanceServiceClient(self.sdk) |
| 130 | + operation = await service.get( |
| 131 | + GetInstanceRequest( |
| 132 | + id=self.instance_id, |
| 133 | + ) |
| 134 | + ) |
| 135 | + internal_ip = operation.status.network_interfaces[0].ip_address.address.split( |
| 136 | + "/" |
| 137 | + )[0] |
| 138 | + external_ip = operation.status.network_interfaces[ |
| 139 | + 0 |
| 140 | + ].public_ip_address.address.split("/")[0] |
| 141 | + self.cluster._log( |
| 142 | + f"Created Nebius instance {self.name} with internal IP {internal_ip} and external IP {external_ip}" |
| 143 | + ) |
| 144 | + return internal_ip, external_ip |
| 145 | + |
| 146 | + async def destroy_vm(self): |
| 147 | + if self.instance_id: |
| 148 | + service = InstanceServiceClient(self.sdk) |
| 149 | + operation = await service.delete( |
| 150 | + DeleteInstanceRequest( |
| 151 | + id=self.instance_id, |
| 152 | + ) |
| 153 | + ) |
| 154 | + await operation.wait() |
| 155 | + |
| 156 | + if self.disk_id: |
| 157 | + service = DiskServiceClient(self.sdk) |
| 158 | + await service.delete( |
| 159 | + DeleteDiskRequest( |
| 160 | + id=self.disk_id, |
| 161 | + ) |
| 162 | + ) |
| 163 | + self.cluster._log( |
| 164 | + f"Terminated instance {self.name} ({self.instance_id}) and deleted disk {self.disk_id}" |
| 165 | + ) |
| 166 | + self.instance_id = None |
| 167 | + self.disk_id = None |
| 168 | + |
| 169 | + |
| 170 | +class NebiusScheduler(SchedulerMixin, NebiusInstance): |
| 171 | + """Scheduler running on a Nebius server.""" |
| 172 | + |
| 173 | + |
| 174 | +class NebiusWorker(WorkerMixin, NebiusInstance): |
| 175 | + """Worker running on a Nebius server.""" |
| 176 | + |
| 177 | + |
| 178 | +class NebiusCluster(VMCluster): |
| 179 | + """Cluster running on Nebius AI Cloud instances. |
| 180 | +
|
| 181 | + VMs in Nebius AI Cloud are referred to as instances. This cluster manager constructs a Dask cluster |
| 182 | + running on VMs. |
| 183 | +
|
| 184 | + When configuring your cluster you may find it useful to install the ``nebius`` tool for querying the |
| 185 | + Nebius API for available options. |
| 186 | +
|
| 187 | + https://docs.nebius.com/cli/quickstart |
| 188 | +
|
| 189 | + Parameters |
| 190 | + ---------- |
| 191 | + image_family: str |
| 192 | + The image to use for the host OS. This should be a Ubuntu variant. |
| 193 | + You find list available images here https://docs.nebius.com/compute/storage/manage#parameters-boot. |
| 194 | + project_id: str |
| 195 | + The Nebius AI Cloud project id. You can find in Nebius AI Cloud console. |
| 196 | + server_platform: str |
| 197 | + List of all platforms and presets here https://docs.nebius.com/compute/virtual-machines/types/. |
| 198 | + server_preset: str |
| 199 | + List of all platforms and presets here https://docs.nebius.com/compute/virtual-machines/types/. |
| 200 | + n_workers: int |
| 201 | + Number of workers to initialise the cluster with. Defaults to ``0``. |
| 202 | + worker_module: str |
| 203 | + The Python module to run for the worker. Defaults to ``distributed.cli.dask_worker`` |
| 204 | + worker_options: dict |
| 205 | + Params to be passed to the worker class. |
| 206 | + See :class:`distributed.worker.Worker` for default worker class. |
| 207 | + If you set ``worker_module`` then refer to the docstring for the custom worker class. |
| 208 | + scheduler_options: dict |
| 209 | + Params to be passed to the scheduler class. |
| 210 | + See :class:`distributed.scheduler.Scheduler`. |
| 211 | + env_vars: dict |
| 212 | + Environment variables to be passed to the worker. |
| 213 | + extra_bootstrap: list[str] (optional) |
| 214 | + Extra commands to be run during the bootstrap phase. |
| 215 | +
|
| 216 | + Example |
| 217 | + -------- |
| 218 | +
|
| 219 | + >>> from dask_cloudprovider.nebius import NebiusCluster |
| 220 | + >>> cluster = NebiusCluster(n_workers=1) |
| 221 | +
|
| 222 | + >>> from dask.distributed import Client |
| 223 | + >>> client = Client(cluster) |
| 224 | +
|
| 225 | + >>> import dask.array as da |
| 226 | + >>> arr = da.random.random((1000, 1000), chunks=(100, 100)) |
| 227 | + >>> arr.mean().compute() |
| 228 | +
|
| 229 | + >>> client.close() |
| 230 | + >>> cluster.close() |
| 231 | +
|
| 232 | + """ |
| 233 | + |
| 234 | + def __init__( |
| 235 | + self, |
| 236 | + bootstrap: str = None, |
| 237 | + image_family: str = None, |
| 238 | + project_id: str = None, |
| 239 | + disk_size: int = None, |
| 240 | + server_platform: str = None, |
| 241 | + server_preset: str = None, |
| 242 | + docker_image: str = None, |
| 243 | + debug: bool = False, |
| 244 | + **kwargs, |
| 245 | + ): |
| 246 | + self.config = dask.config.get("cloudprovider.nebius", {}) |
| 247 | + |
| 248 | + self.scheduler_class = NebiusScheduler |
| 249 | + self.worker_class = NebiusWorker |
| 250 | + |
| 251 | + self.image_family = dask.config.get( |
| 252 | + "cloudprovider.nebius.image_family", override_with=image_family |
| 253 | + ) |
| 254 | + self.docker_image = dask.config.get( |
| 255 | + "cloudprovider.nebius.docker_image", override_with=docker_image |
| 256 | + ) |
| 257 | + self.project_id = dask.config.get( |
| 258 | + "cloudprovider.nebius.project_id", override_with=project_id |
| 259 | + ) |
| 260 | + self.server_platform = dask.config.get( |
| 261 | + "cloudprovider.nebius.server_platform", override_with=server_platform |
| 262 | + ) |
| 263 | + self.server_preset = dask.config.get( |
| 264 | + "cloudprovider.nebius.server_preset", override_with=server_preset |
| 265 | + ) |
| 266 | + self.bootstrap = dask.config.get( |
| 267 | + "cloudprovider.nebius.bootstrap", override_with=bootstrap |
| 268 | + ) |
| 269 | + self.disk_size = dask.config.get( |
| 270 | + "cloudprovider.nebius.disk_size", override_with=disk_size |
| 271 | + ) |
| 272 | + self.debug = debug |
| 273 | + |
| 274 | + self.options = { |
| 275 | + "bootstrap": self.bootstrap, |
| 276 | + "cluster": self, |
| 277 | + "config": self.config, |
| 278 | + "docker_image": self.docker_image, |
| 279 | + "image_family": self.image_family, |
| 280 | + "project_id": self.project_id, |
| 281 | + "server_platform": self.server_platform, |
| 282 | + "server_preset": self.server_preset, |
| 283 | + "disk_size": self.disk_size, |
| 284 | + } |
| 285 | + self.scheduler_options = {**self.options} |
| 286 | + self.worker_options = {**self.options} |
| 287 | + super().__init__(debug=debug, **kwargs) |
0 commit comments