diff --git a/tests/providers/amazon/aws/sensors/test_batch.py b/tests/providers/amazon/aws/sensors/test_batch.py index 74b348381e8ea..353db3b812bf8 100644 --- a/tests/providers/amazon/aws/sensors/test_batch.py +++ b/tests/providers/amazon/aws/sensors/test_batch.py @@ -32,78 +32,112 @@ TASK_ID = "batch_job_sensor" JOB_ID = "8222a1c2-b246-4e19-b1b8-0039bb4407c0" AWS_REGION = "eu-west-1" +ENVIRONMENT_NAME = "environment_name" +JOB_QUEUE = "job_queue" -class TestBatchSensor: - def setup_method(self): - self.batch_sensor = BatchSensor( - task_id="batch_job_sensor", - job_id=JOB_ID, - ) +@pytest.fixture(scope="module") +def batch_sensor() -> BatchSensor: + return BatchSensor( + task_id="batch_job_sensor", + job_id=JOB_ID, + ) + + +@pytest.fixture(scope="module") +def deferrable_batch_sensor() -> BatchSensor: + return BatchSensor(task_id="task", job_id=JOB_ID, region_name=AWS_REGION, deferrable=True) + +class TestBatchSensor: @mock.patch.object(BatchClientHook, "get_job_description") - def test_poke_on_success_state(self, mock_get_job_description): + def test_poke_on_success_state(self, mock_get_job_description, batch_sensor: BatchSensor): mock_get_job_description.return_value = {"status": "SUCCEEDED"} - assert self.batch_sensor.poke({}) is True + assert batch_sensor.poke({}) is True mock_get_job_description.assert_called_once_with(JOB_ID) @mock.patch.object(BatchClientHook, "get_job_description") - def test_poke_on_failure_state(self, mock_get_job_description): + def test_poke_on_failure_state(self, mock_get_job_description, batch_sensor: BatchSensor): mock_get_job_description.return_value = {"status": "FAILED"} with pytest.raises(AirflowException, match="Batch sensor failed. AWS Batch job status: FAILED"): - self.batch_sensor.poke({}) + batch_sensor.poke({}) mock_get_job_description.assert_called_once_with(JOB_ID) @mock.patch.object(BatchClientHook, "get_job_description") - def test_poke_on_invalid_state(self, mock_get_job_description): + def test_poke_on_invalid_state(self, mock_get_job_description, batch_sensor: BatchSensor): mock_get_job_description.return_value = {"status": "INVALID"} with pytest.raises( AirflowException, match="Batch sensor failed. Unknown AWS Batch job status: INVALID" ): - self.batch_sensor.poke({}) + batch_sensor.poke({}) mock_get_job_description.assert_called_once_with(JOB_ID) @pytest.mark.parametrize("job_status", ["SUBMITTED", "PENDING", "RUNNABLE", "STARTING", "RUNNING"]) @mock.patch.object(BatchClientHook, "get_job_description") - def test_poke_on_intermediate_state(self, mock_get_job_description, job_status): + def test_poke_on_intermediate_state( + self, mock_get_job_description, job_status, batch_sensor: BatchSensor + ): print(job_status) mock_get_job_description.return_value = {"status": job_status} - assert self.batch_sensor.poke({}) is False + assert batch_sensor.poke({}) is False mock_get_job_description.assert_called_once_with(JOB_ID) + def test_execute_in_deferrable_mode(self, deferrable_batch_sensor: BatchSensor): + """ + Asserts that a task is deferred and a BatchSensorTrigger will be fired + when the BatchSensor is executed in deferrable mode. + """ -class TestBatchComputeEnvironmentSensor: - def setup_method(self): - self.environment_name = "environment_name" - self.sensor = BatchComputeEnvironmentSensor( - task_id="test_batch_compute_environment_sensor", - compute_environment=self.environment_name, - ) + with pytest.raises(TaskDeferred) as exc: + deferrable_batch_sensor.execute({}) + assert isinstance(exc.value.trigger, BatchJobTrigger), "Trigger is not a BatchJobTrigger" + + def test_execute_failure_in_deferrable_mode(self, deferrable_batch_sensor: BatchSensor): + """Tests that an AirflowException is raised in case of error event""" + + with pytest.raises(AirflowException): + deferrable_batch_sensor.execute_complete(context={}, event={"status": "failure"}) + + +@pytest.fixture(scope="module") +def batch_compute_environment_sensor() -> BatchComputeEnvironmentSensor: + return BatchComputeEnvironmentSensor( + task_id="test_batch_compute_environment_sensor", + compute_environment=ENVIRONMENT_NAME, + ) + +class TestBatchComputeEnvironmentSensor: @mock.patch.object(BatchClientHook, "client") - def test_poke_no_environment(self, mock_batch_client): + def test_poke_no_environment( + self, mock_batch_client, batch_compute_environment_sensor: BatchComputeEnvironmentSensor + ): mock_batch_client.describe_compute_environments.return_value = {"computeEnvironments": []} with pytest.raises(AirflowException) as ctx: - self.sensor.poke({}) + batch_compute_environment_sensor.poke({}) mock_batch_client.describe_compute_environments.assert_called_once_with( - computeEnvironments=[self.environment_name], + computeEnvironments=[ENVIRONMENT_NAME], ) assert "not found" in str(ctx.value) @mock.patch.object(BatchClientHook, "client") - def test_poke_valid(self, mock_batch_client): + def test_poke_valid( + self, mock_batch_client, batch_compute_environment_sensor: BatchComputeEnvironmentSensor + ): mock_batch_client.describe_compute_environments.return_value = { "computeEnvironments": [{"status": "VALID"}] } - assert self.sensor.poke({}) is True + assert batch_compute_environment_sensor.poke({}) is True mock_batch_client.describe_compute_environments.assert_called_once_with( - computeEnvironments=[self.environment_name], + computeEnvironments=[ENVIRONMENT_NAME], ) @mock.patch.object(BatchClientHook, "client") - def test_poke_running(self, mock_batch_client): + def test_poke_running( + self, mock_batch_client, batch_compute_environment_sensor: BatchComputeEnvironmentSensor + ): mock_batch_client.describe_compute_environments.return_value = { "computeEnvironments": [ { @@ -111,13 +145,15 @@ def test_poke_running(self, mock_batch_client): } ] } - assert self.sensor.poke({}) is False + assert batch_compute_environment_sensor.poke({}) is False mock_batch_client.describe_compute_environments.assert_called_once_with( - computeEnvironments=[self.environment_name], + computeEnvironments=[ENVIRONMENT_NAME], ) @mock.patch.object(BatchClientHook, "client") - def test_poke_invalid(self, mock_batch_client): + def test_poke_invalid( + self, mock_batch_client, batch_compute_environment_sensor: BatchComputeEnvironmentSensor + ): mock_batch_client.describe_compute_environments.return_value = { "computeEnvironments": [ { @@ -126,50 +162,53 @@ def test_poke_invalid(self, mock_batch_client): ] } with pytest.raises(AirflowException) as ctx: - self.sensor.poke({}) + batch_compute_environment_sensor.poke({}) mock_batch_client.describe_compute_environments.assert_called_once_with( - computeEnvironments=[self.environment_name], + computeEnvironments=[ENVIRONMENT_NAME], ) assert "AWS Batch compute environment failed" in str(ctx.value) -class TestBatchJobQueueSensor: - def setup_method(self): - self.job_queue = "job_queue" - self.sensor = BatchJobQueueSensor( - task_id="test_batch_job_queue_sensor", - job_queue=self.job_queue, - ) +@pytest.fixture(scope="module") +def batch_job_queue_sensor() -> BatchJobQueueSensor: + return BatchJobQueueSensor( + task_id="test_batch_job_queue_sensor", + job_queue=JOB_QUEUE, + ) + +class TestBatchJobQueueSensor: @mock.patch.object(BatchClientHook, "client") - def test_poke_no_queue(self, mock_batch_client): + def test_poke_no_queue(self, mock_batch_client, batch_job_queue_sensor: BatchJobQueueSensor): mock_batch_client.describe_job_queues.return_value = {"jobQueues": []} with pytest.raises(AirflowException) as ctx: - self.sensor.poke({}) + batch_job_queue_sensor.poke({}) mock_batch_client.describe_job_queues.assert_called_once_with( - jobQueues=[self.job_queue], + jobQueues=[JOB_QUEUE], ) assert "not found" in str(ctx.value) @mock.patch.object(BatchClientHook, "client") - def test_poke_no_queue_with_treat_non_existing_as_deleted(self, mock_batch_client): - self.sensor.treat_non_existing_as_deleted = True + def test_poke_no_queue_with_treat_non_existing_as_deleted( + self, mock_batch_client, batch_job_queue_sensor: BatchJobQueueSensor + ): + batch_job_queue_sensor.treat_non_existing_as_deleted = True mock_batch_client.describe_job_queues.return_value = {"jobQueues": []} - assert self.sensor.poke({}) is True + assert batch_job_queue_sensor.poke({}) is True mock_batch_client.describe_job_queues.assert_called_once_with( - jobQueues=[self.job_queue], + jobQueues=[JOB_QUEUE], ) @mock.patch.object(BatchClientHook, "client") - def test_poke_valid(self, mock_batch_client): + def test_poke_valid(self, mock_batch_client, batch_job_queue_sensor: BatchJobQueueSensor): mock_batch_client.describe_job_queues.return_value = {"jobQueues": [{"status": "VALID"}]} - assert self.sensor.poke({}) is True + assert batch_job_queue_sensor.poke({}) is True mock_batch_client.describe_job_queues.assert_called_once_with( - jobQueues=[self.job_queue], + jobQueues=[JOB_QUEUE], ) @mock.patch.object(BatchClientHook, "client") - def test_poke_running(self, mock_batch_client): + def test_poke_running(self, mock_batch_client, batch_job_queue_sensor: BatchJobQueueSensor): mock_batch_client.describe_job_queues.return_value = { "jobQueues": [ { @@ -177,13 +216,13 @@ def test_poke_running(self, mock_batch_client): } ] } - assert self.sensor.poke({}) is False + assert batch_job_queue_sensor.poke({}) is False mock_batch_client.describe_job_queues.assert_called_once_with( - jobQueues=[self.job_queue], + jobQueues=[JOB_QUEUE], ) @mock.patch.object(BatchClientHook, "client") - def test_poke_invalid(self, mock_batch_client): + def test_poke_invalid(self, mock_batch_client, batch_job_queue_sensor: BatchJobQueueSensor): mock_batch_client.describe_job_queues.return_value = { "jobQueues": [ { @@ -192,28 +231,8 @@ def test_poke_invalid(self, mock_batch_client): ] } with pytest.raises(AirflowException) as ctx: - self.sensor.poke({}) + batch_job_queue_sensor.poke({}) mock_batch_client.describe_job_queues.assert_called_once_with( - jobQueues=[self.job_queue], + jobQueues=[JOB_QUEUE], ) assert "AWS Batch job queue failed" in str(ctx.value) - - -class TestBatchAsyncSensor: - TASK = BatchSensor(task_id="task", job_id=JOB_ID, region_name=AWS_REGION, deferrable=True) - - def test_batch_sensor_async(self): - """ - Asserts that a task is deferred and a BatchSensorTrigger will be fired - when the BatchSensorAsync is executed. - """ - - with pytest.raises(TaskDeferred) as exc: - self.TASK.execute({}) - assert isinstance(exc.value.trigger, BatchJobTrigger), "Trigger is not a BatchJobTrigger" - - def test_batch_sensor_async_execute_failure(self): - """Tests that an AirflowException is raised in case of error event""" - - with pytest.raises(AirflowException): - self.TASK.execute_complete(context={}, event={"status": "failure"})