Skip to content

Commit 968adea

Browse files
committed
add max_depth limit and improve types
1 parent bc4228f commit 968adea

6 files changed

Lines changed: 264 additions & 55 deletions

File tree

README.md

Lines changed: 81 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -139,10 +139,58 @@ result, _ = unpack(packed, ext_hook=ext_hook)
139139
assert pt == result
140140
```
141141

142+
## Depth limits
143+
144+
Use `max_depth` to reject excessively nested payloads during packing or
145+
unpacking. If the nesting limit is exceeded, a `RecursionError` is raised.
146+
147+
`max_depth` counts the root object as one level, so scalar roots work with
148+
`max_depth=1`, while nested containers require a higher value. The default value
149+
is `-1`, which disables the depth limit.
150+
151+
Even with `max_depth` disabled, extremely deep payloads can hit Python's
152+
built-in recursion limit. This limit can be temporarily raised:
153+
154+
```python
155+
import sys
156+
from contextlib import contextmanager
157+
158+
from msgpack_streams import pack, unpack
159+
160+
161+
@contextmanager
162+
def recursion_limit(limit: int):
163+
previous_limit = sys.getrecursionlimit()
164+
sys.setrecursionlimit(limit)
165+
try:
166+
yield
167+
finally:
168+
sys.setrecursionlimit(previous_limit)
169+
170+
171+
data = [[[{"key": "value"}]]]
172+
173+
with recursion_limit(10_000):
174+
packed = pack(data, max_depth=9_000)
175+
unpacked, excess_data = unpack(packed, max_depth=9_000)
176+
177+
assert unpacked == data
178+
assert not excess_data
179+
```
180+
181+
Use this carefully. Raising Python's recursion limit too far can still fail or
182+
destabilize the process.
183+
142184
## API reference
143185

144186
```python
145-
def pack(obj: object, *, float32: bool = False, ext_hook: Callable[[object], ExtType | None] | None = None) -> bytes:
187+
def pack(
188+
obj: object,
189+
*,
190+
float32: bool = False,
191+
ext_hook: Callable[[object], ExtType | None] | None = None,
192+
max_depth: int = -1,
193+
) -> bytes:
146194
...
147195
```
148196

@@ -153,10 +201,18 @@ Pass `ext_hook` to handle types that are not natively supported. The callback
153201
receives the unsupported object and should return an `ExtType` to pack in its
154202
place. If it returns `None` a `TypeError` is raised as normal.
155203

204+
Pass `max_depth` to limit container nesting during encoding. If the limit is
205+
exceeded, a `RecursionError` is raised. The default `-1` disables the limit.
206+
156207
---
157208

158209
```python
159-
def unpack(data: bytes, *, ext_hook: Callable[[ExtType], object | None] | None = None) -> tuple[object, bytes]:
210+
def unpack(
211+
data: bytes,
212+
*,
213+
ext_hook: Callable[[ExtType], object | None] | None = None,
214+
max_depth: int = -1,
215+
) -> tuple[object, bytes]:
160216
...
161217
```
162218

@@ -167,10 +223,20 @@ Pass `ext_hook` to convert `ExtType` values during decoding. The callback
167223
receives each `ExtType` and should return the decoded object, or `None` to leave
168224
it as an `ExtType`.
169225

226+
Pass `max_depth` to limit container nesting during decoding. If the limit is
227+
exceeded, a `RecursionError` is raised. The default `-1` disables the limit.
228+
170229
---
171230

172231
```python
173-
def pack_stream(stream: BinaryIO, obj: object, *, float32: bool = False, ext_hook: Callable[[object], ExtType | None] | None = None) -> None:
232+
def pack_stream(
233+
stream: BinaryIO,
234+
obj: object,
235+
*,
236+
float32: bool = False,
237+
ext_hook: Callable[[object], ExtType | None] | None = None,
238+
max_depth: int = -1,
239+
) -> None:
174240
...
175241
```
176242

@@ -181,10 +247,18 @@ Pass `ext_hook` to handle types that are not natively supported. The callback
181247
receives the unsupported object and should return an `ExtType` to pack in its
182248
place. If it returns `None` a `TypeError` is raised as normal.
183249

250+
Pass `max_depth` to limit container nesting during encoding. If the limit is
251+
exceeded, a `RecursionError` is raised. The default `-1` disables the limit.
252+
184253
---
185254

186255
```python
187-
def unpack_stream(stream: BinaryIO, *, ext_hook: Callable[[ExtType], object] | None = None) -> object:
256+
def unpack_stream(
257+
stream: BinaryIO,
258+
*,
259+
ext_hook: Callable[[ExtType], object] | None = None,
260+
max_depth: int = -1,
261+
) -> object:
188262
...
189263
```
190264

@@ -194,3 +268,6 @@ position past the consumed bytes.
194268
Pass `ext_hook` to convert `ExtType` values during decoding. The callback
195269
receives each `ExtType` and should return the decoded object, or `None` to leave
196270
it as an `ExtType`.
271+
272+
Pass `max_depth` to limit container nesting during decoding. If the limit is
273+
exceeded, a `RecursionError` is raised. The default `-1` disables the limit.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "msgpack-streams"
3-
version = "1.0.0"
3+
version = "1.1.0"
44
description = "Fast stream based implementation of msgpack in pure Python"
55
readme = "README.md"
66
classifiers = [

src/msgpack_streams/_io.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,19 @@
55
from typing import Any
66

77
from ._ext import ExtType
8-
from ._msgpack import pack_stream, unpack_stream
8+
from ._msgpack import _pack_stream, _unpack_stream
99

1010

1111
def pack(
1212
obj: object,
1313
*,
1414
float32: bool = False,
1515
ext_hook: Callable[[object], ExtType | Any] | None = None,
16+
max_depth: int = -1,
1617
) -> bytes:
1718
"""Pack object into data."""
1819
with io.BytesIO() as stream:
19-
pack_stream(stream, obj, float32=float32, ext_hook=ext_hook)
20+
_pack_stream(stream, obj, float32, ext_hook, max_depth)
2021
data = stream.getvalue()
2122
return data
2223

@@ -25,9 +26,10 @@ def unpack(
2526
data: bytes,
2627
*,
2728
ext_hook: Callable[[ExtType], object] | None = None,
29+
max_depth: int = -1,
2830
) -> tuple[object, bytes]:
29-
"""Unpack data into object."""
31+
"""Unpack object from data."""
3032
with io.BytesIO(data) as stream:
31-
obj = unpack_stream(stream, ext_hook=ext_hook)
33+
obj = _unpack_stream(stream, ext_hook, max_depth)
3234
excess_data = stream.read()
3335
return obj, excess_data

src/msgpack_streams/_msgpack.py

Lines changed: 52 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,27 @@ def pack_stream(
7474
*,
7575
float32: bool = False,
7676
ext_hook: Callable[[object], ExtType | None] | None = None,
77+
max_depth: int = -1,
7778
) -> None:
79+
"""Pack object into stream."""
80+
_pack_stream(stream, obj, float32, ext_hook, max_depth)
81+
82+
83+
def _pack_stream(
84+
stream: BinaryIO,
85+
obj: object,
86+
float32: bool = False,
87+
ext_hook: Callable[[object], ExtType | None] | None = None,
88+
max_depth: int = -1,
89+
/, # perf: avoid kwargs overhead in recursive calls
90+
) -> None:
91+
if max_depth == 0:
92+
raise RecursionError("max depth exceeded")
93+
max_depth -= 1
94+
7895
_type = type(obj)
7996
if _type is int: # int
80-
i = obj
97+
i: int = obj # type: ignore
8198
if 0 <= i <= 0x7F: # positive fixint
8299
stream.write(_B[i])
83100
elif -32 <= i < 0: # negative fixint
@@ -114,7 +131,7 @@ def pack_stream(
114131
elif _type is bool: # true / false
115132
stream.write(b"\xc3" if obj else b"\xc2")
116133
elif _type is str: # str
117-
s = obj.encode("utf-8")
134+
s: bytes = obj.encode("utf-8") # type: ignore
118135
sl = len(s)
119136
if sl <= 0x1F: # fixstr
120137
stream.write(_B[0xA0 | sl])
@@ -128,7 +145,7 @@ def pack_stream(
128145
raise ValueError("str too large", obj)
129146
stream.write(s)
130147
elif _type is bytes: # bin
131-
bl = len(obj)
148+
bl = len(obj) # type: ignore
132149
if bl <= 0xFF: # bin8
133150
stream.write(b"\xc4" + _B[bl])
134151
elif bl <= 0xFF_FF: # bin16
@@ -137,9 +154,9 @@ def pack_stream(
137154
stream.write(b"\xc6" + u32_b_pack(bl))
138155
else:
139156
raise ValueError("bin too large", obj)
140-
stream.write(obj)
157+
stream.write(obj) # type: ignore
141158
elif _type is dict: # map
142-
ml = len(obj)
159+
ml = len(obj) # type: ignore
143160
if ml <= 0x0F: # fixmap
144161
stream.write(_B[0x80 | ml])
145162
elif ml <= 0xFF_FF: # map16
@@ -148,11 +165,11 @@ def pack_stream(
148165
stream.write(b"\xdf" + u32_b_pack(ml))
149166
else:
150167
raise ValueError("map too large", obj)
151-
for k, v in obj.items():
152-
pack_stream(stream, k, float32=float32, ext_hook=ext_hook)
153-
pack_stream(stream, v, float32=float32, ext_hook=ext_hook)
168+
for k, v in obj.items(): # type: ignore
169+
_pack_stream(stream, k, float32, ext_hook, max_depth)
170+
_pack_stream(stream, v, float32, ext_hook, max_depth)
154171
elif _type is list: # array
155-
al = len(obj)
172+
al = len(obj) # type: ignore
156173
if al <= 0x0F: # fixarray
157174
stream.write(_B[0x90 | al])
158175
elif al <= 0xFF_FF: # array16
@@ -161,13 +178,13 @@ def pack_stream(
161178
stream.write(b"\xdd" + u32_b_pack(al))
162179
else:
163180
raise ValueError("array too large", obj)
164-
for v in obj:
165-
pack_stream(stream, v, float32=float32, ext_hook=ext_hook)
181+
for v in obj: # type: ignore
182+
_pack_stream(stream, v, float32, ext_hook, max_depth)
166183
elif _type is datetime: # timestamp
167-
if obj.tzinfo is None:
184+
if obj.tzinfo is None: # type: ignore
168185
raise ValueError("datetime object must be timezone-aware", obj)
169186

170-
seconds_from_epoch = obj.timestamp()
187+
seconds_from_epoch = obj.timestamp() # type: ignore
171188
# floor rather than int (handles negative timestamps correctly)
172189
seconds = floor(seconds_from_epoch)
173190
nanoseconds = int((seconds_from_epoch - seconds) * 1_000_000_000)
@@ -184,8 +201,8 @@ def pack_stream(
184201
# timestamp96
185202
stream.write(b"\xc7\x0c\xff" + u32_b_pack(nanoseconds) + s64_b_pack(seconds))
186203
elif _type is ExtType: # ext
187-
data = obj.data
188-
p_code = s8_b_pack(obj.code)
204+
data: bytes = obj.data # type: ignore
205+
p_code = s8_b_pack(obj.code) # type: ignore
189206
extl = len(data)
190207
if extl <= 16 and extl in _PO2: # fixext (0xD4 - 0xD8)
191208
stream.write(_PO2[extl] + p_code)
@@ -204,7 +221,7 @@ def pack_stream(
204221
if result is not None:
205222
# pack the ext type (doesn't exactly need to be an ExtType)
206223
# if the same type is returned it will cause infinite recursion
207-
pack_stream(stream, result, float32=float32, ext_hook=ext_hook)
224+
_pack_stream(stream, result, float32, ext_hook, max_depth)
208225
return
209226
raise TypeError("unsupported type", _type, obj)
210227

@@ -213,7 +230,22 @@ def unpack_stream(
213230
stream: BinaryIO,
214231
*,
215232
ext_hook: Callable[[ExtType], object | None] | None = None,
233+
max_depth: int = -1,
216234
) -> object:
235+
"""Unpack object from stream."""
236+
return _unpack_stream(stream, ext_hook, max_depth)
237+
238+
239+
def _unpack_stream(
240+
stream: BinaryIO,
241+
ext_hook: Callable[[ExtType], object | None] | None = None,
242+
max_depth: int = -1,
243+
/, # perf: avoid kwargs overhead in recursive calls
244+
) -> object:
245+
if max_depth == 0:
246+
raise RecursionError("max depth exceeded")
247+
max_depth -= 1
248+
217249
b = stream.read(1)
218250
first_byte = b[0]
219251
if first_byte <= 0x7F: # positive fixint
@@ -223,12 +255,12 @@ def unpack_stream(
223255
elif first_byte <= 0x8F: # fixmap
224256
ml = first_byte & 0x0F
225257
obj = {
226-
unpack_stream(stream, ext_hook=ext_hook): unpack_stream(stream, ext_hook=ext_hook)
258+
_unpack_stream(stream, ext_hook, max_depth): _unpack_stream(stream, ext_hook, max_depth)
227259
for _ in range(ml)
228260
}
229261
elif first_byte <= 0x9F: # fixarray
230262
al = first_byte & 0x0F
231-
obj = [unpack_stream(stream, ext_hook=ext_hook) for _ in range(al)]
263+
obj = [_unpack_stream(stream, ext_hook, max_depth) for _ in range(al)]
232264
elif first_byte <= 0xBF: # fixstr
233265
sl = first_byte & 0x1F
234266
obj = stream.read(sl).decode("utf-8")
@@ -337,14 +369,14 @@ def unpack_stream(
337369
al = u16_b_unpack(stream) # array16
338370
else:
339371
al = u32_b_unpack(stream) # array32
340-
obj = [unpack_stream(stream, ext_hook=ext_hook) for _ in range(al)]
372+
obj = [_unpack_stream(stream, ext_hook, max_depth) for _ in range(al)]
341373
elif first_byte <= 0xDF: # map
342374
if first_byte == 0xDE:
343375
ml = u16_b_unpack(stream) # map16
344376
else:
345377
ml = u32_b_unpack(stream) # map32
346378
obj = {
347-
unpack_stream(stream, ext_hook=ext_hook): unpack_stream(stream, ext_hook=ext_hook)
379+
_unpack_stream(stream, ext_hook, max_depth): _unpack_stream(stream, ext_hook, max_depth)
348380
for _ in range(ml)
349381
}
350382
else:

0 commit comments

Comments
 (0)