32
32
TASK_ID = "batch_job_sensor"
33
33
JOB_ID = "8222a1c2-b246-4e19-b1b8-0039bb4407c0"
34
34
AWS_REGION = "eu-west-1"
35
+ ENVIRONMENT_NAME = "environment_name"
36
+ JOB_QUEUE = "job_queue"
35
37
36
38
37
- class TestBatchSensor :
38
- def setup_method (self ):
39
- self .batch_sensor = BatchSensor (
40
- task_id = "batch_job_sensor" ,
41
- job_id = JOB_ID ,
42
- )
39
+ @pytest .fixture (scope = "module" )
40
+ def batch_sensor () -> BatchSensor :
41
+ return BatchSensor (
42
+ task_id = "batch_job_sensor" ,
43
+ job_id = JOB_ID ,
44
+ )
45
+
46
+
47
+ @pytest .fixture (scope = "module" )
48
+ def deferrable_batch_sensor () -> BatchSensor :
49
+ return BatchSensor (task_id = "task" , job_id = JOB_ID , region_name = AWS_REGION , deferrable = True )
50
+
43
51
52
+ class TestBatchSensor :
44
53
@mock .patch .object (BatchClientHook , "get_job_description" )
45
- def test_poke_on_success_state (self , mock_get_job_description ):
54
+ def test_poke_on_success_state (self , mock_get_job_description , batch_sensor : BatchSensor ):
46
55
mock_get_job_description .return_value = {"status" : "SUCCEEDED" }
47
- assert self . batch_sensor .poke ({}) is True
56
+ assert batch_sensor .poke ({}) is True
48
57
mock_get_job_description .assert_called_once_with (JOB_ID )
49
58
50
59
@mock .patch .object (BatchClientHook , "get_job_description" )
51
- def test_poke_on_failure_state (self , mock_get_job_description ):
60
+ def test_poke_on_failure_state (self , mock_get_job_description , batch_sensor : BatchSensor ):
52
61
mock_get_job_description .return_value = {"status" : "FAILED" }
53
62
with pytest .raises (AirflowException , match = "Batch sensor failed. AWS Batch job status: FAILED" ):
54
- self . batch_sensor .poke ({})
63
+ batch_sensor .poke ({})
55
64
56
65
mock_get_job_description .assert_called_once_with (JOB_ID )
57
66
58
67
@mock .patch .object (BatchClientHook , "get_job_description" )
59
- def test_poke_on_invalid_state (self , mock_get_job_description ):
68
+ def test_poke_on_invalid_state (self , mock_get_job_description , batch_sensor : BatchSensor ):
60
69
mock_get_job_description .return_value = {"status" : "INVALID" }
61
70
with pytest .raises (
62
71
AirflowException , match = "Batch sensor failed. Unknown AWS Batch job status: INVALID"
63
72
):
64
- self . batch_sensor .poke ({})
73
+ batch_sensor .poke ({})
65
74
66
75
mock_get_job_description .assert_called_once_with (JOB_ID )
67
76
68
77
@pytest .mark .parametrize ("job_status" , ["SUBMITTED" , "PENDING" , "RUNNABLE" , "STARTING" , "RUNNING" ])
69
78
@mock .patch .object (BatchClientHook , "get_job_description" )
70
- def test_poke_on_intermediate_state (self , mock_get_job_description , job_status ):
79
+ def test_poke_on_intermediate_state (
80
+ self , mock_get_job_description , job_status , batch_sensor : BatchSensor
81
+ ):
71
82
print (job_status )
72
83
mock_get_job_description .return_value = {"status" : job_status }
73
- assert self . batch_sensor .poke ({}) is False
84
+ assert batch_sensor .poke ({}) is False
74
85
mock_get_job_description .assert_called_once_with (JOB_ID )
75
86
87
+ def test_execute_in_deferrable_mode (self , deferrable_batch_sensor : BatchSensor ):
88
+ """
89
+ Asserts that a task is deferred and a BatchSensorTrigger will be fired
90
+ when the BatchSensor is executed in deferrable mode.
91
+ """
76
92
77
- class TestBatchComputeEnvironmentSensor :
78
- def setup_method (self ):
79
- self .environment_name = "environment_name"
80
- self .sensor = BatchComputeEnvironmentSensor (
81
- task_id = "test_batch_compute_environment_sensor" ,
82
- compute_environment = self .environment_name ,
83
- )
93
+ with pytest .raises (TaskDeferred ) as exc :
94
+ deferrable_batch_sensor .execute ({})
95
+ assert isinstance (exc .value .trigger , BatchJobTrigger ), "Trigger is not a BatchJobTrigger"
96
+
97
+ def test_execute_failure_in_deferrable_mode (self , deferrable_batch_sensor : BatchSensor ):
98
+ """Tests that an AirflowException is raised in case of error event"""
99
+
100
+ with pytest .raises (AirflowException ):
101
+ deferrable_batch_sensor .execute_complete (context = {}, event = {"status" : "failure" })
102
+
103
+
104
+ @pytest .fixture (scope = "module" )
105
+ def batch_compute_environment_sensor () -> BatchComputeEnvironmentSensor :
106
+ return BatchComputeEnvironmentSensor (
107
+ task_id = "test_batch_compute_environment_sensor" ,
108
+ compute_environment = ENVIRONMENT_NAME ,
109
+ )
84
110
111
+
112
+ class TestBatchComputeEnvironmentSensor :
85
113
@mock .patch .object (BatchClientHook , "client" )
86
- def test_poke_no_environment (self , mock_batch_client ):
114
+ def test_poke_no_environment (
115
+ self , mock_batch_client , batch_compute_environment_sensor : BatchComputeEnvironmentSensor
116
+ ):
87
117
mock_batch_client .describe_compute_environments .return_value = {"computeEnvironments" : []}
88
118
with pytest .raises (AirflowException ) as ctx :
89
- self . sensor .poke ({})
119
+ batch_compute_environment_sensor .poke ({})
90
120
mock_batch_client .describe_compute_environments .assert_called_once_with (
91
- computeEnvironments = [self . environment_name ],
121
+ computeEnvironments = [ENVIRONMENT_NAME ],
92
122
)
93
123
assert "not found" in str (ctx .value )
94
124
95
125
@mock .patch .object (BatchClientHook , "client" )
96
- def test_poke_valid (self , mock_batch_client ):
126
+ def test_poke_valid (
127
+ self , mock_batch_client , batch_compute_environment_sensor : BatchComputeEnvironmentSensor
128
+ ):
97
129
mock_batch_client .describe_compute_environments .return_value = {
98
130
"computeEnvironments" : [{"status" : "VALID" }]
99
131
}
100
- assert self . sensor .poke ({}) is True
132
+ assert batch_compute_environment_sensor .poke ({}) is True
101
133
mock_batch_client .describe_compute_environments .assert_called_once_with (
102
- computeEnvironments = [self . environment_name ],
134
+ computeEnvironments = [ENVIRONMENT_NAME ],
103
135
)
104
136
105
137
@mock .patch .object (BatchClientHook , "client" )
106
- def test_poke_running (self , mock_batch_client ):
138
+ def test_poke_running (
139
+ self , mock_batch_client , batch_compute_environment_sensor : BatchComputeEnvironmentSensor
140
+ ):
107
141
mock_batch_client .describe_compute_environments .return_value = {
108
142
"computeEnvironments" : [
109
143
{
110
144
"status" : "CREATING" ,
111
145
}
112
146
]
113
147
}
114
- assert self . sensor .poke ({}) is False
148
+ assert batch_compute_environment_sensor .poke ({}) is False
115
149
mock_batch_client .describe_compute_environments .assert_called_once_with (
116
- computeEnvironments = [self . environment_name ],
150
+ computeEnvironments = [ENVIRONMENT_NAME ],
117
151
)
118
152
119
153
@mock .patch .object (BatchClientHook , "client" )
120
- def test_poke_invalid (self , mock_batch_client ):
154
+ def test_poke_invalid (
155
+ self , mock_batch_client , batch_compute_environment_sensor : BatchComputeEnvironmentSensor
156
+ ):
121
157
mock_batch_client .describe_compute_environments .return_value = {
122
158
"computeEnvironments" : [
123
159
{
@@ -126,64 +162,67 @@ def test_poke_invalid(self, mock_batch_client):
126
162
]
127
163
}
128
164
with pytest .raises (AirflowException ) as ctx :
129
- self . sensor .poke ({})
165
+ batch_compute_environment_sensor .poke ({})
130
166
mock_batch_client .describe_compute_environments .assert_called_once_with (
131
- computeEnvironments = [self . environment_name ],
167
+ computeEnvironments = [ENVIRONMENT_NAME ],
132
168
)
133
169
assert "AWS Batch compute environment failed" in str (ctx .value )
134
170
135
171
136
- class TestBatchJobQueueSensor :
137
- def setup_method ( self ) :
138
- self . job_queue = "job_queue"
139
- self . sensor = BatchJobQueueSensor (
140
- task_id = "test_batch_job_queue_sensor" ,
141
- job_queue = self . job_queue ,
142
- )
172
+ @ pytest . fixture ( scope = "module" )
173
+ def batch_job_queue_sensor () -> BatchJobQueueSensor :
174
+ return BatchJobQueueSensor (
175
+ task_id = "test_batch_job_queue_sensor" ,
176
+ job_queue = JOB_QUEUE ,
177
+ )
178
+
143
179
180
+ class TestBatchJobQueueSensor :
144
181
@mock .patch .object (BatchClientHook , "client" )
145
- def test_poke_no_queue (self , mock_batch_client ):
182
+ def test_poke_no_queue (self , mock_batch_client , batch_job_queue_sensor : BatchJobQueueSensor ):
146
183
mock_batch_client .describe_job_queues .return_value = {"jobQueues" : []}
147
184
with pytest .raises (AirflowException ) as ctx :
148
- self . sensor .poke ({})
185
+ batch_job_queue_sensor .poke ({})
149
186
mock_batch_client .describe_job_queues .assert_called_once_with (
150
- jobQueues = [self . job_queue ],
187
+ jobQueues = [JOB_QUEUE ],
151
188
)
152
189
assert "not found" in str (ctx .value )
153
190
154
191
@mock .patch .object (BatchClientHook , "client" )
155
- def test_poke_no_queue_with_treat_non_existing_as_deleted (self , mock_batch_client ):
156
- self .sensor .treat_non_existing_as_deleted = True
192
+ def test_poke_no_queue_with_treat_non_existing_as_deleted (
193
+ self , mock_batch_client , batch_job_queue_sensor : BatchJobQueueSensor
194
+ ):
195
+ batch_job_queue_sensor .treat_non_existing_as_deleted = True
157
196
mock_batch_client .describe_job_queues .return_value = {"jobQueues" : []}
158
- assert self . sensor .poke ({}) is True
197
+ assert batch_job_queue_sensor .poke ({}) is True
159
198
mock_batch_client .describe_job_queues .assert_called_once_with (
160
- jobQueues = [self . job_queue ],
199
+ jobQueues = [JOB_QUEUE ],
161
200
)
162
201
163
202
@mock .patch .object (BatchClientHook , "client" )
164
- def test_poke_valid (self , mock_batch_client ):
203
+ def test_poke_valid (self , mock_batch_client , batch_job_queue_sensor : BatchJobQueueSensor ):
165
204
mock_batch_client .describe_job_queues .return_value = {"jobQueues" : [{"status" : "VALID" }]}
166
- assert self . sensor .poke ({}) is True
205
+ assert batch_job_queue_sensor .poke ({}) is True
167
206
mock_batch_client .describe_job_queues .assert_called_once_with (
168
- jobQueues = [self . job_queue ],
207
+ jobQueues = [JOB_QUEUE ],
169
208
)
170
209
171
210
@mock .patch .object (BatchClientHook , "client" )
172
- def test_poke_running (self , mock_batch_client ):
211
+ def test_poke_running (self , mock_batch_client , batch_job_queue_sensor : BatchJobQueueSensor ):
173
212
mock_batch_client .describe_job_queues .return_value = {
174
213
"jobQueues" : [
175
214
{
176
215
"status" : "CREATING" ,
177
216
}
178
217
]
179
218
}
180
- assert self . sensor .poke ({}) is False
219
+ assert batch_job_queue_sensor .poke ({}) is False
181
220
mock_batch_client .describe_job_queues .assert_called_once_with (
182
- jobQueues = [self . job_queue ],
221
+ jobQueues = [JOB_QUEUE ],
183
222
)
184
223
185
224
@mock .patch .object (BatchClientHook , "client" )
186
- def test_poke_invalid (self , mock_batch_client ):
225
+ def test_poke_invalid (self , mock_batch_client , batch_job_queue_sensor : BatchJobQueueSensor ):
187
226
mock_batch_client .describe_job_queues .return_value = {
188
227
"jobQueues" : [
189
228
{
@@ -192,28 +231,8 @@ def test_poke_invalid(self, mock_batch_client):
192
231
]
193
232
}
194
233
with pytest .raises (AirflowException ) as ctx :
195
- self . sensor .poke ({})
234
+ batch_job_queue_sensor .poke ({})
196
235
mock_batch_client .describe_job_queues .assert_called_once_with (
197
- jobQueues = [self . job_queue ],
236
+ jobQueues = [JOB_QUEUE ],
198
237
)
199
238
assert "AWS Batch job queue failed" in str (ctx .value )
200
-
201
-
202
- class TestBatchAsyncSensor :
203
- TASK = BatchSensor (task_id = "task" , job_id = JOB_ID , region_name = AWS_REGION , deferrable = True )
204
-
205
- def test_batch_sensor_async (self ):
206
- """
207
- Asserts that a task is deferred and a BatchSensorTrigger will be fired
208
- when the BatchSensorAsync is executed.
209
- """
210
-
211
- with pytest .raises (TaskDeferred ) as exc :
212
- self .TASK .execute ({})
213
- assert isinstance (exc .value .trigger , BatchJobTrigger ), "Trigger is not a BatchJobTrigger"
214
-
215
- def test_batch_sensor_async_execute_failure (self ):
216
- """Tests that an AirflowException is raised in case of error event"""
217
-
218
- with pytest .raises (AirflowException ):
219
- self .TASK .execute_complete (context = {}, event = {"status" : "failure" })
0 commit comments