Skip to content

Commit 2b920c3

Browse files
Use decompressor max_length parameter (#11898) (#11918)
(cherry picked from commit 92477c5) --------- Co-authored-by: J. Nick Koston <[email protected]>
1 parent 4ed97a4 commit 2b920c3

12 files changed

Lines changed: 335 additions & 88 deletions

CHANGES/11898.breaking.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
``Brotli`` and ``brotlicffi`` minimum version is now 1.2.
2+
Decompression now has a default maximum output size of 32MiB per decompress call -- by :user:`Dreamsorcerer`.

aiohttp/compression_utils.py

Lines changed: 75 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import asyncio
22
import sys
33
import zlib
4+
from abc import ABC, abstractmethod
45
from concurrent.futures import Executor
56
from typing import Any, Final, Optional, Protocol, TypedDict, cast
67

@@ -32,7 +33,12 @@
3233
HAS_ZSTD = False
3334

3435

35-
MAX_SYNC_CHUNK_SIZE = 1024
36+
MAX_SYNC_CHUNK_SIZE = 4096
37+
DEFAULT_MAX_DECOMPRESS_SIZE = 2**25 # 32MiB
38+
39+
# Unlimited decompression constants - different libraries use different conventions
40+
ZLIB_MAX_LENGTH_UNLIMITED = 0 # zlib uses 0 to mean unlimited
41+
ZSTD_MAX_LENGTH_UNLIMITED = -1 # zstd uses -1 to mean unlimited
3642

3743

3844
class ZLibCompressObjProtocol(Protocol):
@@ -144,19 +150,37 @@ def encoding_to_mode(
144150
return -ZLibBackend.MAX_WBITS if suppress_deflate_header else ZLibBackend.MAX_WBITS
145151

146152

147-
class ZlibBaseHandler:
153+
class DecompressionBaseHandler(ABC):
148154
def __init__(
149155
self,
150-
mode: int,
151156
executor: Optional[Executor] = None,
152157
max_sync_chunk_size: Optional[int] = MAX_SYNC_CHUNK_SIZE,
153158
):
154-
self._mode = mode
159+
"""Base class for decompression handlers."""
155160
self._executor = executor
156161
self._max_sync_chunk_size = max_sync_chunk_size
157162

163+
@abstractmethod
164+
def decompress_sync(
165+
self, data: bytes, max_length: int = ZLIB_MAX_LENGTH_UNLIMITED
166+
) -> bytes:
167+
"""Decompress the given data."""
168+
169+
async def decompress(
170+
self, data: bytes, max_length: int = ZLIB_MAX_LENGTH_UNLIMITED
171+
) -> bytes:
172+
"""Decompress the given data."""
173+
if (
174+
self._max_sync_chunk_size is not None
175+
and len(data) > self._max_sync_chunk_size
176+
):
177+
return await asyncio.get_event_loop().run_in_executor(
178+
self._executor, self.decompress_sync, data, max_length
179+
)
180+
return self.decompress_sync(data, max_length)
181+
158182

159-
class ZLibCompressor(ZlibBaseHandler):
183+
class ZLibCompressor:
160184
def __init__(
161185
self,
162186
encoding: Optional[str] = None,
@@ -167,14 +191,12 @@ def __init__(
167191
executor: Optional[Executor] = None,
168192
max_sync_chunk_size: Optional[int] = MAX_SYNC_CHUNK_SIZE,
169193
):
170-
super().__init__(
171-
mode=(
172-
encoding_to_mode(encoding, suppress_deflate_header)
173-
if wbits is None
174-
else wbits
175-
),
176-
executor=executor,
177-
max_sync_chunk_size=max_sync_chunk_size,
194+
self._executor = executor
195+
self._max_sync_chunk_size = max_sync_chunk_size
196+
self._mode = (
197+
encoding_to_mode(encoding, suppress_deflate_header)
198+
if wbits is None
199+
else wbits
178200
)
179201
self._zlib_backend: Final = ZLibBackendWrapper(ZLibBackend._zlib_backend)
180202

@@ -233,41 +255,24 @@ def flush(self, mode: Optional[int] = None) -> bytes:
233255
)
234256

235257

236-
class ZLibDecompressor(ZlibBaseHandler):
258+
class ZLibDecompressor(DecompressionBaseHandler):
237259
def __init__(
238260
self,
239261
encoding: Optional[str] = None,
240262
suppress_deflate_header: bool = False,
241263
executor: Optional[Executor] = None,
242264
max_sync_chunk_size: Optional[int] = MAX_SYNC_CHUNK_SIZE,
243265
):
244-
super().__init__(
245-
mode=encoding_to_mode(encoding, suppress_deflate_header),
246-
executor=executor,
247-
max_sync_chunk_size=max_sync_chunk_size,
248-
)
266+
super().__init__(executor=executor, max_sync_chunk_size=max_sync_chunk_size)
267+
self._mode = encoding_to_mode(encoding, suppress_deflate_header)
249268
self._zlib_backend: Final = ZLibBackendWrapper(ZLibBackend._zlib_backend)
250269
self._decompressor = self._zlib_backend.decompressobj(wbits=self._mode)
251270

252-
def decompress_sync(self, data: bytes, max_length: int = 0) -> bytes:
271+
def decompress_sync(
272+
self, data: Buffer, max_length: int = ZLIB_MAX_LENGTH_UNLIMITED
273+
) -> bytes:
253274
return self._decompressor.decompress(data, max_length)
254275

255-
async def decompress(self, data: bytes, max_length: int = 0) -> bytes:
256-
"""Decompress the data and return the decompressed bytes.
257-
258-
If the data size is large than the max_sync_chunk_size, the decompression
259-
will be done in the executor. Otherwise, the decompression will be done
260-
in the event loop.
261-
"""
262-
if (
263-
self._max_sync_chunk_size is not None
264-
and len(data) > self._max_sync_chunk_size
265-
):
266-
return await asyncio.get_running_loop().run_in_executor(
267-
self._executor, self._decompressor.decompress, data, max_length
268-
)
269-
return self.decompress_sync(data, max_length)
270-
271276
def flush(self, length: int = 0) -> bytes:
272277
return (
273278
self._decompressor.flush(length)
@@ -280,40 +285,64 @@ def eof(self) -> bool:
280285
return self._decompressor.eof
281286

282287

283-
class BrotliDecompressor:
288+
class BrotliDecompressor(DecompressionBaseHandler):
284289
# Supports both 'brotlipy' and 'Brotli' packages
285290
# since they share an import name. The top branches
286291
# are for 'brotlipy' and bottom branches for 'Brotli'
287-
def __init__(self) -> None:
292+
def __init__(
293+
self,
294+
executor: Optional[Executor] = None,
295+
max_sync_chunk_size: Optional[int] = MAX_SYNC_CHUNK_SIZE,
296+
) -> None:
297+
"""Decompress data using the Brotli library."""
288298
if not HAS_BROTLI:
289299
raise RuntimeError(
290300
"The brotli decompression is not available. "
291301
"Please install `Brotli` module"
292302
)
293303
self._obj = brotli.Decompressor()
304+
super().__init__(executor=executor, max_sync_chunk_size=max_sync_chunk_size)
294305

295-
def decompress_sync(self, data: bytes) -> bytes:
306+
def decompress_sync(
307+
self, data: Buffer, max_length: int = ZLIB_MAX_LENGTH_UNLIMITED
308+
) -> bytes:
309+
"""Decompress the given data."""
296310
if hasattr(self._obj, "decompress"):
297-
return cast(bytes, self._obj.decompress(data))
298-
return cast(bytes, self._obj.process(data))
311+
return cast(bytes, self._obj.decompress(data, max_length))
312+
return cast(bytes, self._obj.process(data, max_length))
299313

300314
def flush(self) -> bytes:
315+
"""Flush the decompressor."""
301316
if hasattr(self._obj, "flush"):
302317
return cast(bytes, self._obj.flush())
303318
return b""
304319

305320

306-
class ZSTDDecompressor:
307-
def __init__(self) -> None:
321+
class ZSTDDecompressor(DecompressionBaseHandler):
322+
def __init__(
323+
self,
324+
executor: Optional[Executor] = None,
325+
max_sync_chunk_size: Optional[int] = MAX_SYNC_CHUNK_SIZE,
326+
) -> None:
308327
if not HAS_ZSTD:
309328
raise RuntimeError(
310329
"The zstd decompression is not available. "
311330
"Please install `backports.zstd` module"
312331
)
313332
self._obj = ZstdDecompressor()
314-
315-
def decompress_sync(self, data: bytes) -> bytes:
316-
return self._obj.decompress(data)
333+
super().__init__(executor=executor, max_sync_chunk_size=max_sync_chunk_size)
334+
335+
def decompress_sync(
336+
self, data: bytes, max_length: int = ZLIB_MAX_LENGTH_UNLIMITED
337+
) -> bytes:
338+
# zstd uses -1 for unlimited, while zlib uses 0 for unlimited
339+
# Convert the zlib convention (0=unlimited) to zstd convention (-1=unlimited)
340+
zstd_max_length = (
341+
ZSTD_MAX_LENGTH_UNLIMITED
342+
if max_length == ZLIB_MAX_LENGTH_UNLIMITED
343+
else max_length
344+
)
345+
return self._obj.decompress(data, zstd_max_length)
317346

318347
def flush(self) -> bytes:
319348
return b""

aiohttp/http_exceptions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,10 @@ class ContentLengthError(PayloadEncodingError):
7474
"""Not enough data to satisfy content length header."""
7575

7676

77+
class DecompressSizeError(PayloadEncodingError):
78+
"""Decompressed size exceeds the configured limit."""
79+
80+
7781
class LineTooLong(BadHttpMessage):
7882
def __init__(
7983
self, line: str, limit: str = "Unknown", actual_size: str = "Unknown"

aiohttp/http_parser.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from . import hdrs
2828
from .base_protocol import BaseProtocol
2929
from .compression_utils import (
30+
DEFAULT_MAX_DECOMPRESS_SIZE,
3031
HAS_BROTLI,
3132
HAS_ZSTD,
3233
BrotliDecompressor,
@@ -48,6 +49,7 @@
4849
BadStatusLine,
4950
ContentEncodingError,
5051
ContentLengthError,
52+
DecompressSizeError,
5153
InvalidHeader,
5254
InvalidURLError,
5355
LineTooLong,
@@ -963,7 +965,12 @@ class DeflateBuffer:
963965

964966
decompressor: Any
965967

966-
def __init__(self, out: StreamReader, encoding: Optional[str]) -> None:
968+
def __init__(
969+
self,
970+
out: StreamReader,
971+
encoding: Optional[str],
972+
max_decompress_size: int = DEFAULT_MAX_DECOMPRESS_SIZE,
973+
) -> None:
967974
self.out = out
968975
self.size = 0
969976
out.total_compressed_bytes = self.size
@@ -988,6 +995,8 @@ def __init__(self, out: StreamReader, encoding: Optional[str]) -> None:
988995
else:
989996
self.decompressor = ZLibDecompressor(encoding=encoding)
990997

998+
self._max_decompress_size = max_decompress_size
999+
9911000
def set_exception(
9921001
self,
9931002
exc: BaseException,
@@ -1017,14 +1026,24 @@ def feed_data(self, chunk: bytes, size: int) -> None:
10171026
)
10181027

10191028
try:
1020-
chunk = self.decompressor.decompress_sync(chunk)
1029+
# Decompress with limit + 1 so we can detect if output exceeds limit
1030+
chunk = self.decompressor.decompress_sync(
1031+
chunk, max_length=self._max_decompress_size + 1
1032+
)
10211033
except Exception:
10221034
raise ContentEncodingError(
10231035
"Can not decode content-encoding: %s" % self.encoding
10241036
)
10251037

10261038
self._started_decoding = True
10271039

1040+
# Check if decompression limit was exceeded
1041+
if len(chunk) > self._max_decompress_size:
1042+
raise DecompressSizeError(
1043+
"Decompressed data exceeds the configured limit of %d bytes"
1044+
% self._max_decompress_size
1045+
)
1046+
10281047
if chunk:
10291048
self.out.feed_data(chunk, len(chunk))
10301049

aiohttp/multipart.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,12 @@
2525

2626
from multidict import CIMultiDict, CIMultiDictProxy
2727

28-
from .compression_utils import ZLibCompressor, ZLibDecompressor
28+
from .abc import AbstractStreamWriter
29+
from .compression_utils import (
30+
DEFAULT_MAX_DECOMPRESS_SIZE,
31+
ZLibCompressor,
32+
ZLibDecompressor,
33+
)
2934
from .hdrs import (
3035
CONTENT_DISPOSITION,
3136
CONTENT_ENCODING,
@@ -273,6 +278,7 @@ def __init__(
273278
*,
274279
subtype: str = "mixed",
275280
default_charset: Optional[str] = None,
281+
max_decompress_size: int = DEFAULT_MAX_DECOMPRESS_SIZE,
276282
) -> None:
277283
self.headers = headers
278284
self._boundary = boundary
@@ -289,6 +295,7 @@ def __init__(
289295
self._prev_chunk: Optional[bytes] = None
290296
self._content_eof = 0
291297
self._cache: Dict[str, Any] = {}
298+
self._max_decompress_size = max_decompress_size
292299

293300
def __aiter__(self: Self) -> Self:
294301
return self
@@ -318,7 +325,7 @@ async def read(self, *, decode: bool = False) -> bytes:
318325
while not self._at_eof:
319326
data.extend(await self.read_chunk(self.chunk_size))
320327
if decode:
321-
return self.decode(data)
328+
return await self.decode(data)
322329
return data
323330

324331
async def read_chunk(self, size: int = chunk_size) -> bytes:
@@ -496,7 +503,7 @@ def at_eof(self) -> bool:
496503
"""Returns True if the boundary was reached or False otherwise."""
497504
return self._at_eof
498505

499-
def decode(self, data: bytes) -> bytes:
506+
async def decode(self, data: bytes) -> bytes:
500507
"""Decodes data.
501508
502509
Decoding is done according the specified Content-Encoding
@@ -506,18 +513,18 @@ def decode(self, data: bytes) -> bytes:
506513
data = self._decode_content_transfer(data)
507514
# https://datatracker.ietf.org/doc/html/rfc7578#section-4.8
508515
if not self._is_form_data and CONTENT_ENCODING in self.headers:
509-
return self._decode_content(data)
516+
return await self._decode_content(data)
510517
return data
511518

512-
def _decode_content(self, data: bytes) -> bytes:
519+
async def _decode_content(self, data: bytes) -> bytes:
513520
encoding = self.headers.get(CONTENT_ENCODING, "").lower()
514521
if encoding == "identity":
515522
return data
516523
if encoding in {"deflate", "gzip"}:
517-
return ZLibDecompressor(
524+
return await ZLibDecompressor(
518525
encoding=encoding,
519526
suppress_deflate_header=True,
520-
).decompress_sync(data)
527+
).decompress(data, max_length=self._max_decompress_size)
521528

522529
raise RuntimeError(f"unknown content encoding: {encoding}")
523530

@@ -588,11 +595,11 @@ async def as_bytes(self, encoding: str = "utf-8", errors: str = "strict") -> byt
588595
"""
589596
raise TypeError("Unable to read body part as bytes. Use write() to consume.")
590597

591-
async def write(self, writer: Any) -> None:
598+
async def write(self, writer: AbstractStreamWriter) -> None:
592599
field = self._value
593600
chunk = await field.read_chunk(size=2**16)
594601
while chunk:
595-
await writer.write(field.decode(chunk))
602+
await writer.write(await field.decode(chunk))
596603
chunk = await field.read_chunk(size=2**16)
597604

598605

@@ -1032,7 +1039,9 @@ async def as_bytes(self, encoding: str = "utf-8", errors: str = "strict") -> byt
10321039

10331040
return b"".join(parts)
10341041

1035-
async def write(self, writer: Any, close_boundary: bool = True) -> None:
1042+
async def write(
1043+
self, writer: AbstractStreamWriter, close_boundary: bool = True
1044+
) -> None:
10361045
"""Write body."""
10371046
for part, encoding, te_encoding in self._parts:
10381047
if self._is_form_data:
@@ -1086,7 +1095,7 @@ async def close(self) -> None:
10861095

10871096

10881097
class MultipartPayloadWriter:
1089-
def __init__(self, writer: Any) -> None:
1098+
def __init__(self, writer: AbstractStreamWriter) -> None:
10901099
self._writer = writer
10911100
self._encoding: Optional[str] = None
10921101
self._compress: Optional[ZLibCompressor] = None

0 commit comments

Comments
 (0)