Skip to content

Fix outdated test name and description in BatchSensor #33407

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 95 additions & 76 deletions tests/providers/amazon/aws/sensors/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,92 +32,128 @@
TASK_ID = "batch_job_sensor"
JOB_ID = "8222a1c2-b246-4e19-b1b8-0039bb4407c0"
AWS_REGION = "eu-west-1"
ENVIRONMENT_NAME = "environment_name"
JOB_QUEUE = "job_queue"


class TestBatchSensor:
def setup_method(self):
self.batch_sensor = BatchSensor(
task_id="batch_job_sensor",
job_id=JOB_ID,
)
@pytest.fixture(scope="module")
def batch_sensor() -> BatchSensor:
return BatchSensor(
task_id="batch_job_sensor",
job_id=JOB_ID,
)


@pytest.fixture(scope="module")
def deferrable_batch_sensor() -> BatchSensor:
return BatchSensor(task_id="task", job_id=JOB_ID, region_name=AWS_REGION, deferrable=True)


class TestBatchSensor:
@mock.patch.object(BatchClientHook, "get_job_description")
def test_poke_on_success_state(self, mock_get_job_description):
def test_poke_on_success_state(self, mock_get_job_description, batch_sensor: BatchSensor):
mock_get_job_description.return_value = {"status": "SUCCEEDED"}
assert self.batch_sensor.poke({}) is True
assert batch_sensor.poke({}) is True
mock_get_job_description.assert_called_once_with(JOB_ID)

@mock.patch.object(BatchClientHook, "get_job_description")
def test_poke_on_failure_state(self, mock_get_job_description):
def test_poke_on_failure_state(self, mock_get_job_description, batch_sensor: BatchSensor):
mock_get_job_description.return_value = {"status": "FAILED"}
with pytest.raises(AirflowException, match="Batch sensor failed. AWS Batch job status: FAILED"):
self.batch_sensor.poke({})
batch_sensor.poke({})

mock_get_job_description.assert_called_once_with(JOB_ID)

@mock.patch.object(BatchClientHook, "get_job_description")
def test_poke_on_invalid_state(self, mock_get_job_description):
def test_poke_on_invalid_state(self, mock_get_job_description, batch_sensor: BatchSensor):
mock_get_job_description.return_value = {"status": "INVALID"}
with pytest.raises(
AirflowException, match="Batch sensor failed. Unknown AWS Batch job status: INVALID"
):
self.batch_sensor.poke({})
batch_sensor.poke({})

mock_get_job_description.assert_called_once_with(JOB_ID)

@pytest.mark.parametrize("job_status", ["SUBMITTED", "PENDING", "RUNNABLE", "STARTING", "RUNNING"])
@mock.patch.object(BatchClientHook, "get_job_description")
def test_poke_on_intermediate_state(self, mock_get_job_description, job_status):
def test_poke_on_intermediate_state(
self, mock_get_job_description, job_status, batch_sensor: BatchSensor
):
print(job_status)
mock_get_job_description.return_value = {"status": job_status}
assert self.batch_sensor.poke({}) is False
assert batch_sensor.poke({}) is False
mock_get_job_description.assert_called_once_with(JOB_ID)

def test_execute_in_deferrable_mode(self, deferrable_batch_sensor: BatchSensor):
"""
Asserts that a task is deferred and a BatchSensorTrigger will be fired
when the BatchSensor is executed in deferrable mode.
"""

class TestBatchComputeEnvironmentSensor:
def setup_method(self):
self.environment_name = "environment_name"
self.sensor = BatchComputeEnvironmentSensor(
task_id="test_batch_compute_environment_sensor",
compute_environment=self.environment_name,
)
with pytest.raises(TaskDeferred) as exc:
deferrable_batch_sensor.execute({})
assert isinstance(exc.value.trigger, BatchJobTrigger), "Trigger is not a BatchJobTrigger"

