Skip to content

Commit c01d8a9

Browse files
authored
Don't wait till training data processing gets finished (#7)
* Release lock as soon as data processing is registered * Fix typo * Remove timeout test case for training data upload
1 parent 2331ca6 commit c01d8a9

File tree

3 files changed

+29
-44
lines changed

3 files changed

+29
-44
lines changed

aws/traffic_shadowing/src/errors.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,3 @@ class ModelRegistrationFailed(Exception):
2222

2323
class DataUploadFailed(Exception):
2424
pass
25-
26-
27-
class TimeOut(Exception):
28-
pass

aws/traffic_shadowing/src/model_pool.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -111,14 +111,13 @@ def create_model(
111111
metadata: dict = None
112112
) -> Model:
113113
"""
114-
Register an external model in the Hydrosphere platform
115-
and uploads training data.
114+
Register an external model and send training data.
116115
"""
117-
registration_response = self._register_model(name, schema, metadata)
116+
response = self._register_model(name, schema, metadata)
118117
model = Model(
119-
name=registration_response["model"]["name"],
120-
version=registration_response["modelVersion"],
121-
model_version_id=registration_response["id"]
118+
name=response["model"]["name"],
119+
version=response["modelVersion"],
120+
model_version_id=response["id"]
122121
)
123122
self._upload_training_data(model.model_version_id, training_file)
124123
self._wait_for_data_processing(model.model_version_id)
@@ -127,39 +126,41 @@ def create_model(
127126
def _wait_for_data_processing(
128127
self,
129128
model_version_id,
130-
timeout: int = 120,
131129
retry: int = 3
132130
) -> requests.Response:
133-
"""Wait till the data gets processed."""
131+
"""Wait till the data gets registered or finishes the processing."""
132+
133+
def tick(sleep: int = 5) -> bool:
134+
nonlocal retry
135+
if retry == 0:
136+
return False
137+
retry -= 1
138+
time.sleep(sleep)
139+
return True
140+
134141
url = urllib.parse.urljoin(
135142
self.endpoint, f"/monitoring/profiles/batch/{model_version_id}/status")
136-
137143
result = None
138144
while True:
139145
result = requests.get(url)
140146
if result.status_code != 200:
141-
if retry > 0:
142-
retry -= 1
143-
time.sleep(5)
147+
if tick():
144148
continue
145149
else:
146150
raise errors.ApiNotAvailable(
147151
"Could not fetch the status of the data processing task.")
148152

149153
body = result.json()
150154
status = DataProfileStatus[body["kind"]]
151-
if status == DataProfileStatus.Processing:
152-
if timeout > 0:
153-
seconds = min(10, timeout)
154-
timeout -= seconds
155-
time.sleep(seconds)
155+
156+
if status in (DataProfileStatus.Success, DataProfileStatus.Processing):
157+
# Don't wait until all data gets processed, release immediately
158+
break
159+
elif status == DataProfileStatus.NotRegistered:
160+
# If profile hasn't been registered yet, wait for a little more
161+
if tick():
156162
continue
157-
else:
158-
raise errors.TimeOut("Data processing timed out.")
159-
elif status == DataProfileStatus.Success:
160-
break
161-
else:
162-
raise errors.DataUploadFailed(f"Failed to upload the data: {status}")
163+
raise errors.DataUploadFailed(f"Failed to upload the data: {status}")
163164
return result
164165

165166
def _upload_training_data(self, model_version_id: int, training_file: str) -> requests.Response:
@@ -189,8 +190,7 @@ def _register_model(self, name: str, schema: SchemaDescription, metadata: dict)
189190
response = result.json()
190191
self.logger.info("Registered a new model: %s", response["model"]["name"])
191192
else:
192-
raise errors.ModelRegistrationFailed(
193-
f"Could not register a model: {result.content}")
193+
raise errors.ModelRegistrationFailed(f"Could not register a model: {result.content}")
194194
return response
195195

196196
def _create_feature(self, column: ColumnDescription) -> dict:

aws/traffic_shadowing/tests/test_model_pool.py

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import requests_mock
44
from src.model_pool import ModelPool
55
from src.errors import (
6-
ModelNotFound, DataUploadFailed, TimeOut, ApiNotAvailable
6+
ModelNotFound, DataUploadFailed, ApiNotAvailable
77
)
88
from tests.stubs.http.hydrosphere import (
99
ListModelsStub, ListModelVersionsStub, RegisterExternalModelStub,
@@ -106,17 +106,6 @@ def test_wait_training_data_processed():
106106
assert result.status_code == 200
107107

108108

109-
def test_wait_training_data_processed_timeout():
110-
with requests_mock.mock(real_http=False) as mock:
111-
mock.get(**WaitTrainingDataProcessingStub(
112-
kind="Processing",
113-
model_version_id=MODEL_VERSION_ID,
114-
).generate_response())
115-
with pytest.raises(TimeOut):
116-
pool = ModelPool(HYDROSPHERE_ENDPOINT)
117-
pool._wait_for_data_processing(MODEL_VERSION_ID, timeout=1, retry=0)
118-
119-
120109
def test_wait_training_data_processed_fail():
121110
with requests_mock.mock(real_http=False) as mock:
122111
mock.get(**WaitTrainingDataProcessingStub(
@@ -125,7 +114,7 @@ def test_wait_training_data_processed_fail():
125114
).generate_response())
126115
with pytest.raises(DataUploadFailed):
127116
pool = ModelPool(HYDROSPHERE_ENDPOINT)
128-
pool._wait_for_data_processing(MODEL_VERSION_ID, timeout=1, retry=0)
117+
pool._wait_for_data_processing(MODEL_VERSION_ID, retry=0)
129118

130119

131120
def test_wait_training_data_processed_not_registered():
@@ -136,7 +125,7 @@ def test_wait_training_data_processed_not_registered():
136125
).generate_response())
137126
with pytest.raises(DataUploadFailed):
138127
pool = ModelPool(HYDROSPHERE_ENDPOINT)
139-
pool._wait_for_data_processing(MODEL_VERSION_ID, timeout=1, retry=0)
128+
pool._wait_for_data_processing(MODEL_VERSION_ID, retry=0)
140129

141130

142131
def test_wait_training_data_processed_unavailable():
@@ -147,7 +136,7 @@ def test_wait_training_data_processed_unavailable():
147136
).generate_client_error())
148137
with pytest.raises(ApiNotAvailable):
149138
pool = ModelPool(HYDROSPHERE_ENDPOINT)
150-
pool._wait_for_data_processing(MODEL_VERSION_ID, timeout=1, retry=0)
139+
pool._wait_for_data_processing(MODEL_VERSION_ID, retry=0)
151140

152141

153142
def test_create_model():

0 commit comments

Comments
 (0)