-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Fix race for concurrent downloads of remote media. #8682
Changes from 5 commits
940ad1e
a5849ea
113f6a2
19f7864
0e5fa9c
404ad0a
2f91a02
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
Fix exception during handling multiple concurrent requests for remote media when using multiple media repositories. |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -305,15 +305,12 @@ async def _get_remote_media_impl( | |
# file_id is the ID we use to track the file locally. If we've already | ||
# seen the file then reuse the existing ID, otherwise genereate a new | ||
# one. | ||
if media_info: | ||
file_id = media_info["filesystem_id"] | ||
else: | ||
file_id = random_string(24) | ||
|
||
file_info = FileInfo(server_name, file_id) | ||
|
||
# If we have an entry in the DB, try and look for it | ||
if media_info: | ||
file_id = media_info["filesystem_id"] | ||
file_info = FileInfo(server_name, file_id) | ||
|
||
if media_info["quarantined_by"]: | ||
logger.info("Media is quarantined") | ||
raise NotFoundError() | ||
|
@@ -324,14 +321,28 @@ async def _get_remote_media_impl( | |
|
||
# Failed to find the file anywhere, lets download it. | ||
|
||
media_info = await self._download_remote_file(server_name, media_id, file_id) | ||
try: | ||
media_info = await self._download_remote_file(server_name, media_id,) | ||
except SynapseError: | ||
raise | ||
except Exception as e: | ||
# An exception may be because we downloaded media in another | ||
# process, so let's check if we magically have the media. | ||
media_info = await self.store.get_cached_remote_media(server_name, media_id) | ||
if not media_info: | ||
raise e | ||
|
||
file_id = media_info["filesystem_id"] | ||
file_info = FileInfo(server_name, file_id) | ||
|
||
await self._generate_thumbnails( | ||
server_name, media_id, file_id, media_info["media_type"] | ||
) | ||
|
||
responder = await self.media_storage.fetch_media(file_info) | ||
return responder, media_info | ||
|
||
async def _download_remote_file( | ||
self, server_name: str, media_id: str, file_id: str | ||
) -> dict: | ||
async def _download_remote_file(self, server_name: str, media_id: str,) -> dict: | ||
"""Attempt to download the remote file from the given server name, | ||
using the given file_id as the local id. | ||
|
||
|
@@ -346,6 +357,8 @@ async def _download_remote_file( | |
The media info of the file. | ||
""" | ||
|
||
file_id = random_string(24) | ||
|
||
file_info = FileInfo(server_name=server_name, file_id=file_id) | ||
|
||
with self.media_storage.store_into_file(file_info) as (f, fname, finish): | ||
|
@@ -401,22 +414,32 @@ async def _download_remote_file( | |
|
||
await finish() | ||
|
||
media_type = headers[b"Content-Type"][0].decode("ascii") | ||
upload_name = get_filename_from_headers(headers) | ||
time_now_ms = self.clock.time_msec() | ||
media_type = headers[b"Content-Type"][0].decode("ascii") | ||
upload_name = get_filename_from_headers(headers) | ||
time_now_ms = self.clock.time_msec() | ||
|
||
# Multiple remote media download requests can race (when using | ||
# multiple media repos), so this may throw a violation constraint | ||
# exception. If it does we'll delete the newly downloaded file from | ||
# disk (as we're in the ctx manager). | ||
# | ||
# However: we've already called `finish()` so we may have also | ||
# written to the storage providers. This is preferable to the | ||
# alternative where we call `finish()` *after* this, where we could | ||
# end up having an entry in the DB but fail to write the files to | ||
# the storage providers. | ||
Comment on lines
+432
to
+436
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Meaning we might still end up with files that are useless? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, we should file an issue to fix that though (e.g adding a |
||
await self.store.store_cached_remote_media( | ||
origin=server_name, | ||
media_id=media_id, | ||
media_type=media_type, | ||
time_now_ms=self.clock.time_msec(), | ||
upload_name=upload_name, | ||
media_length=length, | ||
filesystem_id=file_id, | ||
) | ||
|
||
logger.info("Stored remote media in file %r", fname) | ||
|
||
await self.store.store_cached_remote_media( | ||
origin=server_name, | ||
media_id=media_id, | ||
media_type=media_type, | ||
time_now_ms=self.clock.time_msec(), | ||
upload_name=upload_name, | ||
media_length=length, | ||
filesystem_id=file_id, | ||
) | ||
|
||
media_info = { | ||
"media_type": media_type, | ||
"media_length": length, | ||
|
@@ -425,8 +448,6 @@ async def _download_remote_file( | |
"filesystem_id": file_id, | ||
} | ||
|
||
await self._generate_thumbnails(server_name, media_id, file_id, media_type) | ||
|
||
return media_info | ||
|
||
def _get_thumbnail_requirements(self, media_type): | ||
|
@@ -692,42 +713,60 @@ async def _generate_thumbnails( | |
if not t_byte_source: | ||
continue | ||
|
||
try: | ||
file_info = FileInfo( | ||
server_name=server_name, | ||
file_id=file_id, | ||
thumbnail=True, | ||
thumbnail_width=t_width, | ||
thumbnail_height=t_height, | ||
thumbnail_method=t_method, | ||
thumbnail_type=t_type, | ||
url_cache=url_cache, | ||
) | ||
|
||
output_path = await self.media_storage.store_file( | ||
t_byte_source, file_info | ||
) | ||
finally: | ||
t_byte_source.close() | ||
|
||
t_len = os.path.getsize(output_path) | ||
file_info = FileInfo( | ||
server_name=server_name, | ||
file_id=file_id, | ||
thumbnail=True, | ||
thumbnail_width=t_width, | ||
thumbnail_height=t_height, | ||
thumbnail_method=t_method, | ||
thumbnail_type=t_type, | ||
url_cache=url_cache, | ||
) | ||
|
||
# Write to database | ||
if server_name: | ||
await self.store.store_remote_media_thumbnail( | ||
server_name, | ||
media_id, | ||
file_id, | ||
t_width, | ||
t_height, | ||
t_type, | ||
t_method, | ||
t_len, | ||
) | ||
else: | ||
await self.store.store_local_thumbnail( | ||
media_id, t_width, t_height, t_type, t_method, t_len | ||
) | ||
with self.media_storage.store_into_file(file_info) as (f, fname, finish): | ||
try: | ||
await self.media_storage.write_to_file(t_byte_source, f) | ||
await finish() | ||
finally: | ||
t_byte_source.close() | ||
|
||
t_len = os.path.getsize(fname) | ||
|
||
# Write to database | ||
if server_name: | ||
# Multiple remote media download requests can race (when | ||
# using multiple media repos), so this may throw a violation | ||
# constraint exception. If it does we'll delete the newly | ||
# generated thumbnail from disk (as we're in the ctx | ||
# manager). | ||
# | ||
# However: we've already called `finish()` so we may have | ||
# also written to the storage providers. This is preferable | ||
# to the alternative where we call `finish()` *after* this, | ||
# where we could end up having an entry in the DB but fail | ||
# to write the files to the storage providers. | ||
try: | ||
await self.store.store_remote_media_thumbnail( | ||
server_name, | ||
media_id, | ||
file_id, | ||
t_width, | ||
t_height, | ||
t_type, | ||
t_method, | ||
t_len, | ||
) | ||
except Exception as e: | ||
thumbnail_exists = await self.store.get_remote_media_thumbnail( | ||
server_name, media_id, t_width, t_height, t_type, | ||
) | ||
if not thumbnail_exists: | ||
raise e | ||
else: | ||
await self.store.store_local_thumbnail( | ||
media_id, t_width, t_height, t_type, t_method, t_len | ||
) | ||
|
||
return {"width": m_width, "height": m_height} | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -52,6 +52,7 @@ def __init__( | |
storage_providers: Sequence["StorageProviderWrapper"], | ||
): | ||
self.hs = hs | ||
self.reactor = hs.get_reactor() | ||
self.local_media_directory = local_media_directory | ||
self.filepaths = filepaths | ||
self.storage_providers = storage_providers | ||
|
@@ -70,13 +71,16 @@ async def store_file(self, source: IO, file_info: FileInfo) -> str: | |
|
||
with self.store_into_file(file_info) as (f, fname, finish_cb): | ||
# Write to the main repository | ||
await defer_to_thread( | ||
self.hs.get_reactor(), _write_file_synchronously, source, f | ||
) | ||
await self.write_to_file(source, f) | ||
await finish_cb() | ||
|
||
return fname | ||
|
||
async def write_to_file(self, source: IO, output: IO): | ||
"""Asynchronously write the `source` to `output`. | ||
""" | ||
await defer_to_thread(self.reactor, _write_file_synchronously, source, output) | ||
|
||
@contextlib.contextmanager | ||
def store_into_file(self, file_info: FileInfo): | ||
"""Context manager used to get a file like object to write into, as | ||
|
@@ -112,14 +116,20 @@ def store_into_file(self, file_info: FileInfo): | |
|
||
finished_called = [False] | ||
|
||
async def finish(): | ||
for provider in self.storage_providers: | ||
await provider.store_file(path, file_info) | ||
|
||
finished_called[0] = True | ||
|
||
try: | ||
with open(fname, "wb") as f: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Huh, this being a context manager here is a bit weird when we don't really write inside of it... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well, we do in that we write during the |
||
|
||
async def finish(): | ||
# Ensure that all writes have been flushed and close the | ||
# file. | ||
f.flush() | ||
f.close() | ||
|
||
for provider in self.storage_providers: | ||
await provider.store_file(path, file_info) | ||
|
||
finished_called[0] = True | ||
|
||
yield f, fname, finish | ||
except Exception: | ||
try: | ||
|
@@ -210,7 +220,7 @@ async def ensure_media_is_in_local_cache(self, file_info: FileInfo) -> str: | |
if res: | ||
with res: | ||
consumer = BackgroundFileConsumer( | ||
open(local_path, "wb"), self.hs.get_reactor() | ||
open(local_path, "wb"), self.reactor | ||
) | ||
await res.write_to_consumer(consumer) | ||
await consumer.wait() | ||
|
Uh oh!
There was an error while loading. Please reload this page.