Skip to content

Commit 77a741e

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Also support unhashable objects to be serialized with extra args
PiperOrigin-RevId: 577998940
1 parent 1e4a4ec commit 77a741e

File tree

7 files changed

+180
-21
lines changed

7 files changed

+180
-21
lines changed

tests/unit/vertexai/test_remote_training.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -972,7 +972,9 @@ def test_remote_training_sklearn_with_remote_configs(
972972
_TEST_TRAINING_CONFIG_CONTAINER_URI
973973
)
974974
model.fit.vertex.remote_config.machine_type = _TEST_TRAINING_CONFIG_MACHINE_TYPE
975-
model.fit.vertex.remote_config.serializer_args = {model: {"extra_params": 1}}
975+
model.fit.vertex.remote_config.serializer_args[model] = {"extra_params": 1}
976+
# X_TRAIN is a numpy array that is not hashable.
977+
model.fit.vertex.remote_config.serializer_args[_X_TRAIN] = {"extra_params": 2}
976978

977979
model.fit(_X_TRAIN, _Y_TRAIN)
978980

@@ -991,7 +993,7 @@ def test_remote_training_sklearn_with_remote_configs(
991993
mock_any_serializer_sklearn.return_value.serialize.assert_any_call(
992994
to_serialize=_X_TRAIN,
993995
gcs_path=os.path.join(remote_job_base_path, "input/X"),
994-
**{},
996+
**{"extra_params": 2},
995997
)
996998
mock_any_serializer_sklearn.return_value.serialize.assert_any_call(
997999
to_serialize=_Y_TRAIN,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# -*- coding: utf-8 -*-
2+
# Copyright 2023 Google LLC
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#
16+
17+
from vertexai.preview._workflow.serialization_engine import (
18+
serializers_base,
19+
)
20+
21+
22+
class TestSerializerArgs:
23+
def test_object_id_is_saved(self):
24+
class TestClass:
25+
pass
26+
27+
test_obj = TestClass()
28+
serializer_args = serializers_base.SerializerArgs({test_obj: {"a": 1, "b": 2}})
29+
assert id(test_obj) in serializer_args
30+
assert test_obj not in serializer_args
31+
32+
def test_getitem_support_original_object(self):
33+
class TestClass:
34+
pass
35+
36+
test_obj = TestClass()
37+
serializer_args = serializers_base.SerializerArgs({test_obj: {"a": 1, "b": 2}})
38+
assert serializer_args[test_obj] == {"a": 1, "b": 2}
39+
40+
def test_get_support_original_object(self):
41+
class TestClass:
42+
pass
43+
44+
test_obj = TestClass()
45+
serializer_args = serializers_base.SerializerArgs({test_obj: {"a": 1, "b": 2}})
46+
assert serializer_args.get(test_obj) == {"a": 1, "b": 2}
47+
48+
def test_unhashable_obj_saved_successfully(self):
49+
unhashable = [1, 2, 3]
50+
serializer_args = serializers_base.SerializerArgs()
51+
serializer_args[unhashable] = {"a": 1, "b": 2}
52+
assert id(unhashable) in serializer_args
53+
54+
def test_getitem_support_original_unhashable(self):
55+
unhashable = [1, 2, 3]
56+
serializer_args = serializers_base.SerializerArgs()
57+
serializer_args[unhashable] = {"a": 1, "b": 2}
58+
assert serializer_args[unhashable] == {"a": 1, "b": 2}
59+
60+
def test_get_support_original_unhashable(self):
61+
unhashable = [1, 2, 3]
62+
serializers_args = serializers_base.SerializerArgs()
63+
serializers_args[unhashable] = {"a": 1, "b": 2}
64+
assert serializers_args.get(unhashable) == {"a": 1, "b": 2}

vertexai/preview/_workflow/executor/training.py

+5-7
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import re
2222
import sys
2323
import time
24-
from typing import Any, Dict, List, Optional, Set, Tuple, Union, Hashable
24+
from typing import Any, Dict, List, Optional, Set, Tuple, Union
2525
import warnings
2626

2727
from google.api_core import exceptions as api_exceptions
@@ -495,6 +495,8 @@ def remote_training(invokable: shared._Invokable, rewrapper: Any):
495495
bound_args = invokable.bound_arguments
496496
config = invokable.vertex_config.remote_config
497497
serializer_args = invokable.vertex_config.remote_config.serializer_args
498+
if not isinstance(serializer_args, serializers_base.SerializerArgs):
499+
raise ValueError("serializer_args must be an instance of SerializerArgs.")
498500

499501
autolog = vertexai.preview.global_config.autolog
500502
service_account = _get_service_account(config, autolog=autolog)
@@ -609,17 +611,13 @@ def remote_training(invokable: shared._Invokable, rewrapper: Any):
609611
to_serialize=arg_value,
610612
gcs_path=os.path.join(remote_job_input_path, f"{arg_name}"),
611613
framework=detected_framework,
612-
**serializer_args.get(arg_value, {})
613-
if isinstance(arg_value, Hashable)
614-
else {},
614+
**serializer_args.get(arg_value, {}),
615615
)
616616
else:
617617
serialization_metadata = serializer.serialize(
618618
to_serialize=arg_value,
619619
gcs_path=os.path.join(remote_job_input_path, f"{arg_name}"),
620-
**serializer_args.get(arg_value, {})
621-
if isinstance(arg_value, Hashable)
622-
else {},
620+
**serializer_args.get(arg_value, {}),
623621
)
624622
# serializer.get_dependencies() must be run after serializer.serialize()
625623
requirements += serialization_metadata[

vertexai/preview/_workflow/serialization_engine/serializers_base.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from typing import Any, Dict, List, Optional, Type, TypeVar, Union
2626

2727
from google.cloud.aiplatform.utils import gcs_utils
28-
28+
from vertexai.preview._workflow.shared import data_structures
2929

3030
T = TypeVar("T")
3131
SERIALIZATION_METADATA_FILENAME = "serialization_metadata"
@@ -34,6 +34,9 @@
3434
SERIALIZATION_METADATA_CUSTOM_COMMANDS_KEY = "custom_commands"
3535

3636

37+
SerializerArgs = data_structures.IdAsKeyDict
38+
39+
3740
@dataclasses.dataclass
3841
class SerializationMetadata:
3942
"""Metadata of Serializer classes.

vertexai/preview/_workflow/shared/configs.py

+33-11
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@
1515
# limitations under the License.
1616
#
1717
import dataclasses
18-
from typing import List, Optional, Dict, Any
18+
from typing import List, Optional
19+
from vertexai.preview._workflow.serialization_engine import (
20+
serializers_base,
21+
)
1922

2023

2124
@dataclasses.dataclass
@@ -72,16 +75,33 @@ class RemoteConfig(_BaseConfig):
7275
]
7376
7477
# Specify the extra parameters needed for serializing objects.
75-
model.train.vertex.remote_config.serializer_args = {
76-
model: {
77-
"extra_serializer_param1_for_model": param1_value,
78-
"extra_serializer_param2_for_model": param2_value,
78+
from vertexai.preview.developer import SerializerArgs
79+
80+
# You can put all the hashable objects with their arguments in the
81+
# SerializerArgs all at once in a dict. Here we assume "model" is
82+
# hashable.
83+
model.train.vertex.remote_config.serializer_args = SerializerArgs({
84+
model: {
85+
"extra_serializer_param1_for_model": param1_value,
86+
"extra_serializer_param2_for_model": param2_value,
87+
},
88+
hashable_obj2: {
89+
"extra_serializer_param1_for_hashable2": param1_value,
90+
"extra_serializer_param2_for_hashable2": param2_value,
91+
},
92+
})
93+
# Or if the object to be serialized is unhashable, put them into the
94+
# serializer_args one by one. If this is the only use case, there is
95+
# no need to import `SerializerArgs`. Here we assume "X_train" and
96+
# "y_train" is not hashable.
97+
model.train.vertex.remote_config.serializer_args[X_train] = {
98+
"extra_serializer_param1_for_X_train": param1_value,
99+
"extra_serializer_param2_for_X_train": param2_value,
79100
},
80-
X_train: {
81-
"extra_serializer_param1": param1_value,
82-
"extra_serializer_param2": param2_value,
101+
model.train.vertex.remote_config.serializer_args[y_train] = {
102+
"extra_serializer_param1_for_y_train": param1_value,
103+
"extra_serializer_param2_for_y_train": param2_value,
83104
}
84-
}
85105
86106
# Train the model as usual
87107
model.train(X_train, y_train)
@@ -132,7 +152,7 @@ class RemoteConfig(_BaseConfig):
132152
custom_commands (List[str]):
133153
List of custom commands to be run in the remote job environment.
134154
These commands will be run before the requirements are installed.
135-
serializer_args (Dict[Any, Dict[str, Any]]):
155+
serializer_args: serializers_base.SerializerArgs:
136156
Map from object to extra arguments when serializing the object. The extra
137157
arguments is a dictionary from the argument names to the argument values.
138158
"""
@@ -143,7 +163,9 @@ class RemoteConfig(_BaseConfig):
143163
service_account: Optional[str] = None
144164
requirements: List[str] = dataclasses.field(default_factory=list)
145165
custom_commands: List[str] = dataclasses.field(default_factory=list)
146-
serializer_args: Dict[Any, Dict[str, Any]] = dataclasses.field(default_factory=dict)
166+
serializer_args: serializers_base.SerializerArgs = dataclasses.field(
167+
default_factory=serializers_base.SerializerArgs
168+
)
147169

148170

149171
@dataclasses.dataclass
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# -*- coding: utf-8 -*-
2+
3+
# Copyright 2023 Google LLC
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
19+
class IdAsKeyDict(dict):
20+
"""Customized dict that maps each key to its id before storing the data.
21+
22+
This subclass of dict still allows one to use the original key during
23+
subscription ([] operator) or via `get()` method. But under the hood, the
24+
keys are the ids of the original keys.
25+
26+
Example:
27+
# add some hashable objects (key1 and key2) to the dict
28+
id_as_key_dict = IdAsKeyDict({key1: value1, key2: value2})
29+
# add a unhashable object (key3) to the dict
30+
id_as_key_dict[key3] = value3
31+
32+
# can access the value via subscription using the original key
33+
assert id_as_key_dict[key1] == value1
34+
assert id_as_key_dict[key2] == value2
35+
assert id_as_key_dict[key3] == value3
36+
# can access the value via get method using the original key
37+
assert id_as_key_dict.get(key1) == value1
38+
assert id_as_key_dict.get(key2) == value2
39+
assert id_as_key_dict.get(key3) == value3
40+
# but the original keys are not in the dict - the ids are
41+
assert id(key1) in id_as_key_dict
42+
assert id(key2) in id_as_key_dict
43+
assert id(key3) in id_as_key_dict
44+
assert key1 not in id_as_key_dict
45+
assert key2 not in id_as_key_dict
46+
assert key3 not in id_as_key_dict
47+
"""
48+
49+
def __init__(self, *args, **kwargs):
50+
internal_dict = {}
51+
for arg in args:
52+
for k, v in arg.items():
53+
internal_dict[id(k)] = v
54+
for k, v in kwargs.items():
55+
internal_dict[id(k)] = v
56+
super().__init__(internal_dict)
57+
58+
def __getitem__(self, _key):
59+
internal_key = id(_key)
60+
return super().__getitem__(internal_key)
61+
62+
def __setitem__(self, _key, _value):
63+
internal_key = id(_key)
64+
return super().__setitem__(internal_key, _value)
65+
66+
def get(self, key, default=None):
67+
internal_key = id(key)
68+
return super().get(internal_key, default)

vertexai/preview/developer/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
PersistentResourceConfig = configs.PersistentResourceConfig
3030
Serializer = serializers_base.Serializer
3131
SerializationMetadata = serializers_base.SerializationMetadata
32+
SerializerArgs = serializers_base.SerializerArgs
3233
RemoteConfig = configs.RemoteConfig
3334
WorkerPoolSpec = remote_specs.WorkerPoolSpec
3435
WorkerPoolSepcs = remote_specs.WorkerPoolSpecs
@@ -41,6 +42,7 @@
4142
"PersistentResourceConfig",
4243
"register_serializer",
4344
"Serializer",
45+
"SerializerArgs",
4446
"SerializationMetadata",
4547
"RemoteConfig",
4648
"WorkerPoolSpec",

0 commit comments

Comments
 (0)