Skip to content

Commit ccd7a4c

Browse files
committed
refine
Signed-off-by: forsaken628 <[email protected]>
1 parent bf9f8cd commit ccd7a4c

File tree

2 files changed

+29
-19
lines changed

2 files changed

+29
-19
lines changed

pkg/earlystopping/v1beta1/medianstop/service.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import logging
16-
from typing import Iterable
16+
from typing import Iterable, Optional
1717
from kubernetes import client, config
1818
import multiprocessing
1919
from datetime import datetime
@@ -65,7 +65,7 @@ def __init__(self):
6565

6666
self.api_instance = client.CustomObjectsApi()
6767

68-
def ValidateEarlyStoppingSettings(self, request, context):
68+
def ValidateEarlyStoppingSettings(self, request: api_pb2.ValidateEarlyStoppingSettingsRequest, context: grpc.ServicerContext) -> api_pb2.ValidateEarlyStoppingSettingsReply:
6969
is_valid, message = self.validate_early_stopping_spec(request.early_stopping)
7070
if not is_valid:
7171
context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
@@ -98,7 +98,7 @@ def validate_medianstop_setting(early_stopping_settings):
9898

9999
return True, ""
100100

101-
def GetEarlyStoppingRules(self, request: api_pb2.GetEarlyStoppingRulesRequest, context):
101+
def GetEarlyStoppingRules(self, request: api_pb2.GetEarlyStoppingRulesRequest, context: grpc.ServicerContext) -> api_pb2.GetSuggestionsReply:
102102
logger.info("Get new early stopping rules")
103103

104104
# Get required values for the first call.
@@ -145,17 +145,25 @@ def get_early_stopping_settings(self, early_stopping_settings: Iterable[api_pb2.
145145
elif setting.name == "start_step":
146146
self.start_step = int(setting.value)
147147

148-
def get_median_value(self, trials: Iterable[api_pb2.Trial]):
148+
def get_median_value(self, trials: Iterable[api_pb2.Trial]) -> Optional[float]:
149149
for trial in trials:
150150
# Get metrics only for the new succeeded Trials.
151-
if trial.name not in self.trials_avg_history and trial.status.condition == SUCCEEDED_TRIAL:
152-
with grpc.beta.implementations.insecure_channel(
153-
self.db_manager_address[0], int(self.db_manager_address[1])) as channel:
151+
if (
152+
trial.name not in self.trials_avg_history
153+
and trial.status.condition == SUCCEEDED_TRIAL
154+
):
155+
with grpc.insecure_channel(
156+
f"{self.db_manager_address[0]}:{self.db_manager_address[1]}"
157+
) as channel:
154158
stub = api_pb2_grpc.DBManagerStub(channel)
155-
get_log_response = stub.GetObservationLog(api_pb2.GetObservationLogRequest(
156-
trial_name=trial.name,
157-
metric_name=self.objective_metric
158-
), timeout=APISERVER_TIMEOUT)
159+
get_log_response: api_pb2.GetObservationLogReply = (
160+
stub.GetObservationLog(
161+
api_pb2.GetObservationLogRequest(
162+
trial_name=trial.name, metric_name=self.objective_metric
163+
),
164+
timeout=APISERVER_TIMEOUT,
165+
)
166+
)
159167

160168
# Get only first start_step metrics.
161169
# Since metrics are collected consistently and ordered by time, we slice top start_step metrics.
@@ -182,7 +190,7 @@ def get_median_value(self, trials: Iterable[api_pb2.Trial]):
182190
))
183191
return None
184192

185-
def SetTrialStatus(self, request, context):
193+
def SetTrialStatus(self, request: api_pb2.SetTrialStatusRequest, context: grpc.ServicerContext) -> api_pb2.SetTrialStatusReply:
186194
trial_name = request.trial_name
187195

188196
logger.info("Update status for Trial: {}".format(trial_name))

test/conftest.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import os
22
from sys import path
33

4-
root = os.path.join(os.path.dirname(__file__),'..')
5-
path.extend([
6-
os.path.join(root,'pkg/apis/manager/v1beta1/python'),
7-
os.path.join(root,'pkg/apis/manager/health/python'),
8-
os.path.join(root,'pkg/metricscollector/v1beta1/common'),
9-
os.path.join(root,'pkg/metricscollector/v1beta1/tfevent-metricscollector')
10-
])
4+
root = os.path.join(os.path.dirname(__file__), "..")
5+
path.extend(
6+
[
7+
os.path.join(root, "pkg/apis/manager/v1beta1/python"),
8+
os.path.join(root, "pkg/apis/manager/health/python"),
9+
os.path.join(root, "pkg/metricscollector/v1beta1/common"),
10+
os.path.join(root, "pkg/metricscollector/v1beta1/tfevent-metricscollector"),
11+
]
12+
)

0 commit comments

Comments
 (0)