|
6 | 6 | from collections.abc import AsyncIterator
|
7 | 7 | from io import StringIO
|
8 | 8 | from pathlib import Path
|
9 |
| -from typing import Optional, TextIO, Union |
| 9 | +from typing import List, Optional, TextIO, Union |
10 | 10 |
|
11 | 11 | import pytest
|
12 | 12 | from multidict import CIMultiDict
|
13 | 13 |
|
14 | 14 | from aiohttp import payload
|
15 | 15 | from aiohttp.abc import AbstractStreamWriter
|
| 16 | +from aiohttp.payload import READ_SIZE |
16 | 17 |
|
17 | 18 |
|
18 | 19 | class BufferWriter(AbstractStreamWriter):
|
@@ -365,6 +366,155 @@ async def test_iobase_payload_exact_chunk_size_limit() -> None:
|
365 | 366 | assert written == data[:chunk_size]
|
366 | 367 |
|
367 | 368 |
|
| 369 | +async def test_iobase_payload_reads_in_chunks() -> None: |
| 370 | + """Test IOBasePayload reads data in chunks of READ_SIZE, not all at once.""" |
| 371 | + # Create a large file that's multiple times larger than READ_SIZE |
| 372 | + large_data = b"x" * (READ_SIZE * 3 + 1000) # ~192KB + 1000 bytes |
| 373 | + |
| 374 | + # Mock the file-like object to track read calls |
| 375 | + mock_file = unittest.mock.Mock(spec=io.BytesIO) |
| 376 | + mock_file.tell.return_value = 0 |
| 377 | + mock_file.fileno.side_effect = AttributeError # Make size return None |
| 378 | + |
| 379 | + # Track the sizes of read() calls |
| 380 | + read_sizes = [] |
| 381 | + |
| 382 | + def mock_read(size: int) -> bytes: |
| 383 | + read_sizes.append(size) |
| 384 | + # Return data based on how many times read was called |
| 385 | + call_count = len(read_sizes) |
| 386 | + if call_count == 1: |
| 387 | + return large_data[:size] |
| 388 | + elif call_count == 2: |
| 389 | + return large_data[READ_SIZE : READ_SIZE + size] |
| 390 | + elif call_count == 3: |
| 391 | + return large_data[READ_SIZE * 2 : READ_SIZE * 2 + size] |
| 392 | + else: |
| 393 | + return large_data[READ_SIZE * 3 :] |
| 394 | + |
| 395 | + mock_file.read.side_effect = mock_read |
| 396 | + |
| 397 | + payload_obj = payload.IOBasePayload(mock_file) |
| 398 | + writer = MockStreamWriter() |
| 399 | + |
| 400 | + # Write with a large content_length |
| 401 | + await payload_obj.write_with_length(writer, len(large_data)) |
| 402 | + |
| 403 | + # Verify that reads were limited to READ_SIZE |
| 404 | + assert len(read_sizes) > 1 # Should have multiple reads |
| 405 | + for read_size in read_sizes: |
| 406 | + assert ( |
| 407 | + read_size <= READ_SIZE |
| 408 | + ), f"Read size {read_size} exceeds READ_SIZE {READ_SIZE}" |
| 409 | + |
| 410 | + |
| 411 | +async def test_iobase_payload_large_content_length() -> None: |
| 412 | + """Test IOBasePayload with very large content_length doesn't read all at once.""" |
| 413 | + data = b"x" * (READ_SIZE + 1000) |
| 414 | + |
| 415 | + # Create a custom file-like object that tracks read sizes |
| 416 | + class TrackingBytesIO(io.BytesIO): |
| 417 | + def __init__(self, data: bytes) -> None: |
| 418 | + super().__init__(data) |
| 419 | + self.read_sizes: List[int] = [] |
| 420 | + |
| 421 | + def read(self, size: Optional[int] = -1) -> bytes: |
| 422 | + self.read_sizes.append(size if size is not None else -1) |
| 423 | + return super().read(size) |
| 424 | + |
| 425 | + tracking_file = TrackingBytesIO(data) |
| 426 | + payload_obj = payload.IOBasePayload(tracking_file) |
| 427 | + writer = MockStreamWriter() |
| 428 | + |
| 429 | + # Write with a very large content_length (simulating the bug scenario) |
| 430 | + large_content_length = 10 * 1024 * 1024 # 10MB |
| 431 | + await payload_obj.write_with_length(writer, large_content_length) |
| 432 | + |
| 433 | + # Verify no single read exceeded READ_SIZE |
| 434 | + for read_size in tracking_file.read_sizes: |
| 435 | + assert ( |
| 436 | + read_size <= READ_SIZE |
| 437 | + ), f"Read size {read_size} exceeds READ_SIZE {READ_SIZE}" |
| 438 | + |
| 439 | + # Verify the correct amount of data was written |
| 440 | + assert writer.get_written_bytes() == data |
| 441 | + |
| 442 | + |
| 443 | +async def test_textio_payload_reads_in_chunks() -> None: |
| 444 | + """Test TextIOPayload reads data in chunks of READ_SIZE, not all at once.""" |
| 445 | + # Create a large text file that's multiple times larger than READ_SIZE |
| 446 | + large_text = "x" * (READ_SIZE * 3 + 1000) # ~192KB + 1000 chars |
| 447 | + |
| 448 | + # Mock the file-like object to track read calls |
| 449 | + mock_file = unittest.mock.Mock(spec=io.StringIO) |
| 450 | + mock_file.tell.return_value = 0 |
| 451 | + mock_file.fileno.side_effect = AttributeError # Make size return None |
| 452 | + mock_file.encoding = "utf-8" |
| 453 | + |
| 454 | + # Track the sizes of read() calls |
| 455 | + read_sizes = [] |
| 456 | + |
| 457 | + def mock_read(size: int) -> str: |
| 458 | + read_sizes.append(size) |
| 459 | + # Return data based on how many times read was called |
| 460 | + call_count = len(read_sizes) |
| 461 | + if call_count == 1: |
| 462 | + return large_text[:size] |
| 463 | + elif call_count == 2: |
| 464 | + return large_text[READ_SIZE : READ_SIZE + size] |
| 465 | + elif call_count == 3: |
| 466 | + return large_text[READ_SIZE * 2 : READ_SIZE * 2 + size] |
| 467 | + else: |
| 468 | + return large_text[READ_SIZE * 3 :] |
| 469 | + |
| 470 | + mock_file.read.side_effect = mock_read |
| 471 | + |
| 472 | + payload_obj = payload.TextIOPayload(mock_file) |
| 473 | + writer = MockStreamWriter() |
| 474 | + |
| 475 | + # Write with a large content_length |
| 476 | + await payload_obj.write_with_length(writer, len(large_text.encode("utf-8"))) |
| 477 | + |
| 478 | + # Verify that reads were limited to READ_SIZE |
| 479 | + assert len(read_sizes) > 1 # Should have multiple reads |
| 480 | + for read_size in read_sizes: |
| 481 | + assert ( |
| 482 | + read_size <= READ_SIZE |
| 483 | + ), f"Read size {read_size} exceeds READ_SIZE {READ_SIZE}" |
| 484 | + |
| 485 | + |
| 486 | +async def test_textio_payload_large_content_length() -> None: |
| 487 | + """Test TextIOPayload with very large content_length doesn't read all at once.""" |
| 488 | + text_data = "x" * (READ_SIZE + 1000) |
| 489 | + |
| 490 | + # Create a custom file-like object that tracks read sizes |
| 491 | + class TrackingStringIO(io.StringIO): |
| 492 | + def __init__(self, data: str) -> None: |
| 493 | + super().__init__(data) |
| 494 | + self.read_sizes: List[int] = [] |
| 495 | + |
| 496 | + def read(self, size: Optional[int] = -1) -> str: |
| 497 | + self.read_sizes.append(size if size is not None else -1) |
| 498 | + return super().read(size) |
| 499 | + |
| 500 | + tracking_file = TrackingStringIO(text_data) |
| 501 | + payload_obj = payload.TextIOPayload(tracking_file) |
| 502 | + writer = MockStreamWriter() |
| 503 | + |
| 504 | + # Write with a very large content_length (simulating the bug scenario) |
| 505 | + large_content_length = 10 * 1024 * 1024 # 10MB |
| 506 | + await payload_obj.write_with_length(writer, large_content_length) |
| 507 | + |
| 508 | + # Verify no single read exceeded READ_SIZE |
| 509 | + for read_size in tracking_file.read_sizes: |
| 510 | + assert ( |
| 511 | + read_size <= READ_SIZE |
| 512 | + ), f"Read size {read_size} exceeds READ_SIZE {READ_SIZE}" |
| 513 | + |
| 514 | + # Verify the correct amount of data was written |
| 515 | + assert writer.get_written_bytes() == text_data.encode("utf-8") |
| 516 | + |
| 517 | + |
368 | 518 | async def test_async_iterable_payload_write_with_length_no_limit() -> None:
|
369 | 519 | """Test AsyncIterablePayload writing with no content length limit."""
|
370 | 520 |
|
|
0 commit comments