def test_execute_failure_in_deferrable_mode(self, deferrable_batch_sensor: BatchSensor):
"""Tests that an AirflowException is raised in case of error event"""

with pytest.raises(AirflowException):
deferrable_batch_sensor.execute_complete(context={}, event={"status": "failure"})


@pytest.fixture(scope="module")
def batch_compute_environment_sensor() -> BatchComputeEnvironmentSensor:
return BatchComputeEnvironmentSensor(
task_id="test_batch_compute_environment_sensor",
compute_environment=ENVIRONMENT_NAME,
)


class TestBatchComputeEnvironmentSensor:
@mock.patch.object(BatchClientHook, "client")
def test_poke_no_environment(self, mock_batch_client):
def test_poke_no_environment(
self, mock_batch_client, batch_compute_environment_sensor: BatchComputeEnvironmentSensor
):
mock_batch_client.describe_compute_environments.return_value = {"computeEnvironments": []}
with pytest.raises(AirflowException) as ctx:
self.sensor.poke({})
batch_compute_environment_sensor.poke({})
mock_batch_client.describe_compute_environments.assert_called_once_with(
computeEnvironments=[self.environment_name],
computeEnvironments=[ENVIRONMENT_NAME],
)
assert "not found" in str(ctx.value)

@mock.patch.object(BatchClientHook, "client")
def test_poke_valid(self, mock_batch_client):
def test_poke_valid(
self, mock_batch_client, batch_compute_environment_sensor: BatchComputeEnvironmentSensor
):
mock_batch_client.describe_compute_environments.return_value = {
"computeEnvironments": [{"status": "VALID"}]
}
assert self.sensor.poke({}) is True
assert batch_compute_environment_sensor.poke({}) is True
mock_batch_client.describe_compute_environments.assert_called_once_with(
computeEnvironments=[self.environment_name],
computeEnvironments=[ENVIRONMENT_NAME],
)

@mock.patch.object(BatchClientHook, "client")
def test_poke_running(self, mock_batch_client):
def test_poke_running(
self, mock_batch_client, batch_compute_environment_sensor: BatchComputeEnvironmentSensor
):
mock_batch_client.describe_compute_environments.return_value = {
"computeEnvironments": [
{
"status": "CREATING",
}
]
}
assert self.sensor.poke({}) is False
assert batch_compute_environment_sensor.poke({}) is False
mock_batch_client.describe_compute_environments.assert_called_once_with(
computeEnvironments=[self.environment_name],
computeEnvironments=[ENVIRONMENT_NAME],
)

