Skip to content

Commit 9df6384

Browse files
valentinDruzhininValentyn Druzhynin
and
Valentyn Druzhynin
authored
Add support for RequestPay=requester option in Amazon S3's Operators, Sensors and Triggers (apache#51098)
Co-authored-by: Valentyn Druzhynin <[email protected]>
1 parent 12eed61 commit 9df6384

File tree

2 files changed

+268
-50
lines changed
  • providers/amazon
    • src/airflow/providers/amazon/aws/hooks
    • tests/unit/amazon/aws/hooks

2 files changed

+268
-50
lines changed

providers/amazon/src/airflow/providers/amazon/aws/hooks/s3.py

Lines changed: 101 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,7 @@ def __init__(
193193
) -> None:
194194
kwargs["client_type"] = "s3"
195195
kwargs["aws_conn_id"] = aws_conn_id
196+
self._requester_pays = kwargs.pop("requester_pays", False)
196197

197198
if transfer_config_args and not isinstance(transfer_config_args, dict):
198199
raise TypeError(f"transfer_config_args expected dict, got {type(transfer_config_args).__name__}.")
@@ -409,12 +410,15 @@ def list_prefixes(
409410
}
410411

411412
paginator = self.get_conn().get_paginator("list_objects_v2")
412-
response = paginator.paginate(
413-
Bucket=bucket_name,
414-
Prefix=prefix,
415-
Delimiter=delimiter,
416-
PaginationConfig=config,
417-
)
413+
params = {
414+
"Bucket": bucket_name,
415+
"Prefix": prefix,
416+
"Delimiter": delimiter,
417+
"PaginationConfig": config,
418+
}
419+
if self._requester_pays:
420+
params["RequestPayer"] = "requester"
421+
response = paginator.paginate(**params)
418422

