13
13
# limitations under the License.
14
14
15
15
import logging
16
- from typing import Iterable
16
+ from typing import Iterable , Optional
17
17
from kubernetes import client , config
18
18
import multiprocessing
19
19
from datetime import datetime
@@ -65,7 +65,7 @@ def __init__(self):
65
65
66
66
self .api_instance = client .CustomObjectsApi ()
67
67
68
- def ValidateEarlyStoppingSettings (self , request , context ) :
68
+ def ValidateEarlyStoppingSettings (self , request : api_pb2 . ValidateEarlyStoppingSettingsRequest , context : grpc . ServicerContext ) -> api_pb2 . ValidateEarlyStoppingSettingsReply :
69
69
is_valid , message = self .validate_early_stopping_spec (request .early_stopping )
70
70
if not is_valid :
71
71
context .set_code (grpc .StatusCode .INVALID_ARGUMENT )
@@ -98,7 +98,7 @@ def validate_medianstop_setting(early_stopping_settings):
98
98
99
99
return True , ""
100
100
101
- def GetEarlyStoppingRules (self , request : api_pb2 .GetEarlyStoppingRulesRequest , context ) :
101
+ def GetEarlyStoppingRules (self , request : api_pb2 .GetEarlyStoppingRulesRequest , context : grpc . ServicerContext ) -> api_pb2 . GetSuggestionsReply :
102
102
logger .info ("Get new early stopping rules" )
103
103
104
104
# Get required values for the first call.
@@ -145,17 +145,25 @@ def get_early_stopping_settings(self, early_stopping_settings: Iterable[api_pb2.
145
145
elif setting .name == "start_step" :
146
146
self .start_step = int (setting .value )
147
147
148
- def get_median_value (self , trials : Iterable [api_pb2 .Trial ]):
148
+ def get_median_value (self , trials : Iterable [api_pb2 .Trial ]) -> Optional [ float ] :
149
149
for trial in trials :
150
150
# Get metrics only for the new succeeded Trials.
151
- if trial .name not in self .trials_avg_history and trial .status .condition == SUCCEEDED_TRIAL :
152
- with grpc .beta .implementations .insecure_channel (
153
- self .db_manager_address [0 ], int (self .db_manager_address [1 ])) as channel :
151
+ if (
152
+ trial .name not in self .trials_avg_history
153
+ and trial .status .condition == SUCCEEDED_TRIAL
154
+ ):
155
+ with grpc .insecure_channel (
156
+ f"{ self .db_manager_address [0 ]} :{ self .db_manager_address [1 ]} "
157
+ ) as channel :
154
158
stub = api_pb2_grpc .DBManagerStub (channel )
155
- get_log_response = stub .GetObservationLog (api_pb2 .GetObservationLogRequest (
156
- trial_name = trial .name ,
157
- metric_name = self .objective_metric
158
- ), timeout = APISERVER_TIMEOUT )
159
+ get_log_response : api_pb2 .GetObservationLogReply = (
160
+ stub .GetObservationLog (
161
+ api_pb2 .GetObservationLogRequest (
162
+ trial_name = trial .name , metric_name = self .objective_metric
163
+ ),
164
+ timeout = APISERVER_TIMEOUT ,
165
+ )
166
+ )
159
167
160
168
# Get only first start_step metrics.
161
169
# Since metrics are collected consistently and ordered by time, we slice top start_step metrics.
@@ -182,7 +190,7 @@ def get_median_value(self, trials: Iterable[api_pb2.Trial]):
182
190
))
183
191
return None
184
192
185
- def SetTrialStatus (self , request , context ) :
193
+ def SetTrialStatus (self , request : api_pb2 . SetTrialStatusRequest , context : grpc . ServicerContext ) -> api_pb2 . SetTrialStatusReply :
186
194
trial_name = request .trial_name
187
195
188
196
logger .info ("Update status for Trial: {}" .format (trial_name ))
0 commit comments