@mock.patch.object(BatchClientHook, "client")
def test_poke_invalid(self, mock_batch_client):
def test_poke_invalid(
self, mock_batch_client, batch_compute_environment_sensor: BatchComputeEnvironmentSensor
):
mock_batch_client.describe_compute_environments.return_value = {
"computeEnvironments": [
{
Expand All @@ -126,64 +162,67 @@ def test_poke_invalid(self, mock_batch_client):
]
}
with pytest.raises(AirflowException) as ctx:
self.sensor.poke({})
batch_compute_environment_sensor.poke({})
mock_batch_client.describe_compute_environments.assert_called_once_with(
computeEnvironments=[self.environment_name],
computeEnvironments=[ENVIRONMENT_NAME],
)
assert "AWS Batch compute environment failed" in str(ctx.value)


class TestBatchJobQueueSensor:
def setup_method(self):
self.job_queue = "job_queue"
self.sensor = BatchJobQueueSensor(
task_id="test_batch_job_queue_sensor",
job_queue=self.job_queue,
)
@pytest.fixture(scope="module")
def batch_job_queue_sensor() -> BatchJobQueueSensor:
return BatchJobQueueSensor(
task_id="test_batch_job_queue_sensor",
job_queue=JOB_QUEUE,
)


class TestBatchJobQueueSensor:
@mock.patch.object(BatchClientHook, "client")
def test_poke_no_queue(self, mock_batch_client):
def test_poke_no_queue(self, mock_batch_client, batch_job_queue_sensor: BatchJobQueueSensor):
mock_batch_client.describe_job_queues.return_value = {"jobQueues": []}
with pytest.raises(AirflowException) as ctx:
self.sensor.poke({})
batch_job_queue_sensor.poke({})
mock_batch_client.describe_job_queues.assert_called_once_with(
jobQueues=[self.job_queue],
jobQueues=[JOB_QUEUE],
)
assert "not found" in str(ctx.value)

@mock.patch.object(BatchClientHook, "client")
def test_poke_no_queue_with_treat_non_existing_as_deleted(self, mock_batch_client):
self.sensor.treat_non_existing_as_deleted = True
def test_poke_no_queue_with_treat_non_existing_as_deleted(
self, mock_batch_client, batch_job_queue_sensor: BatchJobQueueSensor
):
batch_job_queue_sensor.treat_non_existing_as_deleted = True
mock_batch_client.describe_job_queues.return_value = {"jobQueues": []}
assert self.sensor.poke({}) is True
assert batch_job_queue_sensor.poke({}) is True
mock_batch_client.describe_job_queues.assert_called_once_with(
jobQueues=[self.job_queue],
jobQueues=[JOB_QUEUE],
)

@mock.patch.object(BatchClientHook, "client")
def test_poke_valid(self, mock_batch_client):
def test_poke_valid(self, mock_batch_client, batch_job_queue_sensor: BatchJobQueueSensor):
mock_batch_client.describe_job_queues.return_value = {"jobQueues": [{"status": "VALID"}]}
assert self.sensor.poke({}) is True
assert batch_job_queue_sensor.poke({}) is True
mock_batch_client.describe_job_queues.assert_called_once_with(
jobQueues=[self.job_queue],
jobQueues=[JOB_QUEUE],
)

@mock.patch.object(BatchClientHook, "client")
def test_poke_running(self, mock_batch_client):
def test_poke_running(self, mock_batch_client, batch_job_queue_sensor: BatchJobQueueSensor):
mock_batch_client.describe_job_queues.return_value = {
"jobQueues": [
{
"status": "CREATING",
}
]
}
assert self.sensor.poke({}) is False
assert batch_job_queue_sensor.poke({}) is False
mock_batch_client.describe_job_queues.assert_called_once_with(
jobQueues=[self.job_queue],
jobQueues=[JOB_QUEUE],
)

@mock.patch.object(BatchClientHook, "client")
def test_poke_invalid(self, mock_batch_client):
def test_poke_invalid(self, mock_batch_client, batch_job_queue_sensor: BatchJobQueueSensor):
mock_batch_client.describe_job_queues.return_value = {
"jobQueues": [
{
Expand All @@ -192,28 +231,8 @@ def test_poke_invalid(self, mock_batch_client):
]
}
with pytest.raises(AirflowException) as ctx:
self.sensor.poke({})
batch_job_queue_sensor.poke({})
mock_batch_client.describe_job_queues.assert_called_once_with(
jobQueues=[self.job_queue],
jobQueues=[JOB_QUEUE],
)
assert "AWS Batch job queue failed" in str(ctx.value)


class TestBatchAsyncSensor:
TASK = BatchSensor(task_id="task", job_id=JOB_ID, region_name=AWS_REGION, deferrable=True)

def test_batch_sensor_async(self):
"""
Asserts that a task is deferred and a BatchSensorTrigger will be fired
when the BatchSensorAsync is executed.
"""

with pytest.raises(TaskDeferred) as exc:
self.TASK.execute({})
assert isinstance(exc.value.trigger, BatchJobTrigger), "Trigger is not a BatchJobTrigger"

def test_batch_sensor_async_execute_failure(self):
"""Tests that an AirflowException is raised in case of error event"""

with pytest.raises(AirflowException):
self.TASK.execute_complete(context={}, event={"status": "failure"})