1
- import shutil
2
1
import os
3
-
4
2
import pytest
3
+ import shutil
4
+ from unittest import mock
5
5
6
6
from ramp_utils import read_config
7
7
from ramp_utils .testing import database_config_template
@@ -228,7 +228,6 @@ def test_dispatcher_worker_retry(session_toy):
228
228
229
229
while not dispatcher ._processing_worker_queue .empty ():
230
230
dispatcher .collect_result (session_toy )
231
-
232
231
submissions = get_submissions (session_toy , 'iris_test' , 'new' )
233
232
assert submission_name in [sub [1 ] for sub in submissions ]
234
233
@@ -253,7 +252,82 @@ def test_dispatcher_aws_not_launching(session_toy_aws, caplog):
253
252
assert 'training' not in caplog .text
254
253
num_running_workers = dispatcher ._processing_worker_queue .qsize ()
255
254
assert num_running_workers == 0
256
-
257
255
submissions2 = get_submissions (session_toy_aws , 'iris_aws_test' , 'new' )
258
256
# assert that all the submissions are still in the 'new' state
259
257
assert len (submissions ) == len (submissions2 )
258
+
259
+
260
+ @mock .patch ('ramp_engine.aws.api.download_log' )
261
+ @mock .patch ('ramp_engine.aws.api.check_instance_status' )
262
+ @mock .patch ('ramp_engine.aws.api._get_log_content' )
263
+ @mock .patch ('ramp_engine.aws.api._training_successful' )
264
+ @mock .patch ('ramp_engine.aws.api._training_finished' )
265
+ @mock .patch ('ramp_engine.aws.api.is_spot_terminated' )
266
+ @mock .patch ('ramp_engine.aws.api.launch_train' )
267
+ @mock .patch ('ramp_engine.aws.api.upload_submission' )
268
+ @mock .patch ('ramp_engine.aws.api.launch_ec2_instances' )
269
+ def test_info_on_training_error (test_launch_ec2_instances , upload_submission ,
270
+ launch_train ,
271
+ is_spot_terminated , training_finished ,
272
+ training_successful ,
273
+ get_log_content , check_instance_status ,
274
+ download_log ,
275
+ session_toy_aws ,
276
+ caplog ):
277
+ # make sure that the Python error from the solution is passed to the
278
+ # dispatcher
279
+ # everything shoud be mocked as correct output from AWS instances
280
+ # on setting up the instance and loading the submission
281
+ # mock dummy AWS instance
282
+ class DummyInstance :
283
+ id = 1
284
+ test_launch_ec2_instances .return_value = (DummyInstance (),), 0
285
+ upload_submission .return_value = 0
286
+ launch_train .return_value = 0
287
+ is_spot_terminated .return_value = 0
288
+ training_finished .return_value = False
289
+ download_log .return_value = 0
290
+
291
+ config = read_config (database_config_template ())
292
+ event_config = read_config (ramp_aws_config_template ())
293
+
294
+ dispatcher = Dispatcher (config = config ,
295
+ event_config = event_config ,
296
+ worker = AWSWorker , n_workers = 10 ,
297
+ hunger_policy = 'exit' )
298
+ dispatcher .fetch_from_db (session_toy_aws )
299
+ dispatcher .launch_workers (session_toy_aws )
300
+ num_running_workers = dispatcher ._processing_worker_queue .qsize ()
301
+ # worker, (submission_id, submission_name) = \
302
+ # dispatcher._processing_worker_queue.get()
303
+ # assert worker.status == 'running'
304
+ submissions = get_submissions (session_toy_aws ,
305
+ 'iris_aws_test' ,
306
+ 'training' )
307
+ ids = [submissions [idx ][0 ] for idx in range (len (submissions ))]
308
+ assert len (submissions ) > 1
309
+ assert num_running_workers == len (ids )
310
+
311
+ dispatcher .time_between_collection = 0
312
+ training_successful .return_value = False
313
+
314
+ # now we will end the submission with training error
315
+ training_finished .return_value = True
316
+ training_error_msg = 'Python error here'
317
+ get_log_content .return_value = training_error_msg
318
+ check_instance_status .return_value = 'finished'
319
+
320
+ dispatcher .collect_result (session_toy_aws )
321
+
322
+ # the worker which we were using should have been teared down
323
+ num_running_workers = dispatcher ._processing_worker_queue .qsize ()
324
+
325
+ assert num_running_workers == 0
326
+
327
+ submissions = get_submissions (session_toy_aws ,
328
+ 'iris_aws_test' ,
329
+ 'training_error' )
330
+ assert len (submissions ) == len (ids )
331
+
332
+ submission = get_submission_by_id (session_toy_aws , submissions [0 ][0 ])
333
+ assert training_error_msg in submission .error_msg
0 commit comments