|
1 | 1 | # -*- coding: utf-8 -*-
|
2 | 2 |
|
3 |
| -# Copyright 2023 Google LLC |
| 3 | +# Copyright 2024 Google LLC |
4 | 4 | #
|
5 | 5 | # Licensed under the Apache License, Version 2.0 (the "License");
|
6 | 6 | # you may not use this file except in compliance with the License.
|
|
16 | 16 | #
|
17 | 17 | import grpc
|
18 | 18 | import logging
|
| 19 | +import ray |
| 20 | + |
19 | 21 | from typing import Dict
|
20 | 22 | from typing import Optional
|
21 | 23 | from google.cloud import aiplatform
|
@@ -45,16 +47,30 @@ def __init__(
|
45 | 47 | persistent_resource_id,
|
46 | 48 | " failed to start Head node properly.",
|
47 | 49 | )
|
48 |
| - |
49 |
| - super().__init__( |
50 |
| - dashboard_url=dashboard_uri, |
51 |
| - python_version=ray_client_context.python_version, |
52 |
| - ray_version=ray_client_context.ray_version, |
53 |
| - ray_commit=ray_client_context.ray_commit, |
54 |
| - protocol_version=ray_client_context.protocol_version, |
55 |
| - _num_clients=ray_client_context._num_clients, |
56 |
| - _context_to_restore=ray_client_context._context_to_restore, |
57 |
| - ) |
| 50 | + if ray.__version__ == "2.33.0": |
| 51 | + super().__init__( |
| 52 | + dashboard_url=dashboard_uri, |
| 53 | + python_version=ray_client_context.python_version, |
| 54 | + ray_version=ray_client_context.ray_version, |
| 55 | + ray_commit=ray_client_context.ray_commit, |
| 56 | + _num_clients=ray_client_context._num_clients, |
| 57 | + _context_to_restore=ray_client_context._context_to_restore, |
| 58 | + ) |
| 59 | + elif ray.__version__ == "2.9.3": |
| 60 | + super().__init__( |
| 61 | + dashboard_url=dashboard_uri, |
| 62 | + python_version=ray_client_context.python_version, |
| 63 | + ray_version=ray_client_context.ray_version, |
| 64 | + ray_commit=ray_client_context.ray_commit, |
| 65 | + protocol_version=ray_client_context.protocol_version, |
| 66 | + _num_clients=ray_client_context._num_clients, |
| 67 | + _context_to_restore=ray_client_context._context_to_restore, |
| 68 | + ) |
| 69 | + else: |
| 70 | + raise ImportError( |
| 71 | + f"[Ray on Vertex AI]: Unsupported version {ray.__version__}." |
| 72 | + + "Only 2.33.0 and 2.9.3 are supported." |
| 73 | + ) |
58 | 74 | self.persistent_resource_id = persistent_resource_id
|
59 | 75 | self.vertex_sdk_version = str(VERTEX_SDK_VERSION)
|
60 | 76 | self.shell_uri = ray_head_uris.get("RAY_HEAD_NODE_INTERACTIVE_SHELL_URI")
|
|
0 commit comments