419423
prefixes: list[str] = []
420424
for page in response:
@@ -437,7 +441,13 @@ async def get_head_object_async(
437441
"""
438442
head_object_val: dict[str, Any] | None = None
439443
try:
440-
head_object_val = await client.head_object(Bucket=bucket_name, Key=key)
444+
params = {
445+
"Bucket": bucket_name,
446+
"Key": key,
447+
}
448+
if self._requester_pays:
449+
params["RequestPayer"] = "requester"
450+
head_object_val = await client.head_object(**params)
441451
return head_object_val
442452
except ClientError as e:
443453
if e.response["ResponseMetadata"]["HTTPStatusCode"] == 404:
@@ -472,12 +482,15 @@ async def list_prefixes_async(
472482
}
473483

474484
paginator = client.get_paginator("list_objects_v2")
475-
response = paginator.paginate(
476-
Bucket=bucket_name,
477-
Prefix=prefix,
478-
Delimiter=delimiter,
479-
PaginationConfig=config,
480-
)
485+
params = {
486+
"Bucket": bucket_name,
487+
"Prefix": prefix,
488+
"Delimiter": delimiter,
489+
"PaginationConfig": config,
490+
}
491+
if self._requester_pays:
492+
params["RequestPayer"] = "requester"
493+
response = paginator.paginate(**params)
481494

482495
prefixes = []
483496
async for page in response:
@@ -501,7 +514,14 @@ async def get_file_metadata_async(
501514
prefix = re.split(r"[\[\*\?]", key, 1)[0] if key else ""
502515
delimiter = ""
503516
paginator = client.get_paginator("list_objects_v2")
504-
response = paginator.paginate(Bucket=bucket_name, Prefix=prefix, Delimiter=delimiter)
517+
params = {
518+
"Bucket": bucket_name,
519+
"Prefix": prefix,
520+
"Delimiter": delimiter,
521+
}
522+
if self._requester_pays:
523+
params["RequestPayer"] = "requester"
524+
response = paginator.paginate(**params)
505525
async for page in response:
506526
if "Contents" in page:
507527
for row in page["Contents"]:
@@ -622,14 +642,21 @@ async def get_files_async(
622642
prefix = re.split(r"[\[*?]", key, 1)[0]
623643

624644
paginator = client.get_paginator("list_objects_v2")
625-
response = paginator.paginate(Bucket=bucket, Prefix=prefix, Delimiter=delimiter)
645+
params = {
646+
"Bucket": bucket,
647+
"Prefix": prefix,
648+
"Delimiter": delimiter,
649+
}
650+
if self._requester_pays:
651+
params["RequestPayer"] = "requester"
652+
response = paginator.paginate(**params)
626653
async for page in response:
627654
if "Contents" in page:
628655
keys.extend(k for k in page["Contents"] if isinstance(k.get("Size"), (int, float)))
629656
return keys
630657

631-
@staticmethod
632658
async def _list_keys_async(
659+
self,
633660
client: AioBaseClient,
634661
bucket_name: str | None = None,
635662
prefix: str | None = None,
@@ -655,12 +682,15 @@ async def _list_keys_async(
655682
}
656683

657684
paginator = client.get_paginator("list_objects_v2")
658-
response = paginator.paginate(
659-
Bucket=bucket_name,
660-
Prefix=prefix,
661-
Delimiter=delimiter,
662-
PaginationConfig=config,
663-
)
685+
params = {
686+
"Bucket": bucket_name,
687+
"Prefix": prefix,
688+
"Delimiter": delimiter,
689+
"PaginationConfig": config,
690+
}
691+
if self._requester_pays:
692+
params["RequestPayer"] = "requester"
693+
response = paginator.paginate(**params)
664694

665695
keys = []
666696
async for page in response:
@@ -863,13 +893,16 @@ def _is_in_period(input_date: datetime) -> bool:
863893
}
864894

865895
paginator = self.get_conn().get_paginator("list_objects_v2")
866-
response = paginator.paginate(
867-
Bucket=bucket_name,
868-
Prefix=_prefix,
869-
Delimiter=delimiter,
870-
PaginationConfig=config,
871-
StartAfter=start_after_key,
872-
)
896+
params = {
897+
"Bucket": bucket_name,
898+
"Prefix": _prefix,
899+
"Delimiter": delimiter,
900+
"PaginationConfig": config,
901+
"StartAfter": start_after_key,
902+
}
903+
if self._requester_pays:
904+
params["RequestPayer"] = "requester"
905+
response = paginator.paginate(**params)
873906

874907
keys: list[str] = []
875908
for page in response:
@@ -909,7 +942,14 @@ def get_file_metadata(
909942
}
910943

911944
paginator = self.get_conn().get_paginator("list_objects_v2")
912-
response = paginator.paginate(Bucket=bucket_name, Prefix=prefix, PaginationConfig=config)
945+
params = {
946+
"Bucket": bucket_name,
947+
"Prefix": prefix,
948+
"PaginationConfig": config,
949+
}
950+
if self._requester_pays:
951+
params["RequestPayer"] = "requester"
952+
response = paginator.paginate(**params)
913953

914954
files = []
915955
for page in response:
@@ -931,7 +971,13 @@ def head_object(self, key: str, bucket_name: str | None = None) -> dict | None:
931971
:return: metadata of an object
932972
"""
933973
try:
934-
return self.get_conn().head_object(Bucket=bucket_name, Key=key)
974+
params = {
975+
"Bucket": bucket_name,
976+
"Key": key,
977+
}
978+
if self._requester_pays:
979+
params["RequestPayer"] = "requester"
980+
return self.get_conn().head_object(**params)
935981
except ClientError as e:
936982
if e.response["ResponseMetadata"]["HTTPStatusCode"] == 404:
937983
return None
@@ -975,8 +1021,11 @@ def sanitize_extra_args() -> dict[str, str]:
9751021
if arg_name in S3Transfer.ALLOWED_DOWNLOAD_ARGS
9761022
}
9771023

1024+
params = sanitize_extra_args()
1025+
if self._requester_pays:
1026+
params["RequestPayer"] = "requester"
9781027
obj = self.resource.Object(bucket_name, key)
979-
obj.load(**sanitize_extra_args())
1028+
obj.load(**params)
9801029
return obj
9811030

