Skip to content

Commit e9a3c1d

Browse files
authored
Show training error msg (#495)
* adding test and correcting the training error msg * update the tests * cleanup
1 parent f7211c4 commit e9a3c1d

File tree

4 files changed

+89
-10
lines changed

4 files changed

+89
-10
lines changed

ramp-engine/ramp_engine/aws/worker.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def collect_results(self):
166166
except Exception as e:
167167
logger.error("Error occurred when downloading the logs"
168168
f" from the submission: {e}")
169-
exit_status = 1
169+
exit_status = 2
170170
error_msg = str(e)
171171
self.status = 'error'
172172
if exit_status == 0:
@@ -189,7 +189,7 @@ def collect_results(self):
189189
error_msg = _get_traceback(
190190
aws._get_log_content(self.config, self.submission))
191191
self.status = 'collected'
192-
exit_status, error_msg = 1, ""
192+
exit_status = 1
193193
logger.info(repr(self))
194194
return exit_status, error_msg
195195

ramp-engine/ramp_engine/dispatcher.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ def collect_result(self, session):
212212
for worker, (submission_id, submission_name) in zip(workers,
213213
submissions):
214214
dt = worker.time_since_last_status_check()
215-
if dt is not None and dt < self.time_between_collection:
215+
if (dt is not None) and (dt < self.time_between_collection):
216216
self._processing_worker_queue.put_nowait(
217217
(worker, (submission_id, submission_name)))
218218
time.sleep(0)
@@ -231,20 +231,24 @@ def collect_result(self, session):
231231
else:
232232
self._logger.info(f'Collecting results from worker {worker}')
233233
returncode, stderr = worker.collect_results()
234+
234235
if returncode:
235236
if returncode == 124:
236237
self._logger.info(
237238
f'Worker {worker} killed due to timeout.'
238239
)
240+
submission_status = 'checking_error'
241+
elif returncode == 2:
242+
# Error occurred when downloading the logs
243+
submission_status = 'checking_error'
239244
else:
240245
self._logger.info(
241246
f'Worker {worker} killed due to an error '
242247
f'during training: {stderr}'
243248
)
244-
submission_status = 'training_error'
249+
submission_status = 'training_error'
245250
else:
246251
submission_status = 'tested'
247-
248252
set_submission_state(
249253
session, submission_id, submission_status
250254
)

ramp-engine/ramp_engine/tests/test_aws.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,9 @@ class DummyInstance:
105105
exit_status, error_msg = worker.collect_results()
106106
assert 'Error occurred when downloading the logs' in caplog.text
107107
assert 'Trying to download the log once again' in caplog.text
108-
assert exit_status == 1
108+
assert exit_status == 2
109109
assert 'test' in error_msg
110+
assert worker.status == 'error'
110111

111112

112113
@mock.patch('ramp_engine.aws.api._rsync')

ramp-engine/ramp_engine/tests/test_dispatcher.py

+78-4
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
import shutil
21
import os
3-
42
import pytest
3+
import shutil
4+
from unittest import mock
55

66
from ramp_utils import read_config
77
from ramp_utils.testing import database_config_template
@@ -228,7 +228,6 @@ def test_dispatcher_worker_retry(session_toy):
228228

229229
while not dispatcher._processing_worker_queue.empty():
230230
dispatcher.collect_result(session_toy)
231-
232231
submissions = get_submissions(session_toy, 'iris_test', 'new')
233232
assert submission_name in [sub[1] for sub in submissions]
234233

@@ -253,7 +252,82 @@ def test_dispatcher_aws_not_launching(session_toy_aws, caplog):
253252
assert 'training' not in caplog.text
254253
num_running_workers = dispatcher._processing_worker_queue.qsize()
255254
assert num_running_workers == 0
256-
257255
submissions2 = get_submissions(session_toy_aws, 'iris_aws_test', 'new')
258256
# assert that all the submissions are still in the 'new' state
259257
assert len(submissions) == len(submissions2)
258+
259+
260+
@mock.patch('ramp_engine.aws.api.download_log')
261+
@mock.patch('ramp_engine.aws.api.check_instance_status')
262+
@mock.patch('ramp_engine.aws.api._get_log_content')
263+
@mock.patch('ramp_engine.aws.api._training_successful')
264+
@mock.patch('ramp_engine.aws.api._training_finished')
265+
@mock.patch('ramp_engine.aws.api.is_spot_terminated')
266+
@mock.patch('ramp_engine.aws.api.launch_train')
267+
@mock.patch('ramp_engine.aws.api.upload_submission')
268+
@mock.patch('ramp_engine.aws.api.launch_ec2_instances')
269+
def test_info_on_training_error(test_launch_ec2_instances, upload_submission,
270+
launch_train,
271+
is_spot_terminated, training_finished,
272+
training_successful,
273+
get_log_content, check_instance_status,
274+
download_log,
275+
session_toy_aws,
276+
caplog):
277+
# make sure that the Python error from the solution is passed to the
278+
# dispatcher
279+
# everything shoud be mocked as correct output from AWS instances
280+
# on setting up the instance and loading the submission
281+
# mock dummy AWS instance
282+
class DummyInstance:
283+
id = 1
284+
test_launch_ec2_instances.return_value = (DummyInstance(),), 0
285+
upload_submission.return_value = 0
286+
launch_train.return_value = 0
287+
is_spot_terminated.return_value = 0
288+
training_finished.return_value = False
289+
download_log.return_value = 0
290+
291+
config = read_config(database_config_template())
292+
event_config = read_config(ramp_aws_config_template())
293+
294+
dispatcher = Dispatcher(config=config,
295+
event_config=event_config,
296+
worker=AWSWorker, n_workers=10,
297+
hunger_policy='exit')
298+
dispatcher.fetch_from_db(session_toy_aws)
299+
dispatcher.launch_workers(session_toy_aws)
300+
num_running_workers = dispatcher._processing_worker_queue.qsize()
301+
# worker, (submission_id, submission_name) = \
302+
# dispatcher._processing_worker_queue.get()
303+
# assert worker.status == 'running'
304+
submissions = get_submissions(session_toy_aws,
305+
'iris_aws_test',
306+
'training')
307+
ids = [submissions[idx][0] for idx in range(len(submissions))]
308+
assert len(submissions) > 1
309+
assert num_running_workers == len(ids)
310+
311+
dispatcher.time_between_collection = 0
312+
training_successful.return_value = False
313+
314+
# now we will end the submission with training error
315+
training_finished.return_value = True
316+
training_error_msg = 'Python error here'
317+
get_log_content.return_value = training_error_msg
318+
check_instance_status.return_value = 'finished'
319+
320+
dispatcher.collect_result(session_toy_aws)
321+
322+
# the worker which we were using should have been teared down
323+
num_running_workers = dispatcher._processing_worker_queue.qsize()
324+
325+
assert num_running_workers == 0
326+
327+
submissions = get_submissions(session_toy_aws,
328+
'iris_aws_test',
329+
'training_error')
330+
assert len(submissions) == len(ids)
331+
332+
submission = get_submission_by_id(session_toy_aws, submissions[0][0])
333+
assert training_error_msg in submission.error_msg

0 commit comments

Comments
 (0)