15
15
# limitations under the License.
16
16
#
17
17
import dataclasses
18
- from typing import List , Optional , Dict , Any
18
+ from typing import List , Optional
19
+ from vertexai .preview ._workflow .serialization_engine import (
20
+ serializers_base ,
21
+ )
19
22
20
23
21
24
@dataclasses .dataclass
@@ -72,16 +75,33 @@ class RemoteConfig(_BaseConfig):
72
75
]
73
76
74
77
# Specify the extra parameters needed for serializing objects.
75
- model.train.vertex.remote_config.serializer_args = {
76
- model: {
77
- "extra_serializer_param1_for_model": param1_value,
78
- "extra_serializer_param2_for_model": param2_value,
78
+ from vertexai.preview.developer import SerializerArgs
79
+
80
+ # You can put all the hashable objects with their arguments in the
81
+ # SerializerArgs all at once in a dict. Here we assume "model" is
82
+ # hashable.
83
+ model.train.vertex.remote_config.serializer_args = SerializerArgs({
84
+ model: {
85
+ "extra_serializer_param1_for_model": param1_value,
86
+ "extra_serializer_param2_for_model": param2_value,
87
+ },
88
+ hashable_obj2: {
89
+ "extra_serializer_param1_for_hashable2": param1_value,
90
+ "extra_serializer_param2_for_hashable2": param2_value,
91
+ },
92
+ })
93
+ # Or if the object to be serialized is unhashable, put them into the
94
+ # serializer_args one by one. If this is the only use case, there is
95
+ # no need to import `SerializerArgs`. Here we assume "X_train" and
96
+ # "y_train" is not hashable.
97
+ model.train.vertex.remote_config.serializer_args[X_train] = {
98
+ "extra_serializer_param1_for_X_train": param1_value,
99
+ "extra_serializer_param2_for_X_train": param2_value,
79
100
},
80
- X_train: {
81
- "extra_serializer_param1 ": param1_value,
82
- "extra_serializer_param2 ": param2_value,
101
+ model.train.vertex.remote_config.serializer_args[y_train] = {
102
+ "extra_serializer_param1_for_y_train ": param1_value,
103
+ "extra_serializer_param2_for_y_train ": param2_value,
83
104
}
84
- }
85
105
86
106
# Train the model as usual
87
107
model.train(X_train, y_train)
@@ -132,7 +152,7 @@ class RemoteConfig(_BaseConfig):
132
152
custom_commands (List[str]):
133
153
List of custom commands to be run in the remote job environment.
134
154
These commands will be run before the requirements are installed.
135
- serializer_args (Dict[Any, Dict[str, Any]]) :
155
+ serializer_args: serializers_base.SerializerArgs :
136
156
Map from object to extra arguments when serializing the object. The extra
137
157
arguments is a dictionary from the argument names to the argument values.
138
158
"""
@@ -143,7 +163,9 @@ class RemoteConfig(_BaseConfig):
143
163
service_account : Optional [str ] = None
144
164
requirements : List [str ] = dataclasses .field (default_factory = list )
145
165
custom_commands : List [str ] = dataclasses .field (default_factory = list )
146
- serializer_args : Dict [Any , Dict [str , Any ]] = dataclasses .field (default_factory = dict )
166
+ serializer_args : serializers_base .SerializerArgs = dataclasses .field (
167
+ default_factory = serializers_base .SerializerArgs
168
+ )
147
169
148
170
149
171
@dataclasses .dataclass
0 commit comments