Skip to content

Commit 3b78c6a

Browse files
fix: Use source hook instead of destination hook when reading records in non-paginated mode in GenericTransfer (apache#50598)
* fix: Use source hook when reading records in non-paginated mode in GenericTransfer * refactor: Moved mocked hooks as class variable in test GenericTransfer --------- Co-authored-by: David Blain <[email protected]>
1 parent f80a2a5 commit 3b78c6a

File tree

2 files changed

+88
-36
lines changed

2 files changed

+88
-36
lines changed

providers/common/sql/src/airflow/providers/common/sql/operators/generic_transfer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def execute(self, context: Context):
162162
self.log.info("Extracting data from %s", self.source_conn_id)
163163
self.log.info("Executing: \n %s", self.sql)
164164

165-
results = self.destination_hook.get_records(self.sql)
165+
results = self.source_hook.get_records(self.sql)
166166

167167
self.log.info("Inserting rows into %s", self.destination_conn_id)
168168
self.destination_hook.insert_rows(table=self.destination_table, rows=results, **self.insert_args)

providers/common/sql/tests/unit/common/sql/operators/test_generic_transfer.py

Lines changed: 87 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from unittest.mock import MagicMock
2525

2626
import pytest
27+
from more_itertools import flatten
2728

2829
from airflow.exceptions import AirflowProviderDeprecationWarning
2930
from airflow.models.connection import Connection
@@ -34,7 +35,7 @@
3435
from airflow.utils import timezone
3536

3637
from tests_common.test_utils.compat import GenericTransfer
37-
from tests_common.test_utils.operators.run_deferrable import execute_operator
38+
from tests_common.test_utils.operators.run_deferrable import execute_operator, mock_context
3839
from tests_common.test_utils.providers import get_provider_min_airflow_version
3940

4041
pytestmark = pytest.mark.db_test
@@ -43,6 +44,12 @@
4344
DEFAULT_DATE_ISO = DEFAULT_DATE.isoformat()
4445
DEFAULT_DATE_DS = DEFAULT_DATE_ISO[:10]
4546
TEST_DAG_ID = "unit_test_dag"
47+
INSERT_ARGS = {
48+
"commit_every": 1000, # Number of rows inserted in each batch
49+
"executemany": True, # Enable batch inserts
50+
"fast_executemany": True, # Boost performance for MSSQL inserts
51+
"replace": True, # Used for upserts/merges if needed
52+
}
4653
counter = 0
4754

4855

@@ -175,6 +182,44 @@ def test_postgres_to_postgres_replace(self, mock_insert, dag_maker):
175182

176183

177184
class TestGenericTransfer:
185+
mocked_source_hook = MagicMock(conn_name_attr="my_source_conn_id", spec=DbApiHook)
186+
mocked_destination_hook = MagicMock(conn_name_attr="my_destination_conn_id", spec=DbApiHook)
187+
mocked_hooks = {
188+
"my_source_conn_id": mocked_source_hook,
189+
"my_destination_conn_id": mocked_destination_hook,
190+
}
191+
192+
@classmethod
193+
def get_hook(cls, conn_id: str, hook_params: dict | None = None):
194+
return cls.mocked_hooks[conn_id]
195+
196+
@classmethod
197+
def get_connection(cls, conn_id: str):
198+
mocked_hook = cls.get_hook(conn_id=conn_id)
199+
mocked_conn = MagicMock(conn_id=conn_id, spec=Connection)
200+
mocked_conn.get_hook.return_value = mocked_hook
201+
return mocked_conn
202+
203+
def setup_method(self):
204+
# Reset mock states before each test
205+
self.mocked_source_hook.reset_mock()
206+
self.mocked_destination_hook.reset_mock()
207+
208+
# Set up the side effect for paginated read
209+
records = [
210+
[[1, 2], [11, 12], [3, 4], [13, 14]],
211+
[[3, 4], [13, 14]],
212+
]
213+
214+
def get_records_side_effect(sql: str):
215+
if records:
216+
if "LIMIT" not in sql:
217+
return list(flatten(records))
218+
return records.pop(0)
219+
return []
220+
221+
self.mocked_source_hook.get_records.side_effect = get_records_side_effect
222+
178223
def test_templated_fields(self):
179224
dag = DAG(
180225
"test_dag",
@@ -209,53 +254,45 @@ def test_templated_fields(self):
209254
assert operator.preoperator == "my_preoperator"
210255
assert operator.insert_args == {"commit_every": 5000, "executemany": True, "replace": True}
211256

257+
def test_non_paginated_read(self):
258+
with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=self.get_connection):
259+
with mock.patch("airflow.hooks.base.BaseHook.get_hook", side_effect=self.get_hook):
260+
operator = GenericTransfer(
261+
task_id="transfer_table",
262+
source_conn_id="my_source_conn_id",
263+
destination_conn_id="my_destination_conn_id",
264+
sql="SELECT * FROM HR.EMPLOYEES",
265+
destination_table="NEW_HR.EMPLOYEES",
266+
insert_args=INSERT_ARGS,
267+
execution_timeout=timedelta(hours=1),
268+
)
269+
270+
operator.execute(context=mock_context(task=operator))
271+
272+
assert self.mocked_source_hook.get_records.call_count == 1
273+
assert self.mocked_source_hook.get_records.call_args_list[0].args[0] == "SELECT * FROM HR.EMPLOYEES"
274+
assert self.mocked_destination_hook.insert_rows.call_count == 1
275+
assert self.mocked_destination_hook.insert_rows.call_args_list[0].kwargs == {
276+
**INSERT_ARGS,
277+
**{"rows": [[1, 2], [11, 12], [3, 4], [13, 14], [3, 4], [13, 14]], "table": "NEW_HR.EMPLOYEES"},
278+
}
279+
212280
def test_paginated_read(self):
213281
"""
214282
This unit test is based on the example described in the medium article:
215283
https://medium.com/apache-airflow/transfering-data-from-sap-hana-to-mssql-using-the-airflow-generictransfer-d29f147a9f1f
216284
"""
217285

218-
def create_get_records_side_effect():
219-
records = [
220-
[[1, 2], [11, 12], [3, 4], [13, 14]],
221-
[[3, 4], [13, 14]],
222-
]
223-
224-
def side_effect(sql: str):
225-
if records:
226-
return records.pop(0)
227-
return []
228-
229-
return side_effect
230-
231-
get_records_side_effect = create_get_records_side_effect()
232-
233-
def get_hook(conn_id: str, hook_params: dict | None = None):
234-
mocked_hook = MagicMock(conn_name_attr=conn_id, spec=DbApiHook)
235-
mocked_hook.get_records.side_effect = get_records_side_effect
236-
return mocked_hook
237-
238-
def get_connection(conn_id: str):
239-
mocked_hook = get_hook(conn_id=conn_id)
240-
mocked_conn = MagicMock(conn_id=conn_id, spec=Connection)
241-
mocked_conn.get_hook.return_value = mocked_hook
242-
return mocked_conn
243-
244-
with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_connection):
245-
with mock.patch("airflow.hooks.base.BaseHook.get_hook", side_effect=get_hook):
286+
with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=self.get_connection):
287+
with mock.patch("airflow.hooks.base.BaseHook.get_hook", side_effect=self.get_hook):
246288
operator = GenericTransfer(
247289
task_id="transfer_table",
248290
source_conn_id="my_source_conn_id",
249291
destination_conn_id="my_destination_conn_id",
250292
sql="SELECT * FROM HR.EMPLOYEES",
251293
destination_table="NEW_HR.EMPLOYEES",
252294
page_size=1000, # Fetch data in chunks of 1000 rows for pagination
253-
insert_args={
254-
"commit_every": 1000, # Number of rows inserted in each batch
255-
"executemany": True, # Enable batch inserts
256-
"fast_executemany": True, # Boost performance for MSSQL inserts
257-
"replace": True, # Used for upserts/merges if needed
258-
},
295+
insert_args=INSERT_ARGS,
259296
execution_timeout=timedelta(hours=1),
260297
)
261298