9821031
@unify_bucket_name_and_key
@@ -1022,11 +1071,14 @@ def select_key(
10221071
"""
10231072
expression = expression or "SELECT * FROM S3Object"
10241073
expression_type = expression_type or "SQL"
1074+
extra_args = {}
10251075

10261076
if input_serialization is None:
10271077
input_serialization = {"CSV": {}}
10281078
if output_serialization is None:
10291079
output_serialization = {"CSV": {}}
1080+
if self._requester_pays:
1081+
extra_args["RequestPayer"] = "requester"
10301082

10311083
response = self.get_conn().select_object_content(
10321084
Bucket=bucket_name,
@@ -1035,6 +1087,7 @@ def select_key(
10351087
ExpressionType=expression_type,
10361088
InputSerialization=input_serialization,
10371089
OutputSerialization=output_serialization,
1090+
ExtraArgs=extra_args,
10381091
)
10391092

10401093
return b"".join(
@@ -1124,6 +1177,8 @@ def load_file(
11241177
filename = filename_gz
11251178
if acl_policy:
11261179
extra_args["ACL"] = acl_policy
1180+
if self._requester_pays:
1181+
extra_args["RequestPayer"] = "requester"
11271182

11281183
client = self.get_conn()
11291184
client.upload_file(
@@ -1270,6 +1325,8 @@ def _upload_file_obj(
12701325
extra_args["ServerSideEncryption"] = "AES256"
12711326
if acl_policy:
12721327
extra_args["ACL"] = acl_policy
1328+
if self._requester_pays:
1329+
extra_args["RequestPayer"] = "requester"
12731330

12741331
client = self.get_conn()
12751332
client.upload_fileobj(
@@ -1330,6 +1387,8 @@ def copy_object(
13301387
kwargs["ACL"] = acl_policy
13311388
if meta_data_directive:
13321389
kwargs["MetadataDirective"] = meta_data_directive
1390+
if self._requester_pays:
1391+
kwargs["RequestPayer"] = "requester"
13331392

13341393
dest_bucket_name, dest_bucket_key = self.get_s3_bucket_key(
13351394
dest_bucket_name, dest_bucket_key, "dest_bucket_name", "dest_bucket_key"
@@ -1412,12 +1471,17 @@ def delete_objects(self, bucket: str, keys: str | list) -> None:
14121471
keys = [keys]
14131472

14141473
s3 = self.get_conn()
1474+
extra_kwargs = {}
1475+
if self._requester_pays:
1476+
extra_kwargs["RequestPayer"] = "requester"
14151477

14161478
# We can only send a maximum of 1000 keys per request.
14171479
# For details see:
14181480
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/s3.html#S3.Client.delete_objects
14191481
for chunk in chunks(keys, chunk_size=1000):
1420-
response = s3.delete_objects(Bucket=bucket, Delete={"Objects": [{"Key": k} for k in chunk]})
1482+
response = s3.delete_objects(
1483+
Bucket=bucket, Delete={"Objects": [{"Key": k} for k in chunk]}, **extra_kwargs
1484+
)
14211485
deleted_keys = [x["Key"] for x in response.get("Deleted", [])]
14221486
self.log.info("Deleted: %s", deleted_keys)
14231487
if "Errors" in response:
@@ -1496,9 +1560,12 @@ def download_file(
14961560
file = NamedTemporaryFile(dir=local_path, prefix="airflow_tmp_", delete=False) # type: ignore
14971561

14981562
with file:
1563+
extra_args = {**self.extra_args}
1564+
if self._requester_pays:
1565+
extra_args["RequestPayer"] = "requester"
14991566
s3_obj.download_fileobj(
15001567
file,
1501-
ExtraArgs=self.extra_args,
1568+
ExtraArgs=extra_args,
15021569
Config=self.transfer_config,
15031570
)
15041571
get_hook_lineage_collector().add_input_asset(

0 commit comments

Comments
 (0)