|
16 | 16 | import tempfile
|
17 | 17 | from binascii import unhexlify
|
18 | 18 | from io import BytesIO
|
19 |
| -from typing import Any, BinaryIO, Dict, List, Optional, Union |
| 19 | +from typing import Any, BinaryIO, ClassVar, Dict, List, Optional, Tuple, Union |
20 | 20 | from unittest.mock import Mock
|
21 | 21 | from urllib import parse
|
22 | 22 |
|
|
32 | 32 | from synapse.api.errors import Codes
|
33 | 33 | from synapse.events import EventBase
|
34 | 34 | from synapse.events.spamcheck import load_legacy_spam_checkers
|
| 35 | +from synapse.http.types import QueryParams |
35 | 36 | from synapse.logging.context import make_deferred_yieldable
|
36 | 37 | from synapse.module_api import ModuleApi
|
37 | 38 | from synapse.rest import admin
|
|
41 | 42 | from synapse.rest.media.v1.media_storage import MediaStorage, ReadableFileWrapper
|
42 | 43 | from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend
|
43 | 44 | from synapse.server import HomeServer
|
44 |
| -from synapse.types import RoomAlias |
| 45 | +from synapse.types import JsonDict, RoomAlias |
45 | 46 | from synapse.util import Clock
|
46 | 47 |
|
47 | 48 | from tests import unittest
|
@@ -201,36 +202,46 @@ class _TestImage:
|
201 | 202 | ],
|
202 | 203 | )
|
203 | 204 | class MediaRepoTests(unittest.HomeserverTestCase):
|
204 |
| - |
| 205 | + test_image: ClassVar[_TestImage] |
205 | 206 | hijack_auth = True
|
206 | 207 | user_id = "@test:user"
|
207 | 208 |
|
208 | 209 | def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
209 | 210 |
|
210 |
| - self.fetches = [] |
| 211 | + self.fetches: List[ |
| 212 | + Tuple[ |
| 213 | + "Deferred[Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]]]", |
| 214 | + str, |
| 215 | + str, |
| 216 | + Optional[QueryParams], |
| 217 | + ] |
| 218 | + ] = [] |
211 | 219 |
|
212 | 220 | def get_file(
|
213 | 221 | destination: str,
|
214 | 222 | path: str,
|
215 | 223 | output_stream: BinaryIO,
|
216 |
| - args: Optional[Dict[str, Union[str, List[str]]]] = None, |
| 224 | + args: Optional[QueryParams] = None, |
| 225 | + retry_on_dns_fail: bool = True, |
217 | 226 | max_size: Optional[int] = None,
|
218 |
| - ) -> Deferred: |
219 |
| - """ |
220 |
| - Returns tuple[int,dict,str,int] of file length, response headers, |
221 |
| - absolute URI, and response code. |
222 |
| - """ |
| 227 | + ignore_backoff: bool = False, |
| 228 | + ) -> "Deferred[Tuple[int, Dict[bytes, List[bytes]]]]": |
| 229 | + """A mock for MatrixFederationHttpClient.get_file.""" |
223 | 230 |
|
224 |
| - def write_to(r): |
| 231 | + def write_to( |
| 232 | + r: Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]] |
| 233 | + ) -> Tuple[int, Dict[bytes, List[bytes]]]: |
225 | 234 | data, response = r
|
226 | 235 | output_stream.write(data)
|
227 | 236 | return response
|
228 | 237 |
|
229 |
| - d = Deferred() |
230 |
| - d.addCallback(write_to) |
| 238 | + d: Deferred[Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]]] = Deferred() |
231 | 239 | self.fetches.append((d, destination, path, args))
|
232 |
| - return make_deferred_yieldable(d) |
| 240 | + # Note that this callback changes the value held by d. |
| 241 | + d_after_callback = d.addCallback(write_to) |
| 242 | + return make_deferred_yieldable(d_after_callback) |
233 | 243 |
|
| 244 | + # Mock out the homeserver's MatrixFederationHttpClient |
234 | 245 | client = Mock()
|
235 | 246 | client.get_file = get_file
|
236 | 247 |
|
@@ -461,6 +472,7 @@ def test_thumbnail_repeated_thumbnail(self) -> None:
|
461 | 472 | # Synapse should regenerate missing thumbnails.
|
462 | 473 | origin, media_id = self.media_id.split("/")
|
463 | 474 | info = self.get_success(self.store.get_cached_remote_media(origin, media_id))
|
| 475 | + assert info is not None |
464 | 476 | file_id = info["filesystem_id"]
|
465 | 477 |
|
466 | 478 | thumbnail_dir = self.media_repo.filepaths.remote_media_thumbnail_dir(
|
@@ -581,18 +593,18 @@ def test_same_quality(self, method: str, desired_size: int) -> None:
|
581 | 593 | "thumbnail_method": method,
|
582 | 594 | "thumbnail_type": self.test_image.content_type,
|
583 | 595 | "thumbnail_length": 256,
|
584 |
| - "filesystem_id": f"thumbnail1{self.test_image.extension}", |
| 596 | + "filesystem_id": f"thumbnail1{self.test_image.extension.decode()}", |
585 | 597 | },
|
586 | 598 | {
|
587 | 599 | "thumbnail_width": 32,
|
588 | 600 | "thumbnail_height": 32,
|
589 | 601 | "thumbnail_method": method,
|
590 | 602 | "thumbnail_type": self.test_image.content_type,
|
591 | 603 | "thumbnail_length": 256,
|
592 |
| - "filesystem_id": f"thumbnail2{self.test_image.extension}", |
| 604 | + "filesystem_id": f"thumbnail2{self.test_image.extension.decode()}", |
593 | 605 | },
|
594 | 606 | ],
|
595 |
| - file_id=f"image{self.test_image.extension}", |
| 607 | + file_id=f"image{self.test_image.extension.decode()}", |
596 | 608 | url_cache=None,
|
597 | 609 | server_name=None,
|
598 | 610 | )
|
@@ -637,6 +649,7 @@ def __init__(self, config: Dict[str, Any], api: ModuleApi) -> None:
|
637 | 649 | self.config = config
|
638 | 650 | self.api = api
|
639 | 651 |
|
| 652 | + @staticmethod |
640 | 653 | def parse_config(config: Dict[str, Any]) -> Dict[str, Any]:
|
641 | 654 | return config
|
642 | 655 |
|
@@ -748,7 +761,7 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
748 | 761 |
|
749 | 762 | async def check_media_file_for_spam(
|
750 | 763 | self, file_wrapper: ReadableFileWrapper, file_info: FileInfo
|
751 |
| - ) -> Union[Codes, Literal["NOT_SPAM"]]: |
| 764 | + ) -> Union[Codes, Literal["NOT_SPAM"], Tuple[Codes, JsonDict]]: |
752 | 765 | buf = BytesIO()
|
753 | 766 | await file_wrapper.write_chunks_to(buf.write)
|
754 | 767 |
|
|
0 commit comments