@@ -267,6 +304,21 @@ def get_connection(conn_id: str):
267304
assert events[1].payload["results"] == [[3, 4], [13, 14]]
268305
assert not events[2].payload["results"]
269306

307+
assert self.mocked_source_hook.get_records.call_count == 3
308+
assert (
309+
self.mocked_source_hook.get_records.call_args_list[0].args[0]
310+
== "SELECT * FROM HR.EMPLOYEES LIMIT 1000 OFFSET 0"
311+
)
312+
assert self.mocked_destination_hook.insert_rows.call_count == 2
313+
assert self.mocked_destination_hook.insert_rows.call_args_list[0].kwargs == {
314+
**INSERT_ARGS,
315+
**{"rows": [[1, 2], [11, 12], [3, 4], [13, 14]], "table": "NEW_HR.EMPLOYEES"},
316+
}
317+
assert self.mocked_destination_hook.insert_rows.call_args_list[1].kwargs == {
318+
**INSERT_ARGS,
319+
**{"rows": [[3, 4], [13, 14]], "table": "NEW_HR.EMPLOYEES"},
320+
}
321+
270322
def test_when_provider_min_airflow_version_is_3_0_or_higher_remove_obsolete_method(self):
271323
"""
272324
Once this test starts failing due to the fact that the minimum Airflow version is now 3.0.0 or higher

0 commit comments

Comments
 (0)