Skip to content

Commit cfcad08

Browse files
fix: Fix zstd decompression of multi-frame responses (#12290)
1 parent 3ef17c6 commit cfcad08

5 files changed

Lines changed: 159 additions & 2 deletions

File tree

CHANGES/12234.bugfix.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Fixed zstd decompression failing with ``ClientPayloadError`` when the server
2+
sends a response as multiple zstd frames -- by :user:`josu-moreno`.

CONTRIBUTORS.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,7 @@ Jordan Borean
214214
Josep Cugat
215215
Josh Junon
216216
Joshu Coats
217+
Josu Moreno
217218
Julia Tsemusheva
218219
Julien Duponchelle
219220
Jungkook Park

aiohttp/compression_utils.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,7 @@ def __init__(
330330
"Please install `backports.zstd` module"
331331
)
332332
self._obj = ZstdDecompressor()
333+
self._pending_unused_data: bytes | None = None
333334
super().__init__(executor=executor, max_sync_chunk_size=max_sync_chunk_size)
334335

335336
def decompress_sync(
@@ -342,7 +343,33 @@ def decompress_sync(
342343
if max_length == ZLIB_MAX_LENGTH_UNLIMITED
343344
else max_length
344345
)
345-
return self._obj.decompress(data, zstd_max_length)
346+
if self._pending_unused_data is not None:
347+
data = self._pending_unused_data + data
348+
self._pending_unused_data = None
349+
result = self._obj.decompress(data, zstd_max_length)
350+
351+
# Handle multi-frame zstd streams.
352+
# https://datatracker.ietf.org/doc/html/rfc8878#section-3.1.1
353+
# ZstdDecompressor handles one frame only. When a frame ends,
354+
# eof becomes True and any trailing data goes to unused_data.
355+
# We create a fresh decompressor to continue with the next frame.
356+
while self._obj.eof and self._obj.unused_data:
357+
unused_data = self._obj.unused_data
358+
self._obj = ZstdDecompressor()
359+
if zstd_max_length != ZSTD_MAX_LENGTH_UNLIMITED:
360+
zstd_max_length -= len(result)
361+
if zstd_max_length <= 0:
362+
self._pending_unused_data = unused_data
363+
break
364+
result += self._obj.decompress(unused_data, zstd_max_length)
365+
366+
# Frame ended exactly at chunk boundary — no unused_data, but the
367+
# next feed_data() call would fail on the spent decompressor.
368+
# Prepare a fresh one for the next chunk.
369+
if self._obj.eof:
370+
self._obj = ZstdDecompressor()
371+
372+
return result
346373

347374
def flush(self) -> bytes:
348375
return b""

tests/test_compression_utils.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,23 @@
11
"""Tests for compression utils."""
22

3+
import sys
4+
35
import pytest
46

5-
from aiohttp.compression_utils import ZLibBackend, ZLibCompressor, ZLibDecompressor
7+
from aiohttp.compression_utils import (
8+
ZLibBackend,
9+
ZLibCompressor,
10+
ZLibDecompressor,
11+
ZSTDDecompressor,
12+
)
13+
14+
try:
15+
if sys.version_info >= (3, 14):
16+
import compression.zstd as zstandard # noqa: I900
17+
else:
18+
import backports.zstd as zstandard
19+
except ImportError: # pragma: no cover
20+
zstandard = None # type: ignore[assignment]
621

722

823
@pytest.mark.usefixtures("parametrize_zlib_backend")
@@ -33,3 +48,42 @@ async def test_compression_round_trip_in_event_loop() -> None:
3348
compressed_data = await compressor.compress(data) + compressor.flush()
3449
decompressed_data = await decompressor.decompress(compressed_data)
3550
assert data == decompressed_data
51+
52+
53+
@pytest.mark.skipif(zstandard is None, reason="zstandard is not installed")
54+
def test_zstd_multi_frame_unlimited() -> None:
55+
d = ZSTDDecompressor()
56+
frame1 = zstandard.compress(b"AAAA")
57+
frame2 = zstandard.compress(b"BBBB")
58+
result = d.decompress_sync(frame1 + frame2)
59+
assert result == b"AAAABBBB"
60+
61+
62+
@pytest.mark.skipif(zstandard is None, reason="zstandard is not installed")
63+
def test_zstd_multi_frame_max_length_partial() -> None:
64+
d = ZSTDDecompressor()
65+
frame1 = zstandard.compress(b"AAAA")
66+
frame2 = zstandard.compress(b"BBBB")
67+
result = d.decompress_sync(frame1 + frame2, max_length=6)
68+
assert result == b"AAAABB"
69+
70+
71+
@pytest.mark.skipif(zstandard is None, reason="zstandard is not installed")
72+
def test_zstd_multi_frame_max_length_exhausted() -> None:
73+
d = ZSTDDecompressor()
74+
frame1 = zstandard.compress(b"AAAA")
75+
frame2 = zstandard.compress(b"BBBB")
76+
result = d.decompress_sync(frame1 + frame2, max_length=4)
77+
assert result == b"AAAA"
78+
79+
80+
@pytest.mark.skipif(zstandard is None, reason="zstandard is not installed")
81+
def test_zstd_multi_frame_max_length_exhausted_preserves_unused_data() -> None:
82+
d = ZSTDDecompressor()
83+
frame1 = zstandard.compress(b"AAAA")
84+
frame2 = zstandard.compress(b"BBBB")
85+
frame3 = zstandard.compress(b"CCCC")
86+
result1 = d.decompress_sync(frame1 + frame2, max_length=4)
87+
assert result1 == b"AAAA"
88+
result2 = d.decompress_sync(frame3)
89+
assert result2 == b"BBBBCCCC"

tests/test_http_parser.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2081,6 +2081,79 @@ async def test_http_payload_zstandard(self, protocol: BaseProtocol) -> None:
20812081
assert b"zstd data" == out._buffer[0]
20822082
assert out.is_eof()
20832083

2084+
@pytest.mark.skipif(zstandard is None, reason="zstandard is not installed")
2085+
async def test_http_payload_zstandard_multi_frame(
2086+
self, protocol: BaseProtocol
2087+
) -> None:
2088+
frame1 = zstandard.compress(b"first")
2089+
frame2 = zstandard.compress(b"second")
2090+
payload = frame1 + frame2
2091+
out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop())
2092+
p = HttpPayloadParser(
2093+
out,
2094+
length=len(payload),
2095+
compression="zstd",
2096+
headers_parser=HeadersParser(),
2097+
)
2098+
p.feed_data(payload)
2099+
assert b"firstsecond" == b"".join(out._buffer)
2100+
assert out.is_eof()
2101+
2102+
@pytest.mark.skipif(zstandard is None, reason="zstandard is not installed")
2103+
async def test_http_payload_zstandard_multi_frame_chunked(
2104+
self, protocol: BaseProtocol
2105+
) -> None:
2106+
frame1 = zstandard.compress(b"chunk1")
2107+
frame2 = zstandard.compress(b"chunk2")
2108+
out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop())
2109+
p = HttpPayloadParser(
2110+
out,
2111+
length=len(frame1) + len(frame2),
2112+
compression="zstd",
2113+
headers_parser=HeadersParser(),
2114+
)
2115+
p.feed_data(frame1)
2116+
p.feed_data(frame2)
2117+
assert b"chunk1chunk2" == b"".join(out._buffer)
2118+
assert out.is_eof()
2119+
2120+
@pytest.mark.skipif(zstandard is None, reason="zstandard is not installed")
2121+
async def test_http_payload_zstandard_frame_split_mid_chunk(
2122+
self, protocol: BaseProtocol
2123+
) -> None:
2124+
frame1 = zstandard.compress(b"AAAA")
2125+
frame2 = zstandard.compress(b"BBBB")
2126+
combined = frame1 + frame2
2127+
split_point = len(frame1) + 3 # 3 bytes into frame2
2128+
out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop())
2129+
p = HttpPayloadParser(
2130+
out,
2131+
length=len(combined),
2132+
compression="zstd",
2133+
headers_parser=HeadersParser(),
2134+
)
2135+
p.feed_data(combined[:split_point])
2136+
p.feed_data(combined[split_point:])
2137+
assert b"AAAABBBB" == b"".join(out._buffer)
2138+
assert out.is_eof()
2139+
2140+
@pytest.mark.skipif(zstandard is None, reason="zstandard is not installed")
2141+
async def test_http_payload_zstandard_many_small_frames(
2142+
self, protocol: BaseProtocol
2143+
) -> None:
2144+
parts = [f"part{i}".encode() for i in range(10)]
2145+
payload = b"".join(zstandard.compress(p) for p in parts)
2146+
out = aiohttp.StreamReader(protocol, 2**16, loop=asyncio.get_running_loop())
2147+
p = HttpPayloadParser(
2148+
out,
2149+
length=len(payload),
2150+
compression="zstd",
2151+
headers_parser=HeadersParser(),
2152+
)
2153+
p.feed_data(payload)
2154+
assert b"".join(parts) == b"".join(out._buffer)
2155+
assert out.is_eof()
2156+
20842157

20852158
class TestDeflateBuffer:
20862159
async def test_feed_data(self, protocol: BaseProtocol) -> None:

0 commit comments

Comments
 (0)