@@ -193,6 +193,7 @@ def __init__(
193
193
) -> None :
194
194
kwargs ["client_type" ] = "s3"
195
195
kwargs ["aws_conn_id" ] = aws_conn_id
196
+ self ._requester_pays = kwargs .pop ("requester_pays" , False )
196
197
197
198
if transfer_config_args and not isinstance (transfer_config_args , dict ):
198
199
raise TypeError (f"transfer_config_args expected dict, got { type (transfer_config_args ).__name__ } ." )
@@ -409,12 +410,15 @@ def list_prefixes(
409
410
}
410
411
411
412
paginator = self .get_conn ().get_paginator ("list_objects_v2" )
412
- response = paginator .paginate (
413
- Bucket = bucket_name ,
414
- Prefix = prefix ,
415
- Delimiter = delimiter ,
416
- PaginationConfig = config ,
417
- )
413
+ params = {
414
+ "Bucket" : bucket_name ,
415
+ "Prefix" : prefix ,
416
+ "Delimiter" : delimiter ,
417
+ "PaginationConfig" : config ,
418
+ }
419
+ if self ._requester_pays :
420
+ params ["RequestPayer" ] = "requester"
421
+ response = paginator .paginate (** params )
418
422
419
423
prefixes : list [str ] = []
420
424
for page in response :
@@ -437,7 +441,13 @@ async def get_head_object_async(
437
441
"""
438
442
head_object_val : dict [str , Any ] | None = None
439
443
try :
440
- head_object_val = await client .head_object (Bucket = bucket_name , Key = key )
444
+ params = {
445
+ "Bucket" : bucket_name ,
446
+ "Key" : key ,
447
+ }
448
+ if self ._requester_pays :
449
+ params ["RequestPayer" ] = "requester"
450
+ head_object_val = await client .head_object (** params )
441
451
return head_object_val
442
452
except ClientError as e :
443
453
if e .response ["ResponseMetadata" ]["HTTPStatusCode" ] == 404 :
@@ -472,12 +482,15 @@ async def list_prefixes_async(
472
482
}
473
483
474
484
paginator = client .get_paginator ("list_objects_v2" )
475
- response = paginator .paginate (
476
- Bucket = bucket_name ,
477
- Prefix = prefix ,
478
- Delimiter = delimiter ,
479
- PaginationConfig = config ,
480
- )
485
+ params = {
486
+ "Bucket" : bucket_name ,
487
+ "Prefix" : prefix ,
488
+ "Delimiter" : delimiter ,
489
+ "PaginationConfig" : config ,
490
+ }
491
+ if self ._requester_pays :
492
+ params ["RequestPayer" ] = "requester"
493
+ response = paginator .paginate (** params )
481
494
482
495
prefixes = []
483
496
async for page in response :
@@ -501,7 +514,14 @@ async def get_file_metadata_async(
501
514
prefix = re .split (r"[\[\*\?]" , key , 1 )[0 ] if key else ""
502
515
delimiter = ""
503
516
paginator = client .get_paginator ("list_objects_v2" )
504
- response = paginator .paginate (Bucket = bucket_name , Prefix = prefix , Delimiter = delimiter )
517
+ params = {
518
+ "Bucket" : bucket_name ,
519
+ "Prefix" : prefix ,
520
+ "Delimiter" : delimiter ,
521
+ }
522
+ if self ._requester_pays :
523
+ params ["RequestPayer" ] = "requester"
524
+ response = paginator .paginate (** params )
505
525
async for page in response :
506
526
if "Contents" in page :
507
527
for row in page ["Contents" ]:
@@ -622,14 +642,21 @@ async def get_files_async(
622
642
prefix = re .split (r"[\[*?]" , key , 1 )[0 ]
623
643
624
644
paginator = client .get_paginator ("list_objects_v2" )
625
- response = paginator .paginate (Bucket = bucket , Prefix = prefix , Delimiter = delimiter )
645
+ params = {
646
+ "Bucket" : bucket ,
647
+ "Prefix" : prefix ,
648
+ "Delimiter" : delimiter ,
649
+ }
650
+ if self ._requester_pays :
651
+ params ["RequestPayer" ] = "requester"
652
+ response = paginator .paginate (** params )
626
653
async for page in response :
627
654
if "Contents" in page :
628
655
keys .extend (k for k in page ["Contents" ] if isinstance (k .get ("Size" ), (int , float )))
629
656
return keys
630
657
631
- @staticmethod
632
658
async def _list_keys_async (
659
+ self ,
633
660
client : AioBaseClient ,
634
661
bucket_name : str | None = None ,
635
662
prefix : str | None = None ,
@@ -655,12 +682,15 @@ async def _list_keys_async(
655
682
}
656
683
657
684
paginator = client .get_paginator ("list_objects_v2" )
658
- response = paginator .paginate (
659
- Bucket = bucket_name ,
660
- Prefix = prefix ,
661
- Delimiter = delimiter ,
662
- PaginationConfig = config ,
663
- )
685
+ params = {
686
+ "Bucket" : bucket_name ,
687
+ "Prefix" : prefix ,
688
+ "Delimiter" : delimiter ,
689
+ "PaginationConfig" : config ,
690
+ }
691
+ if self ._requester_pays :
692
+ params ["RequestPayer" ] = "requester"
693
+ response = paginator .paginate (** params )
664
694
665
695
keys = []
666
696
async for page in response :
@@ -863,13 +893,16 @@ def _is_in_period(input_date: datetime) -> bool:
863
893
}
864
894
865
895
paginator = self .get_conn ().get_paginator ("list_objects_v2" )
866
- response = paginator .paginate (
867
- Bucket = bucket_name ,
868
- Prefix = _prefix ,
869
- Delimiter = delimiter ,
870
- PaginationConfig = config ,
871
- StartAfter = start_after_key ,
872
- )
896
+ params = {
897
+ "Bucket" : bucket_name ,
898
+ "Prefix" : _prefix ,
899
+ "Delimiter" : delimiter ,
900
+ "PaginationConfig" : config ,
901
+ "StartAfter" : start_after_key ,
902
+ }
903
+ if self ._requester_pays :
904
+ params ["RequestPayer" ] = "requester"
905
+ response = paginator .paginate (** params )
873
906
874
907
keys : list [str ] = []
875
908
for page in response :
@@ -909,7 +942,14 @@ def get_file_metadata(
909
942
}
910
943
911
944
paginator = self .get_conn ().get_paginator ("list_objects_v2" )
912
- response = paginator .paginate (Bucket = bucket_name , Prefix = prefix , PaginationConfig = config )
945
+ params = {
946
+ "Bucket" : bucket_name ,
947
+ "Prefix" : prefix ,
948
+ "PaginationConfig" : config ,
949
+ }
950
+ if self ._requester_pays :
951
+ params ["RequestPayer" ] = "requester"
952
+ response = paginator .paginate (** params )
913
953
914
954
files = []
915
955
for page in response :
@@ -931,7 +971,13 @@ def head_object(self, key: str, bucket_name: str | None = None) -> dict | None:
931
971
:return: metadata of an object
932
972
"""
933
973
try :
934
- return self .get_conn ().head_object (Bucket = bucket_name , Key = key )
974
+ params = {
975
+ "Bucket" : bucket_name ,
976
+ "Key" : key ,
977
+ }
978
+ if self ._requester_pays :
979
+ params ["RequestPayer" ] = "requester"
980
+ return self .get_conn ().head_object (** params )
935
981
except ClientError as e :
936
982
if e .response ["ResponseMetadata" ]["HTTPStatusCode" ] == 404 :
937
983
return None
@@ -975,8 +1021,11 @@ def sanitize_extra_args() -> dict[str, str]:
975
1021
if arg_name in S3Transfer .ALLOWED_DOWNLOAD_ARGS
976
1022
}
977
1023
1024
+ params = sanitize_extra_args ()
1025
+ if self ._requester_pays :
1026
+ params ["RequestPayer" ] = "requester"
978
1027
obj = self .resource .Object (bucket_name , key )
979
- obj .load (** sanitize_extra_args () )
1028
+ obj .load (** params )
980
1029
return obj
981
1030
982
1031
@unify_bucket_name_and_key
@@ -1022,11 +1071,14 @@ def select_key(
1022
1071
"""
1023
1072
expression = expression or "SELECT * FROM S3Object"
1024
1073
expression_type = expression_type or "SQL"
1074
+ extra_args = {}
1025
1075
1026
1076
if input_serialization is None :
1027
1077
input_serialization = {"CSV" : {}}
1028
1078
if output_serialization is None :
1029
1079
output_serialization = {"CSV" : {}}
1080
+ if self ._requester_pays :
1081
+ extra_args ["RequestPayer" ] = "requester"
1030
1082
1031
1083
response = self .get_conn ().select_object_content (
1032
1084
Bucket = bucket_name ,
@@ -1035,6 +1087,7 @@ def select_key(
1035
1087
ExpressionType = expression_type ,
1036
1088
InputSerialization = input_serialization ,
1037
1089
OutputSerialization = output_serialization ,
1090
+ ExtraArgs = extra_args ,
1038
1091
)
1039
1092
1040
1093
return b"" .join (
@@ -1124,6 +1177,8 @@ def load_file(
1124
1177
filename = filename_gz
1125
1178
if acl_policy :
1126
1179
extra_args ["ACL" ] = acl_policy
1180
+ if self ._requester_pays :
1181
+ extra_args ["RequestPayer" ] = "requester"
1127
1182
1128
1183
client = self .get_conn ()
1129
1184
client .upload_file (
@@ -1270,6 +1325,8 @@ def _upload_file_obj(
1270
1325
extra_args ["ServerSideEncryption" ] = "AES256"
1271
1326
if acl_policy :
1272
1327
extra_args ["ACL" ] = acl_policy
1328
+ if self ._requester_pays :
1329
+ extra_args ["RequestPayer" ] = "requester"
1273
1330
1274
1331
client = self .get_conn ()
1275
1332
client .upload_fileobj (
@@ -1330,6 +1387,8 @@ def copy_object(
1330
1387
kwargs ["ACL" ] = acl_policy
1331
1388
if meta_data_directive :
1332
1389
kwargs ["MetadataDirective" ] = meta_data_directive
1390
+ if self ._requester_pays :
1391
+ kwargs ["RequestPayer" ] = "requester"
1333
1392
1334
1393
dest_bucket_name , dest_bucket_key = self .get_s3_bucket_key (
1335
1394
dest_bucket_name , dest_bucket_key , "dest_bucket_name" , "dest_bucket_key"
@@ -1412,12 +1471,17 @@ def delete_objects(self, bucket: str, keys: str | list) -> None:
1412
1471
keys = [keys ]
1413
1472
1414
1473
s3 = self .get_conn ()
1474
+ extra_kwargs = {}
1475
+ if self ._requester_pays :
1476
+ extra_kwargs ["RequestPayer" ] = "requester"
1415
1477
1416
1478
# We can only send a maximum of 1000 keys per request.
1417
1479
# For details see:
1418
1480
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/s3.html#S3.Client.delete_objects
1419
1481
for chunk in chunks (keys , chunk_size = 1000 ):
1420
- response = s3 .delete_objects (Bucket = bucket , Delete = {"Objects" : [{"Key" : k } for k in chunk ]})
1482
+ response = s3 .delete_objects (
1483
+ Bucket = bucket , Delete = {"Objects" : [{"Key" : k } for k in chunk ]}, ** extra_kwargs
1484
+ )
1421
1485
deleted_keys = [x ["Key" ] for x in response .get ("Deleted" , [])]
1422
1486
self .log .info ("Deleted: %s" , deleted_keys )
1423
1487
if "Errors" in response :
@@ -1496,9 +1560,12 @@ def download_file(
1496
1560
file = NamedTemporaryFile (dir = local_path , prefix = "airflow_tmp_" , delete = False ) # type: ignore
1497
1561
1498
1562
with file :
1563
+ extra_args = {** self .extra_args }
1564
+ if self ._requester_pays :
1565
+ extra_args ["RequestPayer" ] = "requester"
1499
1566
s3_obj .download_fileobj (
1500
1567
file ,
1501
- ExtraArgs = self . extra_args ,
1568
+ ExtraArgs = extra_args ,
1502
1569
Config = self .transfer_config ,
1503
1570
)
1504
1571
get_hook_lineage_collector ().add_input_asset (
0 commit comments