Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit 659bf25

Browse files
epwalshjoelgrus
authored andcommitted
Allow files to be downloaded from S3 (#1620)
* handle s3 files * pull in to file object * get s3 etag * add some tests * better unit tests * added test for etag * add a few more tests * add status code to error msg * fix unit tests * add a few comments
1 parent 76a65a8 commit 659bf25

File tree

4 files changed

+164
-18
lines changed

4 files changed

+164
-18
lines changed

allennlp/common/file_utils.py

+79-17
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,12 @@
99
import json
1010
from urllib.parse import urlparse
1111
from pathlib import Path
12-
from typing import Tuple, Union
12+
from typing import Optional, Tuple, Union, IO, Callable
1313
from hashlib import sha256
14+
from functools import wraps
1415

16+
import boto3
17+
from botocore.exceptions import ClientError
1518
import requests
1619

1720
from allennlp.common.tqdm import Tqdm
@@ -78,7 +81,7 @@ def cached_path(url_or_filename: Union[str, Path], cache_dir: str = None) -> str
7881

7982
parsed = urlparse(url_or_filename)
8083

81-
if parsed.scheme in ('http', 'https'):
84+
if parsed.scheme in ('http', 'https', 's3'):
8285
# URL, so get it from the cache (downloading if necessary)
8386
return get_from_cache(url_or_filename, cache_dir)
8487
elif os.path.exists(url_or_filename):
@@ -92,6 +95,67 @@ def cached_path(url_or_filename: Union[str, Path], cache_dir: str = None) -> str
9295
raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
9396

9497

98+
def split_s3_path(url: str) -> Tuple[str, str]:
99+
"""Split a full s3 path into the bucket name and path."""
100+
parsed = urlparse(url)
101+
if not parsed.netloc or not parsed.path:
102+
raise ValueError("bad s3 path {}".format(url))
103+
bucket_name = parsed.netloc
104+
s3_path = parsed.path
105+
# Remove '/' at beginning of path.
106+
if s3_path.startswith("/"):
107+
s3_path = s3_path[1:]
108+
return bucket_name, s3_path
109+
110+
111+
def s3_request(func: Callable):
112+
"""
113+
Wrapper function for s3 requests in order to create more helpful error
114+
messages.
115+
"""
116+
117+
@wraps(func)
118+
def wrapper(url: str, *args, **kwargs):
119+
try:
120+
return func(url, *args, **kwargs)
121+
except ClientError as exc:
122+
if int(exc.response["Error"]["Code"]) == 404:
123+
raise FileNotFoundError("file {} not found".format(url))
124+
else:
125+
raise
126+
127+
return wrapper
128+
129+
130+
@s3_request
131+
def s3_etag(url: str) -> Optional[str]:
132+
"""Check ETag on S3 object."""
133+
s3_resource = boto3.resource("s3")
134+
bucket_name, s3_path = split_s3_path(url)
135+
s3_object = s3_resource.Object(bucket_name, s3_path)
136+
return s3_object.e_tag
137+
138+
139+
@s3_request
140+
def s3_get(url: str, temp_file: IO) -> None:
141+
"""Pull a file directly from S3."""
142+
s3_resource = boto3.resource("s3")
143+
bucket_name, s3_path = split_s3_path(url)
144+
s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
145+
146+
147+
def http_get(url: str, temp_file: IO) -> None:
148+
req = requests.get(url, stream=True)
149+
content_length = req.headers.get('Content-Length')
150+
total = int(content_length) if content_length is not None else None
151+
progress = Tqdm.tqdm(unit="B", total=total)
152+
for chunk in req.iter_content(chunk_size=1024):
153+
if chunk: # filter out keep-alive new chunks
154+
progress.update(len(chunk))
155+
temp_file.write(chunk)
156+
progress.close()
157+
158+
95159
# TODO(joelgrus): do we want to do checksums or anything like that?
96160
def get_from_cache(url: str, cache_dir: str = None) -> str:
97161
"""
@@ -103,13 +167,16 @@ def get_from_cache(url: str, cache_dir: str = None) -> str:
103167

104168
os.makedirs(cache_dir, exist_ok=True)
105169

106-
# make HEAD request to check ETag
107-
response = requests.head(url, allow_redirects=True)
108-
if response.status_code != 200:
109-
raise IOError("HEAD request failed for url {}".format(url))
170+
# Get eTag to add to filename, if it exists.
171+
if url.startswith("s3://"):
172+
etag = s3_etag(url)
173+
else:
174+
response = requests.head(url, allow_redirects=True)
175+
if response.status_code != 200:
176+
raise IOError("HEAD request failed for url {} with status code {}"
177+
.format(url, response.status_code))
178+
etag = response.headers.get("ETag")
110179

111-
# add ETag to filename if it exists
112-
etag = response.headers.get("ETag")
113180
filename = url_to_filename(url, etag)
114181

115182
# get cache path to put the file
@@ -122,15 +189,10 @@ def get_from_cache(url: str, cache_dir: str = None) -> str:
122189
logger.info("%s not found in cache, downloading to %s", url, temp_file.name)
123190

124191
# GET file object
125-
req = requests.get(url, stream=True)
126-
content_length = req.headers.get('Content-Length')
127-
total = int(content_length) if content_length is not None else None
128-
progress = Tqdm.tqdm(unit="B", total=total)
129-
for chunk in req.iter_content(chunk_size=1024):
130-
if chunk: # filter out keep-alive new chunks
131-
progress.update(len(chunk))
132-
temp_file.write(chunk)
133-
progress.close()
192+
if url.startswith("s3://"):
193+
s3_get(url, temp_file)
194+
else:
195+
http_get(url, temp_file)
134196

135197
# we are copying the file before closing it, so flush to avoid truncation
136198
temp_file.flush()

allennlp/tests/common/file_utils_test.py

+77-1
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,17 @@
33
import os
44
import pathlib
55
import json
6+
import tempfile
7+
from typing import List, Tuple
68

9+
import boto3
10+
from moto import mock_s3
711
import pytest
812
import responses
913

10-
from allennlp.common.file_utils import url_to_filename, filename_to_url, get_from_cache, cached_path
14+
from allennlp.common.file_utils import (
15+
url_to_filename, filename_to_url, get_from_cache, cached_path, split_s3_path,
16+
s3_request, s3_etag, s3_get)
1117
from allennlp.common.testing import AllenNlpTestCase
1218

1319

@@ -47,6 +53,14 @@ def head_callback(_):
4753
)
4854

4955

56+
def set_up_s3_bucket(bucket_name: str = "my-bucket", s3_objects: List[Tuple[str, str]] = None):
57+
"""Creates a mock s3 bucket optionally with objects uploaded from local files."""
58+
s3_client = boto3.client("s3")
59+
s3_client.create_bucket(Bucket=bucket_name)
60+
for filename, key in s3_objects or []:
61+
s3_client.upload_file(Filename=filename, Bucket=bucket_name, Key=key)
62+
63+
5064
class TestFileUtils(AllenNlpTestCase):
5165
def setUp(self):
5266
super().setUp()
@@ -97,6 +111,68 @@ def test_url_to_filename_with_etags_eliminates_quotes(self):
97111
assert back_to_url == url
98112
assert etag == "mytag"
99113

114+
def test_split_s3_path(self):
115+
# Test splitting good urls.
116+
assert split_s3_path("s3://my-bucket/subdir/file.txt") == ("my-bucket", "subdir/file.txt")
117+
assert split_s3_path("s3://my-bucket/file.txt") == ("my-bucket", "file.txt")
118+
119+
# Test splitting bad urls.
120+
with pytest.raises(ValueError):
121+
split_s3_path("s3://")
122+
split_s3_path("s3://myfile.txt")
123+
split_s3_path("myfile.txt")
124+
125+
@mock_s3
126+
def test_s3_bucket(self):
127+
"""This just ensures the bucket gets set up correctly."""
128+
set_up_s3_bucket()
129+
s3_client = boto3.client("s3")
130+
buckets = s3_client.list_buckets()["Buckets"]
131+
assert len(buckets) == 1
132+
assert buckets[0]["Name"] == "my-bucket"
133+
134+
@mock_s3
135+
def test_s3_request_wrapper(self):
136+
set_up_s3_bucket(s3_objects=[(str(self.glove_file), "embeddings/glove.txt.gz")])
137+
s3_resource = boto3.resource("s3")
138+
139+
@s3_request
140+
def get_file_info(url):
141+
bucket_name, s3_path = split_s3_path(url)
142+
return s3_resource.Object(bucket_name, s3_path).content_type
143+
144+
# Good request, should work.
145+
assert get_file_info("s3://my-bucket/embeddings/glove.txt.gz") == "text/plain"
146+
147+
# File missing, should raise FileNotFoundError.
148+
with pytest.raises(FileNotFoundError):
149+
get_file_info("s3://my-bucket/missing_file.txt")
150+
151+
@mock_s3
152+
def test_s3_etag(self):
153+
set_up_s3_bucket(s3_objects=[(str(self.glove_file), "embeddings/glove.txt.gz")])
154+
# Ensure we can get the etag for an s3 object and that it looks as expected.
155+
etag = s3_etag("s3://my-bucket/embeddings/glove.txt.gz")
156+
assert isinstance(etag, str)
157+
assert etag.startswith("'") or etag.startswith('"')
158+
159+
# Should raise FileNotFoundError if the file does not exist on the bucket.
160+
with pytest.raises(FileNotFoundError):
161+
s3_etag("s3://my-bucket/missing_file.txt")
162+
163+
@mock_s3
164+
def test_s3_get(self):
165+
set_up_s3_bucket(s3_objects=[(str(self.glove_file), "embeddings/glove.txt.gz")])
166+
167+
with tempfile.NamedTemporaryFile() as temp_file:
168+
s3_get("s3://my-bucket/embeddings/glove.txt.gz", temp_file)
169+
assert os.stat(temp_file.name).st_size != 0
170+
171+
# Should raise FileNotFoundError if the file does not exist on the bucket.
172+
with pytest.raises(FileNotFoundError):
173+
with tempfile.NamedTemporaryFile() as temp_file:
174+
s3_get("s3://my-bucket/missing_file.txt", temp_file)
175+
100176
@responses.activate
101177
def test_get_from_cache(self):
102178
url = 'http://fake.datastore.com/glove.txt.gz'

requirements.txt

+6
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ cffi==1.11.2
4343
# aws commandline tools for running on Docker remotely.
4444
awscli>=1.11.91
4545

46+
# Accessing files from S3 directly.
47+
boto3
48+
4649
# REST interface for models
4750
flask==0.12.1
4851
flask-cors==3.0.3
@@ -106,6 +109,9 @@ codecov
106109
# Required to run sanic tests
107110
aiohttp
108111

112+
# For mocking s3.
113+
moto==1.3.4
114+
109115
#### DOC-RELATED PACKAGES ####
110116

111117
# Builds our documentation.

setup.py

+2
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,8 @@
111111
'tensorboardX==1.2',
112112
'cffi==1.11.2',
113113
'awscli>=1.11.91',
114+
'boto3',
115+
'moto==1.3.4',
114116
'flask==0.12.1',
115117
'flask-cors==3.0.3',
116118
'gevent==1.3.5',

0 commit comments

Comments
 (0)