24
24
from unittest .mock import MagicMock
25
25
26
26
import pytest
27
+ from more_itertools import flatten
27
28
28
29
from airflow .exceptions import AirflowProviderDeprecationWarning
29
30
from airflow .models .connection import Connection
34
35
from airflow .utils import timezone
35
36
36
37
from tests_common .test_utils .compat import GenericTransfer
37
- from tests_common .test_utils .operators .run_deferrable import execute_operator
38
+ from tests_common .test_utils .operators .run_deferrable import execute_operator , mock_context
38
39
from tests_common .test_utils .providers import get_provider_min_airflow_version
39
40
40
41
pytestmark = pytest .mark .db_test
43
44
DEFAULT_DATE_ISO = DEFAULT_DATE .isoformat ()
44
45
DEFAULT_DATE_DS = DEFAULT_DATE_ISO [:10 ]
45
46
TEST_DAG_ID = "unit_test_dag"
47
+ INSERT_ARGS = {
48
+ "commit_every" : 1000 , # Number of rows inserted in each batch
49
+ "executemany" : True , # Enable batch inserts
50
+ "fast_executemany" : True , # Boost performance for MSSQL inserts
51
+ "replace" : True , # Used for upserts/merges if needed
52
+ }
46
53
counter = 0
47
54
48
55
@@ -175,6 +182,44 @@ def test_postgres_to_postgres_replace(self, mock_insert, dag_maker):
175
182
176
183
177
184
class TestGenericTransfer :
185
+ mocked_source_hook = MagicMock (conn_name_attr = "my_source_conn_id" , spec = DbApiHook )
186
+ mocked_destination_hook = MagicMock (conn_name_attr = "my_destination_conn_id" , spec = DbApiHook )
187
+ mocked_hooks = {
188
+ "my_source_conn_id" : mocked_source_hook ,
189
+ "my_destination_conn_id" : mocked_destination_hook ,
190
+ }
191
+
192
+ @classmethod
193
+ def get_hook (cls , conn_id : str , hook_params : dict | None = None ):
194
+ return cls .mocked_hooks [conn_id ]
195
+
196
+ @classmethod
197
+ def get_connection (cls , conn_id : str ):
198
+ mocked_hook = cls .get_hook (conn_id = conn_id )
199
+ mocked_conn = MagicMock (conn_id = conn_id , spec = Connection )
200
+ mocked_conn .get_hook .return_value = mocked_hook
201
+ return mocked_conn
202
+
203
+ def setup_method (self ):
204
+ # Reset mock states before each test
205
+ self .mocked_source_hook .reset_mock ()
206
+ self .mocked_destination_hook .reset_mock ()
207
+
208
+ # Set up the side effect for paginated read
209
+ records = [
210
+ [[1 , 2 ], [11 , 12 ], [3 , 4 ], [13 , 14 ]],
211
+ [[3 , 4 ], [13 , 14 ]],
212
+ ]
213
+
214
+ def get_records_side_effect (sql : str ):
215
+ if records :
216
+ if "LIMIT" not in sql :
217
+ return list (flatten (records ))
218
+ return records .pop (0 )
219
+ return []
220
+
221
+ self .mocked_source_hook .get_records .side_effect = get_records_side_effect
222
+
178
223
def test_templated_fields (self ):
179
224
dag = DAG (
180
225
"test_dag" ,
@@ -209,53 +254,45 @@ def test_templated_fields(self):
209
254
assert operator .preoperator == "my_preoperator"
210
255
assert operator .insert_args == {"commit_every" : 5000 , "executemany" : True , "replace" : True }
211
256
257
+ def test_non_paginated_read (self ):
258
+ with mock .patch ("airflow.hooks.base.BaseHook.get_connection" , side_effect = self .get_connection ):
259
+ with mock .patch ("airflow.hooks.base.BaseHook.get_hook" , side_effect = self .get_hook ):
260
+ operator = GenericTransfer (
261
+ task_id = "transfer_table" ,
262
+ source_conn_id = "my_source_conn_id" ,
263
+ destination_conn_id = "my_destination_conn_id" ,
264
+ sql = "SELECT * FROM HR.EMPLOYEES" ,
265
+ destination_table = "NEW_HR.EMPLOYEES" ,
266
+ insert_args = INSERT_ARGS ,
267
+ execution_timeout = timedelta (hours = 1 ),
268
+ )
269
+
270
+ operator .execute (context = mock_context (task = operator ))
271
+
272
+ assert self .mocked_source_hook .get_records .call_count == 1
273
+ assert self .mocked_source_hook .get_records .call_args_list [0 ].args [0 ] == "SELECT * FROM HR.EMPLOYEES"
274
+ assert self .mocked_destination_hook .insert_rows .call_count == 1
275
+ assert self .mocked_destination_hook .insert_rows .call_args_list [0 ].kwargs == {
276
+ ** INSERT_ARGS ,
277
+ ** {"rows" : [[1 , 2 ], [11 , 12 ], [3 , 4 ], [13 , 14 ], [3 , 4 ], [13 , 14 ]], "table" : "NEW_HR.EMPLOYEES" },
278
+ }
279
+
212
280
def test_paginated_read (self ):
213
281
"""
214
282
This unit test is based on the example described in the medium article:
215
283
https://medium.com/apache-airflow/transfering-data-from-sap-hana-to-mssql-using-the-airflow-generictransfer-d29f147a9f1f
216
284
"""
217
285
218
- def create_get_records_side_effect ():
219
- records = [
220
- [[1 , 2 ], [11 , 12 ], [3 , 4 ], [13 , 14 ]],
221
- [[3 , 4 ], [13 , 14 ]],
222
- ]
223
-
224
- def side_effect (sql : str ):
225
- if records :
226
- return records .pop (0 )
227
- return []
228
-
229
- return side_effect
230
-
231
- get_records_side_effect = create_get_records_side_effect ()
232
-
233
- def get_hook (conn_id : str , hook_params : dict | None = None ):
234
- mocked_hook = MagicMock (conn_name_attr = conn_id , spec = DbApiHook )
235
- mocked_hook .get_records .side_effect = get_records_side_effect
236
- return mocked_hook
237
-
238
- def get_connection (conn_id : str ):
239
- mocked_hook = get_hook (conn_id = conn_id )
240
- mocked_conn = MagicMock (conn_id = conn_id , spec = Connection )
241
- mocked_conn .get_hook .return_value = mocked_hook
242
- return mocked_conn
243
-
244
- with mock .patch ("airflow.hooks.base.BaseHook.get_connection" , side_effect = get_connection ):
245
- with mock .patch ("airflow.hooks.base.BaseHook.get_hook" , side_effect = get_hook ):
286
+ with mock .patch ("airflow.hooks.base.BaseHook.get_connection" , side_effect = self .get_connection ):
287
+ with mock .patch ("airflow.hooks.base.BaseHook.get_hook" , side_effect = self .get_hook ):
246
288
operator = GenericTransfer (
247
289
task_id = "transfer_table" ,
248
290
source_conn_id = "my_source_conn_id" ,
249
291
destination_conn_id = "my_destination_conn_id" ,
250
292
sql = "SELECT * FROM HR.EMPLOYEES" ,
251
293
destination_table = "NEW_HR.EMPLOYEES" ,
252
294
page_size = 1000 , # Fetch data in chunks of 1000 rows for pagination
253
- insert_args = {
254
- "commit_every" : 1000 , # Number of rows inserted in each batch
255
- "executemany" : True , # Enable batch inserts
256
- "fast_executemany" : True , # Boost performance for MSSQL inserts
257
- "replace" : True , # Used for upserts/merges if needed
258
- },
295
+ insert_args = INSERT_ARGS ,
259
296
execution_timeout = timedelta (hours = 1 ),
260
297
)
261
298
@@ -267,6 +304,21 @@ def get_connection(conn_id: str):
267
304
assert events [1 ].payload ["results" ] == [[3 , 4 ], [13 , 14 ]]
268
305
assert not events [2 ].payload ["results" ]
269
306
307
+ assert self .mocked_source_hook .get_records .call_count == 3
308
+ assert (
309
+ self .mocked_source_hook .get_records .call_args_list [0 ].args [0 ]
310
+ == "SELECT * FROM HR.EMPLOYEES LIMIT 1000 OFFSET 0"
311
+ )
312
+ assert self .mocked_destination_hook .insert_rows .call_count == 2
313
+ assert self .mocked_destination_hook .insert_rows .call_args_list [0 ].kwargs == {
314
+ ** INSERT_ARGS ,
315
+ ** {"rows" : [[1 , 2 ], [11 , 12 ], [3 , 4 ], [13 , 14 ]], "table" : "NEW_HR.EMPLOYEES" },
316
+ }
317
+ assert self .mocked_destination_hook .insert_rows .call_args_list [1 ].kwargs == {
318
+ ** INSERT_ARGS ,
319
+ ** {"rows" : [[3 , 4 ], [13 , 14 ]], "table" : "NEW_HR.EMPLOYEES" },
320
+ }
321
+
270
322
def test_when_provider_min_airflow_version_is_3_0_or_higher_remove_obsolete_method (self ):
271
323
"""
272
324
Once this test starts failing due to the fact that the minimum Airflow version is now 3.0.0 or higher
0 commit comments