Skip to content

Commit f3f71ff

Browse files
committed
Add sql_hook_params parameter to SqlToS3Operator
Adding `sql_hook_params` parameter to `SqlToS3Operator`. This will allow you to pass extra config params to the underlying SQL hook. This uses the same "sql_hook_params" parameter name as already used in `SqlToSlackOperator`.
1 parent 489ca14 commit f3f71ff

File tree

2 files changed

+23
-1
lines changed

2 files changed

+23
-1
lines changed

airflow/providers/amazon/aws/transfers/sql_to_s3.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ class SqlToS3Operator(BaseOperator):
6565
:param s3_key: desired key for the file. It includes the name of the file. (templated)
6666
:param replace: whether or not to replace the file in S3 if it previously existed
6767
:param sql_conn_id: reference to a specific database.
68+
:param sql_hook_params: Extra config params to be passed to the underlying hook.
69+
Should match the desired hook constructor params.
6870
:param parameters: (optional) the parameters to render the SQL query with.
6971
:param aws_conn_id: reference to a specific S3 connection
7072
:param verify: Whether or not to verify SSL certificates for S3 connection.
@@ -100,6 +102,7 @@ def __init__(
100102
s3_bucket: str,
101103
s3_key: str,
102104
sql_conn_id: str,
105+
sql_hook_params: dict | None = None,
103106
parameters: None | Mapping | Iterable = None,
104107
replace: bool = False,
105108
aws_conn_id: str = "aws_default",
@@ -120,6 +123,7 @@ def __init__(
120123
self.pd_kwargs = pd_kwargs or {}
121124
self.parameters = parameters
122125
self.groupby_kwargs = groupby_kwargs or {}
126+
self.sql_hook_params = sql_hook_params
123127

124128
if "path_or_buf" in self.pd_kwargs:
125129
raise AirflowException("The argument path_or_buf is not allowed, please remove it")
@@ -200,7 +204,7 @@ def _partition_dataframe(self, df: DataFrame) -> Iterable[tuple[str, DataFrame]]
200204
def _get_hook(self) -> DbApiHook:
201205
self.log.debug("Get connection for %s", self.sql_conn_id)
202206
conn = BaseHook.get_connection(self.sql_conn_id)
203-
hook = conn.get_hook()
207+
hook = conn.get_hook(hook_params=self.sql_hook_params)
204208
if not callable(getattr(hook, "get_pandas_df", None)):
205209
raise AirflowException(
206210
"This hook is not supported. The hook class must have get_pandas_df method."

tests/providers/amazon/aws/transfers/test_sql_to_s3.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import pytest
2525

2626
from airflow.exceptions import AirflowException
27+
from airflow.models import Connection
2728
from airflow.providers.amazon.aws.transfers.sql_to_s3 import SqlToS3Operator
2829

2930

@@ -269,3 +270,20 @@ def test_without_groupby_kwarg(self):
269270
}
270271
)
271272
)
273+
274+
@mock.patch("airflow.providers.common.sql.operators.sql.BaseHook.get_connection")
275+
def test_hook_params(self, mock_get_conn):
276+
mock_get_conn.return_value = Connection(conn_id="postgres_test", conn_type="postgres")
277+
op = SqlToS3Operator(
278+
query="query",
279+
s3_bucket="bucket",
280+
s3_key="key",
281+
sql_conn_id="postgres_test",
282+
task_id="task_id",
283+
sql_hook_params={
284+
"log_sql": False,
285+
},
286+
dag=None,
287+
)
288+
hook = op._get_hook()
289+
assert hook.log_sql == op.sql_hook_params["log_sql"]

0 commit comments

Comments
 (0)