Skip to content

Add parameter sftp_prefetch to SFTPToGCSOperator #33274

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Aug 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion airflow/providers/google/cloud/transfers/sftp_to_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class SFTPToGCSOperator(BaseOperator):
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:param sftp_prefetch: Whether to enable SFTP prefetch, the default is True.
"""

template_fields: Sequence[str] = (
Expand All @@ -90,6 +91,7 @@ def __init__(
gzip: bool = False,
move_object: bool = False,
impersonation_chain: str | Sequence[str] | None = None,
sftp_prefetch: bool = True,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -103,6 +105,7 @@ def __init__(
self.sftp_conn_id = sftp_conn_id
self.move_object = move_object
self.impersonation_chain = impersonation_chain
self.sftp_prefetch = sftp_prefetch

def execute(self, context: Context):
gcs_hook = GCSHook(
Expand Down Expand Up @@ -151,7 +154,7 @@ def _copy_single_object(
)

with NamedTemporaryFile("w") as tmp:
sftp_hook.retrieve_file(source_path, tmp.name)
sftp_hook.retrieve_file(source_path, tmp.name, prefetch=self.sftp_prefetch)

gcs_hook.upload(
bucket_name=self.destination_bucket,
Expand Down
5 changes: 3 additions & 2 deletions airflow/providers/sftp/hooks/sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,17 +223,18 @@ def delete_directory(self, path: str) -> None:
conn = self.get_conn()
conn.rmdir(path)

def retrieve_file(self, remote_full_path: str, local_full_path: str) -> None:
def retrieve_file(self, remote_full_path: str, local_full_path: str, prefetch: bool = True) -> None:
"""Transfer the remote file to a local location.

If local_full_path is a string path, the file will be put
at that location.

:param remote_full_path: full path to the remote file
:param local_full_path: full path to the local file
:param prefetch: controls whether prefetch is performed (default: True)
"""
conn = self.get_conn()
conn.get(remote_full_path, local_full_path)
conn.get(remote_full_path, local_full_path, prefetch=prefetch)

def store_file(self, remote_full_path: str, local_full_path: str, confirm: bool = True) -> None:
"""Transfer a local file to the remote location.
Expand Down
1 change: 1 addition & 0 deletions airflow/providers/sftp/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ versions:
dependencies:
- apache-airflow>=2.4.0
- apache-airflow-providers-ssh>=2.1.0
- paramiko>=2.8.0

integrations:
- integration-name: SSH File Transfer Protocol (SFTP)
Expand Down
3 changes: 2 additions & 1 deletion generated/provider_dependencies.json
Original file line number Diff line number Diff line change
Expand Up @@ -794,7 +794,8 @@
"sftp": {
"deps": [
"apache-airflow-providers-ssh>=2.1.0",
"apache-airflow>=2.4.0"
"apache-airflow>=2.4.0",
"paramiko>=2.8.0"
],
"cross-providers-deps": [
"openlineage",
Expand Down
12 changes: 7 additions & 5 deletions tests/providers/google/cloud/transfers/test_sftp_to_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def test_execute_copy_single_file(self, sftp_hook, gcs_hook):
sftp_hook.assert_called_once_with(SFTP_CONN_ID)

sftp_hook.return_value.retrieve_file.assert_called_once_with(
os.path.join(SOURCE_OBJECT_NO_WILDCARD), mock.ANY
os.path.join(SOURCE_OBJECT_NO_WILDCARD), mock.ANY, prefetch=True
)

gcs_hook.return_value.upload.assert_called_once_with(
Expand All @@ -99,6 +99,7 @@ def test_execute_copy_single_file_with_compression(self, sftp_hook, gcs_hook):
sftp_conn_id=SFTP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
gzip=True,
sftp_prefetch=False,
)
task.execute(None)
gcs_hook.assert_called_once_with(
Expand All @@ -108,7 +109,7 @@ def test_execute_copy_single_file_with_compression(self, sftp_hook, gcs_hook):
sftp_hook.assert_called_once_with(SFTP_CONN_ID)

sftp_hook.return_value.retrieve_file.assert_called_once_with(
os.path.join(SOURCE_OBJECT_NO_WILDCARD), mock.ANY
os.path.join(SOURCE_OBJECT_NO_WILDCARD), mock.ANY, prefetch=False
)

gcs_hook.return_value.upload.assert_called_once_with(
Expand All @@ -133,6 +134,7 @@ def test_execute_move_single_file(self, sftp_hook, gcs_hook):
gcp_conn_id=GCP_CONN_ID,
sftp_conn_id=SFTP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
sftp_prefetch=True,
)
task.execute(None)
gcs_hook.assert_called_once_with(
Expand All @@ -142,7 +144,7 @@ def test_execute_move_single_file(self, sftp_hook, gcs_hook):
sftp_hook.assert_called_once_with(SFTP_CONN_ID)

sftp_hook.return_value.retrieve_file.assert_called_once_with(
os.path.join(SOURCE_OBJECT_NO_WILDCARD), mock.ANY
os.path.join(SOURCE_OBJECT_NO_WILDCARD), mock.ANY, prefetch=True
)

gcs_hook.return_value.upload.assert_called_once_with(
Expand Down Expand Up @@ -181,8 +183,8 @@ def test_execute_copy_with_wildcard(self, sftp_hook, gcs_hook):

sftp_hook.return_value.retrieve_file.assert_has_calls(
[
mock.call("main_dir/test_object3.json", mock.ANY),
mock.call("main_dir/sub_dir/test_object3.json", mock.ANY),
mock.call("main_dir/test_object3.json", mock.ANY, prefetch=True),
mock.call("main_dir/sub_dir/test_object3.json", mock.ANY, prefetch=True),
]
)

Expand Down