Skip to content

Commit a1576d3

Browse files
authored
Fix outdated test name and description in BatchSensor (#33407)
1 parent e57d0c9 commit a1576d3

File tree

1 file changed

+95
-76
lines changed

1 file changed

+95
-76
lines changed

tests/providers/amazon/aws/sensors/test_batch.py

Lines changed: 95 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -32,92 +32,128 @@
3232
TASK_ID = "batch_job_sensor"
3333
JOB_ID = "8222a1c2-b246-4e19-b1b8-0039bb4407c0"
3434
AWS_REGION = "eu-west-1"
35+
ENVIRONMENT_NAME = "environment_name"
36+
JOB_QUEUE = "job_queue"
3537

3638

37-
class TestBatchSensor:
38-
def setup_method(self):
39-
self.batch_sensor = BatchSensor(
40-
task_id="batch_job_sensor",
41-
job_id=JOB_ID,
42-
)
39+
@pytest.fixture(scope="module")
40+
def batch_sensor() -> BatchSensor:
41+
return BatchSensor(
42+
task_id="batch_job_sensor",
43+
job_id=JOB_ID,
44+
)
45+
46+
47+
@pytest.fixture(scope="module")
48+
def deferrable_batch_sensor() -> BatchSensor:
49+
return BatchSensor(task_id="task", job_id=JOB_ID, region_name=AWS_REGION, deferrable=True)
50+
4351

52+
class TestBatchSensor:
4453
@mock.patch.object(BatchClientHook, "get_job_description")
45-
def test_poke_on_success_state(self, mock_get_job_description):
54+
def test_poke_on_success_state(self, mock_get_job_description, batch_sensor: BatchSensor):
4655
mock_get_job_description.return_value = {"status": "SUCCEEDED"}
47-
assert self.batch_sensor.poke({}) is True
56+
assert batch_sensor.poke({}) is True
4857
mock_get_job_description.assert_called_once_with(JOB_ID)
4958

5059
@mock.patch.object(BatchClientHook, "get_job_description")
51-
def test_poke_on_failure_state(self, mock_get_job_description):
60+
def test_poke_on_failure_state(self, mock_get_job_description, batch_sensor: BatchSensor):
5261
mock_get_job_description.return_value = {"status": "FAILED"}
5362
with pytest.raises(AirflowException, match="Batch sensor failed. AWS Batch job status: FAILED"):
54-
self.batch_sensor.poke({})
63+
batch_sensor.poke({})
5564

5665
mock_get_job_description.assert_called_once_with(JOB_ID)
5766

5867
@mock.patch.object(BatchClientHook, "get_job_description")
59-
def test_poke_on_invalid_state(self, mock_get_job_description):
68+
def test_poke_on_invalid_state(self, mock_get_job_description, batch_sensor: BatchSensor):
6069
mock_get_job_description.return_value = {"status": "INVALID"}
6170
with pytest.raises(
6271
AirflowException, match="Batch sensor failed. Unknown AWS Batch job status: INVALID"
6372
):
64-
self.batch_sensor.poke({})
73+
batch_sensor.poke({})
6574

6675
mock_get_job_description.assert_called_once_with(JOB_ID)
6776

6877
@pytest.mark.parametrize("job_status", ["SUBMITTED", "PENDING", "RUNNABLE", "STARTING", "RUNNING"])
6978
@mock.patch.object(BatchClientHook, "get_job_description")
70-
def test_poke_on_intermediate_state(self, mock_get_job_description, job_status):
79+
def test_poke_on_intermediate_state(
80+
self, mock_get_job_description, job_status, batch_sensor: BatchSensor
81+
):
7182
print(job_status)
7283
mock_get_job_description.return_value = {"status": job_status}
73-
assert self.batch_sensor.poke({}) is False
84+
assert batch_sensor.poke({}) is False
7485
mock_get_job_description.assert_called_once_with(JOB_ID)
7586

87+
def test_execute_in_deferrable_mode(self, deferrable_batch_sensor: BatchSensor):
88+
"""
89+
Asserts that a task is deferred and a BatchSensorTrigger will be fired
90+
when the BatchSensor is executed in deferrable mode.
91+
"""
7692

77-
class TestBatchComputeEnvironmentSensor:
78-
def setup_method(self):
79-
self.environment_name = "environment_name"
80-
self.sensor = BatchComputeEnvironmentSensor(
81-
task_id="test_batch_compute_environment_sensor",
82-
compute_environment=self.environment_name,
83-
)
93+
with pytest.raises(TaskDeferred) as exc:
94+
deferrable_batch_sensor.execute({})
95+
assert isinstance(exc.value.trigger, BatchJobTrigger), "Trigger is not a BatchJobTrigger"
96+
97+
def test_execute_failure_in_deferrable_mode(self, deferrable_batch_sensor: BatchSensor):
98+
"""Tests that an AirflowException is raised in case of error event"""
99+
100+
with pytest.raises(AirflowException):
101+
deferrable_batch_sensor.execute_complete(context={}, event={"status": "failure"})
102+
103+
104+
@pytest.fixture(scope="module")
105+
def batch_compute_environment_sensor() -> BatchComputeEnvironmentSensor:
106+
return BatchComputeEnvironmentSensor(
107+
task_id="test_batch_compute_environment_sensor",
108+
compute_environment=ENVIRONMENT_NAME,
109+
)
84110

111+
112+
class TestBatchComputeEnvironmentSensor:
85113
@mock.patch.object(BatchClientHook, "client")
86-
def test_poke_no_environment(self, mock_batch_client):
114+
def test_poke_no_environment(
115+
self, mock_batch_client, batch_compute_environment_sensor: BatchComputeEnvironmentSensor
116+
):
87117
mock_batch_client.describe_compute_environments.return_value = {"computeEnvironments": []}
88118
with pytest.raises(AirflowException) as ctx:
89-
self.sensor.poke({})
119+
batch_compute_environment_sensor.poke({})
90120
mock_batch_client.describe_compute_environments.assert_called_once_with(
91-
computeEnvironments=[self.environment_name],
121+
computeEnvironments=[ENVIRONMENT_NAME],
92122
)
93123
assert "not found" in str(ctx.value)
94124

95125
@mock.patch.object(BatchClientHook, "client")
96-
def test_poke_valid(self, mock_batch_client):
126+
def test_poke_valid(
127+
self, mock_batch_client, batch_compute_environment_sensor: BatchComputeEnvironmentSensor
128+
):
97129
mock_batch_client.describe_compute_environments.return_value = {
98130
"computeEnvironments": [{"status": "VALID"}]
99131
}
100-
assert self.sensor.poke({}) is True
132+
assert batch_compute_environment_sensor.poke({}) is True
101133
mock_batch_client.describe_compute_environments.assert_called_once_with(
102-
computeEnvironments=[self.environment_name],
134+
computeEnvironments=[ENVIRONMENT_NAME],
103135
)
104136

105137
@mock.patch.object(BatchClientHook, "client")
106-
def test_poke_running(self, mock_batch_client):
138+
def test_poke_running(
139+
self, mock_batch_client, batch_compute_environment_sensor: BatchComputeEnvironmentSensor
140+
):
107141
mock_batch_client.describe_compute_environments.return_value = {
108142
"computeEnvironments": [
109143
{
110144
"status": "CREATING",
111145
}
112146
]
113147
}
114-
assert self.sensor.poke({}) is False
148+
assert batch_compute_environment_sensor.poke({}) is False
115149
mock_batch_client.describe_compute_environments.assert_called_once_with(
116-
computeEnvironments=[self.environment_name],
150+
computeEnvironments=[ENVIRONMENT_NAME],
117151
)
118152

119153
@mock.patch.object(BatchClientHook, "client")
120-
def test_poke_invalid(self, mock_batch_client):
154+
def test_poke_invalid(
155+
self, mock_batch_client, batch_compute_environment_sensor: BatchComputeEnvironmentSensor
156+
):
121157
mock_batch_client.describe_compute_environments.return_value = {
122158
"computeEnvironments": [
123159
{
@@ -126,64 +162,67 @@ def test_poke_invalid(self, mock_batch_client):
126162
]
127163
}
128164
with pytest.raises(AirflowException) as ctx:
129-
self.sensor.poke({})
165+
batch_compute_environment_sensor.poke({})
130166
mock_batch_client.describe_compute_environments.assert_called_once_with(
131-
computeEnvironments=[self.environment_name],
167+
computeEnvironments=[ENVIRONMENT_NAME],
132168
)
133169
assert "AWS Batch compute environment failed" in str(ctx.value)
134170

135171

136-
class TestBatchJobQueueSensor:
137-
def setup_method(self):
138-
self.job_queue = "job_queue"
139-
self.sensor = BatchJobQueueSensor(
140-
task_id="test_batch_job_queue_sensor",
141-
job_queue=self.job_queue,
142-
)
172+
@pytest.fixture(scope="module")
173+
def batch_job_queue_sensor() -> BatchJobQueueSensor:
174+
return BatchJobQueueSensor(
175+
task_id="test_batch_job_queue_sensor",
176+
job_queue=JOB_QUEUE,
177+
)
178+
143179

180+
class TestBatchJobQueueSensor:
144181
@mock.patch.object(BatchClientHook, "client")
145-
def test_poke_no_queue(self, mock_batch_client):
182+
def test_poke_no_queue(self, mock_batch_client, batch_job_queue_sensor: BatchJobQueueSensor):
146183
mock_batch_client.describe_job_queues.return_value = {"jobQueues": []}
147184
with pytest.raises(AirflowException) as ctx:
148-
self.sensor.poke({})
185+
batch_job_queue_sensor.poke({})
149186
mock_batch_client.describe_job_queues.assert_called_once_with(
150-
jobQueues=[self.job_queue],
187+
jobQueues=[JOB_QUEUE],
151188
)
152189
assert "not found" in str(ctx.value)
153190

154191
@mock.patch.object(BatchClientHook, "client")
155-
def test_poke_no_queue_with_treat_non_existing_as_deleted(self, mock_batch_client):
156-
self.sensor.treat_non_existing_as_deleted = True
192+
def test_poke_no_queue_with_treat_non_existing_as_deleted(
193+
self, mock_batch_client, batch_job_queue_sensor: BatchJobQueueSensor
194+
):
195+
batch_job_queue_sensor.treat_non_existing_as_deleted = True
157196
mock_batch_client.describe_job_queues.return_value = {"jobQueues": []}
158-
assert self.sensor.poke({}) is True
197+
assert batch_job_queue_sensor.poke({}) is True
159198
mock_batch_client.describe_job_queues.assert_called_once_with(
160-
jobQueues=[self.job_queue],
199+
jobQueues=[JOB_QUEUE],
161200
)
162201

163202
@mock.patch.object(BatchClientHook, "client")
164-
def test_poke_valid(self, mock_batch_client):
203+
def test_poke_valid(self, mock_batch_client, batch_job_queue_sensor: BatchJobQueueSensor):
165204
mock_batch_client.describe_job_queues.return_value = {"jobQueues": [{"status": "VALID"}]}
166-
assert self.sensor.poke({}) is True
205+
assert batch_job_queue_sensor.poke({}) is True
167206
mock_batch_client.describe_job_queues.assert_called_once_with(
168-
jobQueues=[self.job_queue],
207+
jobQueues=[JOB_QUEUE],
169208
)
170209

171210
@mock.patch.object(BatchClientHook, "client")
172-
def test_poke_running(self, mock_batch_client):
211+
def test_poke_running(self, mock_batch_client, batch_job_queue_sensor: BatchJobQueueSensor):
173212
mock_batch_client.describe_job_queues.return_value = {
174213
"jobQueues": [
175214
{
176215
"status": "CREATING",
177216
}
178217
]
179218
}
180-
assert self.sensor.poke({}) is False
219+
assert batch_job_queue_sensor.poke({}) is False
181220
mock_batch_client.describe_job_queues.assert_called_once_with(
182-
jobQueues=[self.job_queue],
221+
jobQueues=[JOB_QUEUE],
183222
)
184223

185224
@mock.patch.object(BatchClientHook, "client")
186-
def test_poke_invalid(self, mock_batch_client):
225+
def test_poke_invalid(self, mock_batch_client, batch_job_queue_sensor: BatchJobQueueSensor):
187226
mock_batch_client.describe_job_queues.return_value = {
188227
"jobQueues": [
189228
{
@@ -192,28 +231,8 @@ def test_poke_invalid(self, mock_batch_client):
192231
]
193232
}
194233
with pytest.raises(AirflowException) as ctx:
195-
self.sensor.poke({})
234+
batch_job_queue_sensor.poke({})
196235
mock_batch_client.describe_job_queues.assert_called_once_with(
197-
jobQueues=[self.job_queue],
236+
jobQueues=[JOB_QUEUE],
198237
)
199238
assert "AWS Batch job queue failed" in str(ctx.value)
200-
201-
202-
class TestBatchAsyncSensor:
203-
TASK = BatchSensor(task_id="task", job_id=JOB_ID, region_name=AWS_REGION, deferrable=True)
204-
205-
def test_batch_sensor_async(self):
206-
"""
207-
Asserts that a task is deferred and a BatchSensorTrigger will be fired
208-
when the BatchSensorAsync is executed.
209-
"""
210-
211-
with pytest.raises(TaskDeferred) as exc:
212-
self.TASK.execute({})
213-
assert isinstance(exc.value.trigger, BatchJobTrigger), "Trigger is not a BatchJobTrigger"
214-
215-
def test_batch_sensor_async_execute_failure(self):
216-
"""Tests that an AirflowException is raised in case of error event"""
217-
218-
with pytest.raises(AirflowException):
219-
self.TASK.execute_complete(context={}, event={"status": "failure"})

0 commit comments

Comments
 (0)