Skip to content

Commit cebe526

Browse files
Fix handling of multipart/form-data (#8280) (#8302)
https://datatracker.ietf.org/doc/html/rfc7578 (cherry picked from commit 7d0be3f)
1 parent 270ae9c commit cebe526

7 files changed

+155
-120
lines changed

CHANGES/8280.bugfix.rst

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fixed ``multipart/form-data`` compliance with :rfc:`7578` -- by :user:`Dreamsorcerer`.

CHANGES/8280.deprecation.rst

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Deprecated ``content_transfer_encoding`` parameter in :py:meth:`FormData.add_field()
2+
<aiohttp.FormData.add_field>` -- by :user:`Dreamsorcerer`.

aiohttp/formdata.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import io
2+
import warnings
23
from typing import Any, Iterable, List, Optional
34
from urllib.parse import urlencode
45

@@ -53,7 +54,12 @@ def add_field(
5354
if isinstance(value, io.IOBase):
5455
self._is_multipart = True
5556
elif isinstance(value, (bytes, bytearray, memoryview)):
57+
msg = (
58+
"In v4, passing bytes will no longer create a file field. "
59+
"Please explicitly use the filename parameter or pass a BytesIO object."
60+
)
5661
if filename is None and content_transfer_encoding is None:
62+
warnings.warn(msg, DeprecationWarning)
5763
filename = name
5864

5965
type_options: MultiDict[str] = MultiDict({"name": name})
@@ -81,7 +87,11 @@ def add_field(
8187
"content_transfer_encoding must be an instance"
8288
" of str. Got: %s" % content_transfer_encoding
8389
)
84-
headers[hdrs.CONTENT_TRANSFER_ENCODING] = content_transfer_encoding
90+
msg = (
91+
"content_transfer_encoding is deprecated. "
92+
"To maintain compatibility with v4 please pass a BytesPayload."
93+
)
94+
warnings.warn(msg, DeprecationWarning)
8595
self._is_multipart = True
8696

8797
self._fields.append((type_options, headers, value))

aiohttp/multipart.py

+80-41
Original file line numberDiff line numberDiff line change
@@ -256,13 +256,22 @@ class BodyPartReader:
256256
chunk_size = 8192
257257

258258
def __init__(
259-
self, boundary: bytes, headers: "CIMultiDictProxy[str]", content: StreamReader
259+
self,
260+
boundary: bytes,
261+
headers: "CIMultiDictProxy[str]",
262+
content: StreamReader,
263+
*,
264+
subtype: str = "mixed",
265+
default_charset: Optional[str] = None,
260266
) -> None:
261267
self.headers = headers
262268
self._boundary = boundary
263269
self._content = content
270+
self._default_charset = default_charset
264271
self._at_eof = False
265-
length = self.headers.get(CONTENT_LENGTH, None)
272+
self._is_form_data = subtype == "form-data"
273+
# https://datatracker.ietf.org/doc/html/rfc7578#section-4.8
274+
length = None if self._is_form_data else self.headers.get(CONTENT_LENGTH, None)
266275
self._length = int(length) if length is not None else None
267276
self._read_bytes = 0
268277
self._unread: Deque[bytes] = deque()
@@ -329,6 +338,8 @@ async def _read_chunk_from_length(self, size: int) -> bytes:
329338
assert self._length is not None, "Content-Length required for chunked read"
330339
chunk_size = min(size, self._length - self._read_bytes)
331340
chunk = await self._content.read(chunk_size)
341+
if self._content.at_eof():
342+
self._at_eof = True
332343
return chunk
333344

334345
async def _read_chunk_from_stream(self, size: int) -> bytes:
@@ -449,7 +460,8 @@ def decode(self, data: bytes) -> bytes:
449460
"""
450461
if CONTENT_TRANSFER_ENCODING in self.headers:
451462
data = self._decode_content_transfer(data)
452-
if CONTENT_ENCODING in self.headers:
463+
# https://datatracker.ietf.org/doc/html/rfc7578#section-4.8
464+
if not self._is_form_data and CONTENT_ENCODING in self.headers:
453465
return self._decode_content(data)
454466
return data
455467

@@ -483,7 +495,7 @@ def get_charset(self, default: str) -> str:
483495
"""Returns charset parameter from Content-Type header or default."""
484496
ctype = self.headers.get(CONTENT_TYPE, "")
485497
mimetype = parse_mimetype(ctype)
486-
return mimetype.parameters.get("charset", default)
498+
return mimetype.parameters.get("charset", self._default_charset or default)
487499

488500
@reify
489501
def name(self) -> Optional[str]:
@@ -538,9 +550,17 @@ class MultipartReader:
538550
part_reader_cls = BodyPartReader
539551

540552
def __init__(self, headers: Mapping[str, str], content: StreamReader) -> None:
553+
self._mimetype = parse_mimetype(headers[CONTENT_TYPE])
554+
assert self._mimetype.type == "multipart", "multipart/* content type expected"
555+
if "boundary" not in self._mimetype.parameters:
556+
raise ValueError(
557+
"boundary missed for Content-Type: %s" % headers[CONTENT_TYPE]
558+
)
559+
541560
self.headers = headers
542561
self._boundary = ("--" + self._get_boundary()).encode()
543562
self._content = content
563+
self._default_charset: Optional[str] = None
544564
self._last_part: Optional[Union["MultipartReader", BodyPartReader]] = None
545565
self._at_eof = False
546566
self._at_bof = True
@@ -592,7 +612,24 @@ async def next(
592612
await self._read_boundary()
593613
if self._at_eof: # we just read the last boundary, nothing to do there
594614
return None
595-
self._last_part = await self.fetch_next_part()
615+
616+
part = await self.fetch_next_part()
617+
# https://datatracker.ietf.org/doc/html/rfc7578#section-4.6
618+
if (
619+
self._last_part is None
620+
and self._mimetype.subtype == "form-data"
621+
and isinstance(part, BodyPartReader)
622+
):
623+
_, params = parse_content_disposition(part.headers.get(CONTENT_DISPOSITION))
624+
if params.get("name") == "_charset_":
625+
# Longest encoding in https://encoding.spec.whatwg.org/encodings.json
626+
# is 19 characters, so 32 should be more than enough for any valid encoding.
627+
charset = await part.read_chunk(32)
628+
if len(charset) > 31:
629+
raise RuntimeError("Invalid default charset")
630+
self._default_charset = charset.strip().decode()
631+
part = await self.fetch_next_part()
632+
self._last_part = part
596633
return self._last_part
597634

598635
async def release(self) -> None:
@@ -628,19 +665,16 @@ def _get_part_reader(
628665
return type(self)(headers, self._content)
629666
return self.multipart_reader_cls(headers, self._content)
630667
else:
631-
return self.part_reader_cls(self._boundary, headers, self._content)
632-
633-
def _get_boundary(self) -> str:
634-
mimetype = parse_mimetype(self.headers[CONTENT_TYPE])
635-
636-
assert mimetype.type == "multipart", "multipart/* content type expected"
637-
638-
if "boundary" not in mimetype.parameters:
639-
raise ValueError(
640-
"boundary missed for Content-Type: %s" % self.headers[CONTENT_TYPE]
668+
return self.part_reader_cls(
669+
self._boundary,
670+
headers,
671+
self._content,
672+
subtype=self._mimetype.subtype,
673+
default_charset=self._default_charset,
641674
)
642675

643-
boundary = mimetype.parameters["boundary"]
676+
def _get_boundary(self) -> str:
677+
boundary = self._mimetype.parameters["boundary"]
644678
if len(boundary) > 70:
645679
raise ValueError("boundary %r is too long (70 chars max)" % boundary)
646680

@@ -731,6 +765,7 @@ def __init__(self, subtype: str = "mixed", boundary: Optional[str] = None) -> No
731765
super().__init__(None, content_type=ctype)
732766

733767
self._parts: List[_Part] = []
768+
self._is_form_data = subtype == "form-data"
734769

735770
def __enter__(self) -> "MultipartWriter":
736771
return self
@@ -808,32 +843,36 @@ def append(self, obj: Any, headers: Optional[Mapping[str, str]] = None) -> Paylo
808843

809844
def append_payload(self, payload: Payload) -> Payload:
810845
"""Adds a new body part to multipart writer."""
811-
# compression
812-
encoding: Optional[str] = payload.headers.get(
813-
CONTENT_ENCODING,
814-
"",
815-
).lower()
816-
if encoding and encoding not in ("deflate", "gzip", "identity"):
817-
raise RuntimeError(f"unknown content encoding: {encoding}")
818-
if encoding == "identity":
819-
encoding = None
820-
821-
# te encoding
822-
te_encoding: Optional[str] = payload.headers.get(
823-
CONTENT_TRANSFER_ENCODING,
824-
"",
825-
).lower()
826-
if te_encoding not in ("", "base64", "quoted-printable", "binary"):
827-
raise RuntimeError(
828-
"unknown content transfer encoding: {}" "".format(te_encoding)
846+
encoding: Optional[str] = None
847+
te_encoding: Optional[str] = None
848+
if self._is_form_data:
849+
# https://datatracker.ietf.org/doc/html/rfc7578#section-4.7
850+
# https://datatracker.ietf.org/doc/html/rfc7578#section-4.8
851+
assert CONTENT_DISPOSITION in payload.headers
852+
assert "name=" in payload.headers[CONTENT_DISPOSITION]
853+
assert (
854+
not {CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_TRANSFER_ENCODING}
855+
& payload.headers.keys()
829856
)
830-
if te_encoding == "binary":
831-
te_encoding = None
832-
833-
# size
834-
size = payload.size
835-
if size is not None and not (encoding or te_encoding):
836-
payload.headers[CONTENT_LENGTH] = str(size)
857+
else:
858+
# compression
859+
encoding = payload.headers.get(CONTENT_ENCODING, "").lower()
860+
if encoding and encoding not in ("deflate", "gzip", "identity"):
861+
raise RuntimeError(f"unknown content encoding: {encoding}")
862+
if encoding == "identity":
863+
encoding = None
864+
865+
# te encoding
866+
te_encoding = payload.headers.get(CONTENT_TRANSFER_ENCODING, "").lower()
867+
if te_encoding not in ("", "base64", "quoted-printable", "binary"):
868+
raise RuntimeError(f"unknown content transfer encoding: {te_encoding}")
869+
if te_encoding == "binary":
870+
te_encoding = None
871+
872+
# size
873+
size = payload.size
874+
if size is not None and not (encoding or te_encoding):
875+
payload.headers[CONTENT_LENGTH] = str(size)
837876

838877
self._parts.append((payload, encoding, te_encoding)) # type: ignore[arg-type]
839878
return payload

tests/test_client_functional.py

+1-43
Original file line numberDiff line numberDiff line change
@@ -1317,48 +1317,6 @@ async def handler(request):
13171317
resp.close()
13181318

13191319

1320-
async def test_POST_DATA_with_context_transfer_encoding(aiohttp_client) -> None:
1321-
async def handler(request):
1322-
data = await request.post()
1323-
assert data["name"] == "text"
1324-
return web.Response(text=data["name"])
1325-
1326-
app = web.Application()
1327-
app.router.add_post("/", handler)
1328-
client = await aiohttp_client(app)
1329-
1330-
form = aiohttp.FormData()
1331-
form.add_field("name", "text", content_transfer_encoding="base64")
1332-
1333-
resp = await client.post("/", data=form)
1334-
assert 200 == resp.status
1335-
content = await resp.text()
1336-
assert content == "text"
1337-
resp.close()
1338-
1339-
1340-
async def test_POST_DATA_with_content_type_context_transfer_encoding(aiohttp_client):
1341-
async def handler(request):
1342-
data = await request.post()
1343-
assert data["name"] == "text"
1344-
return web.Response(body=data["name"])
1345-
1346-
app = web.Application()
1347-
app.router.add_post("/", handler)
1348-
client = await aiohttp_client(app)
1349-
1350-
form = aiohttp.FormData()
1351-
form.add_field(
1352-
"name", "text", content_type="text/plain", content_transfer_encoding="base64"
1353-
)
1354-
1355-
resp = await client.post("/", data=form)
1356-
assert 200 == resp.status
1357-
content = await resp.text()
1358-
assert content == "text"
1359-
resp.close()
1360-
1361-
13621320
async def test_POST_MultiDict(aiohttp_client) -> None:
13631321
async def handler(request):
13641322
data = await request.post()
@@ -1410,7 +1368,7 @@ async def handler(request):
14101368

14111369
with fname.open("rb") as f:
14121370
async with client.post(
1413-
"/", data={"some": f, "test": b"data"}, chunked=True
1371+
"/", data={"some": f, "test": io.BytesIO(b"data")}, chunked=True
14141372
) as resp:
14151373
assert 200 == resp.status
14161374

0 commit comments

Comments
 (0)