9
9
import json
10
10
from urllib .parse import urlparse
11
11
from pathlib import Path
12
- from typing import Tuple , Union
12
+ from typing import Optional , Tuple , Union , IO , Callable
13
13
from hashlib import sha256
14
+ from functools import wraps
14
15
16
+ import boto3
17
+ from botocore .exceptions import ClientError
15
18
import requests
16
19
17
20
from allennlp .common .tqdm import Tqdm
@@ -78,7 +81,7 @@ def cached_path(url_or_filename: Union[str, Path], cache_dir: str = None) -> str
78
81
79
82
parsed = urlparse (url_or_filename )
80
83
81
- if parsed .scheme in ('http' , 'https' ):
84
+ if parsed .scheme in ('http' , 'https' , 's3' ):
82
85
# URL, so get it from the cache (downloading if necessary)
83
86
return get_from_cache (url_or_filename , cache_dir )
84
87
elif os .path .exists (url_or_filename ):
@@ -92,6 +95,67 @@ def cached_path(url_or_filename: Union[str, Path], cache_dir: str = None) -> str
92
95
raise ValueError ("unable to parse {} as a URL or as a local path" .format (url_or_filename ))
93
96
94
97
98
+ def split_s3_path (url : str ) -> Tuple [str , str ]:
99
+ """Split a full s3 path into the bucket name and path."""
100
+ parsed = urlparse (url )
101
+ if not parsed .netloc or not parsed .path :
102
+ raise ValueError ("bad s3 path {}" .format (url ))
103
+ bucket_name = parsed .netloc
104
+ s3_path = parsed .path
105
+ # Remove '/' at beginning of path.
106
+ if s3_path .startswith ("/" ):
107
+ s3_path = s3_path [1 :]
108
+ return bucket_name , s3_path
109
+
110
+
111
+ def s3_request (func : Callable ):
112
+ """
113
+ Wrapper function for s3 requests in order to create more helpful error
114
+ messages.
115
+ """
116
+
117
+ @wraps (func )
118
+ def wrapper (url : str , * args , ** kwargs ):
119
+ try :
120
+ return func (url , * args , ** kwargs )
121
+ except ClientError as exc :
122
+ if int (exc .response ["Error" ]["Code" ]) == 404 :
123
+ raise FileNotFoundError ("file {} not found" .format (url ))
124
+ else :
125
+ raise
126
+
127
+ return wrapper
128
+
129
+
130
+ @s3_request
131
+ def s3_etag (url : str ) -> Optional [str ]:
132
+ """Check ETag on S3 object."""
133
+ s3_resource = boto3 .resource ("s3" )
134
+ bucket_name , s3_path = split_s3_path (url )
135
+ s3_object = s3_resource .Object (bucket_name , s3_path )
136
+ return s3_object .e_tag
137
+
138
+
139
+ @s3_request
140
+ def s3_get (url : str , temp_file : IO ) -> None :
141
+ """Pull a file directly from S3."""
142
+ s3_resource = boto3 .resource ("s3" )
143
+ bucket_name , s3_path = split_s3_path (url )
144
+ s3_resource .Bucket (bucket_name ).download_fileobj (s3_path , temp_file )
145
+
146
+
147
+ def http_get (url : str , temp_file : IO ) -> None :
148
+ req = requests .get (url , stream = True )
149
+ content_length = req .headers .get ('Content-Length' )
150
+ total = int (content_length ) if content_length is not None else None
151
+ progress = Tqdm .tqdm (unit = "B" , total = total )
152
+ for chunk in req .iter_content (chunk_size = 1024 ):
153
+ if chunk : # filter out keep-alive new chunks
154
+ progress .update (len (chunk ))
155
+ temp_file .write (chunk )
156
+ progress .close ()
157
+
158
+
95
159
# TODO(joelgrus): do we want to do checksums or anything like that?
96
160
def get_from_cache (url : str , cache_dir : str = None ) -> str :
97
161
"""
@@ -103,13 +167,16 @@ def get_from_cache(url: str, cache_dir: str = None) -> str:
103
167
104
168
os .makedirs (cache_dir , exist_ok = True )
105
169
106
- # make HEAD request to check ETag
107
- response = requests .head (url , allow_redirects = True )
108
- if response .status_code != 200 :
109
- raise IOError ("HEAD request failed for url {}" .format (url ))
170
+ # Get eTag to add to filename, if it exists.
171
+ if url .startswith ("s3://" ):
172
+ etag = s3_etag (url )
173
+ else :
174
+ response = requests .head (url , allow_redirects = True )
175
+ if response .status_code != 200 :
176
+ raise IOError ("HEAD request failed for url {} with status code {}"
177
+ .format (url , response .status_code ))
178
+ etag = response .headers .get ("ETag" )
110
179
111
- # add ETag to filename if it exists
112
- etag = response .headers .get ("ETag" )
113
180
filename = url_to_filename (url , etag )
114
181
115
182
# get cache path to put the file
@@ -122,15 +189,10 @@ def get_from_cache(url: str, cache_dir: str = None) -> str:
122
189
logger .info ("%s not found in cache, downloading to %s" , url , temp_file .name )
123
190
124
191
# GET file object
125
- req = requests .get (url , stream = True )
126
- content_length = req .headers .get ('Content-Length' )
127
- total = int (content_length ) if content_length is not None else None
128
- progress = Tqdm .tqdm (unit = "B" , total = total )
129
- for chunk in req .iter_content (chunk_size = 1024 ):
130
- if chunk : # filter out keep-alive new chunks
131
- progress .update (len (chunk ))
132
- temp_file .write (chunk )
133
- progress .close ()
192
+ if url .startswith ("s3://" ):
193
+ s3_get (url , temp_file )
194
+ else :
195
+ http_get (url , temp_file )
134
196
135
197
# we are copying the file before closing it, so flush to avoid truncation
136
198
temp_file .flush ()
0 commit comments