11import asyncio
22import sys
33import zlib
4+ from abc import ABC , abstractmethod
45from concurrent .futures import Executor
56from typing import Any , Final , Optional , Protocol , TypedDict , cast
67
3233 HAS_ZSTD = False
3334
3435
35- MAX_SYNC_CHUNK_SIZE = 1024
36+ MAX_SYNC_CHUNK_SIZE = 4096
37+ DEFAULT_MAX_DECOMPRESS_SIZE = 2 ** 25 # 32MiB
38+
39+ # Unlimited decompression constants - different libraries use different conventions
40+ ZLIB_MAX_LENGTH_UNLIMITED = 0 # zlib uses 0 to mean unlimited
41+ ZSTD_MAX_LENGTH_UNLIMITED = - 1 # zstd uses -1 to mean unlimited
3642
3743
3844class ZLibCompressObjProtocol (Protocol ):
@@ -144,19 +150,37 @@ def encoding_to_mode(
144150 return - ZLibBackend .MAX_WBITS if suppress_deflate_header else ZLibBackend .MAX_WBITS
145151
146152
147- class ZlibBaseHandler :
153+ class DecompressionBaseHandler ( ABC ) :
148154 def __init__ (
149155 self ,
150- mode : int ,
151156 executor : Optional [Executor ] = None ,
152157 max_sync_chunk_size : Optional [int ] = MAX_SYNC_CHUNK_SIZE ,
153158 ):
154- self . _mode = mode
159+ """Base class for decompression handlers."""
155160 self ._executor = executor
156161 self ._max_sync_chunk_size = max_sync_chunk_size
157162
163+ @abstractmethod
164+ def decompress_sync (
165+ self , data : bytes , max_length : int = ZLIB_MAX_LENGTH_UNLIMITED
166+ ) -> bytes :
167+ """Decompress the given data."""
168+
169+ async def decompress (
170+ self , data : bytes , max_length : int = ZLIB_MAX_LENGTH_UNLIMITED
171+ ) -> bytes :
172+ """Decompress the given data."""
173+ if (
174+ self ._max_sync_chunk_size is not None
175+ and len (data ) > self ._max_sync_chunk_size
176+ ):
177+ return await asyncio .get_event_loop ().run_in_executor (
178+ self ._executor , self .decompress_sync , data , max_length
179+ )
180+ return self .decompress_sync (data , max_length )
181+
158182
159- class ZLibCompressor ( ZlibBaseHandler ) :
183+ class ZLibCompressor :
160184 def __init__ (
161185 self ,
162186 encoding : Optional [str ] = None ,
@@ -167,14 +191,12 @@ def __init__(
167191 executor : Optional [Executor ] = None ,
168192 max_sync_chunk_size : Optional [int ] = MAX_SYNC_CHUNK_SIZE ,
169193 ):
170- super ().__init__ (
171- mode = (
172- encoding_to_mode (encoding , suppress_deflate_header )
173- if wbits is None
174- else wbits
175- ),
176- executor = executor ,
177- max_sync_chunk_size = max_sync_chunk_size ,
194+ self ._executor = executor
195+ self ._max_sync_chunk_size = max_sync_chunk_size
196+ self ._mode = (
197+ encoding_to_mode (encoding , suppress_deflate_header )
198+ if wbits is None
199+ else wbits
178200 )
179201 self ._zlib_backend : Final = ZLibBackendWrapper (ZLibBackend ._zlib_backend )
180202
@@ -233,41 +255,24 @@ def flush(self, mode: Optional[int] = None) -> bytes:
233255 )
234256
235257
236- class ZLibDecompressor (ZlibBaseHandler ):
258+ class ZLibDecompressor (DecompressionBaseHandler ):
237259 def __init__ (
238260 self ,
239261 encoding : Optional [str ] = None ,
240262 suppress_deflate_header : bool = False ,
241263 executor : Optional [Executor ] = None ,
242264 max_sync_chunk_size : Optional [int ] = MAX_SYNC_CHUNK_SIZE ,
243265 ):
244- super ().__init__ (
245- mode = encoding_to_mode (encoding , suppress_deflate_header ),
246- executor = executor ,
247- max_sync_chunk_size = max_sync_chunk_size ,
248- )
266+ super ().__init__ (executor = executor , max_sync_chunk_size = max_sync_chunk_size )
267+ self ._mode = encoding_to_mode (encoding , suppress_deflate_header )
249268 self ._zlib_backend : Final = ZLibBackendWrapper (ZLibBackend ._zlib_backend )
250269 self ._decompressor = self ._zlib_backend .decompressobj (wbits = self ._mode )
251270
252- def decompress_sync (self , data : bytes , max_length : int = 0 ) -> bytes :
271+ def decompress_sync (
272+ self , data : Buffer , max_length : int = ZLIB_MAX_LENGTH_UNLIMITED
273+ ) -> bytes :
253274 return self ._decompressor .decompress (data , max_length )
254275
255- async def decompress (self , data : bytes , max_length : int = 0 ) -> bytes :
256- """Decompress the data and return the decompressed bytes.
257-
258- If the data size is large than the max_sync_chunk_size, the decompression
259- will be done in the executor. Otherwise, the decompression will be done
260- in the event loop.
261- """
262- if (
263- self ._max_sync_chunk_size is not None
264- and len (data ) > self ._max_sync_chunk_size
265- ):
266- return await asyncio .get_running_loop ().run_in_executor (
267- self ._executor , self ._decompressor .decompress , data , max_length
268- )
269- return self .decompress_sync (data , max_length )
270-
271276 def flush (self , length : int = 0 ) -> bytes :
272277 return (
273278 self ._decompressor .flush (length )
@@ -280,40 +285,64 @@ def eof(self) -> bool:
280285 return self ._decompressor .eof
281286
282287
283- class BrotliDecompressor :
288+ class BrotliDecompressor ( DecompressionBaseHandler ) :
284289 # Supports both 'brotlipy' and 'Brotli' packages
285290 # since they share an import name. The top branches
286291 # are for 'brotlipy' and bottom branches for 'Brotli'
287- def __init__ (self ) -> None :
292+ def __init__ (
293+ self ,
294+ executor : Optional [Executor ] = None ,
295+ max_sync_chunk_size : Optional [int ] = MAX_SYNC_CHUNK_SIZE ,
296+ ) -> None :
297+ """Decompress data using the Brotli library."""
288298 if not HAS_BROTLI :
289299 raise RuntimeError (
290300 "The brotli decompression is not available. "
291301 "Please install `Brotli` module"
292302 )
293303 self ._obj = brotli .Decompressor ()
304+ super ().__init__ (executor = executor , max_sync_chunk_size = max_sync_chunk_size )
294305
295- def decompress_sync (self , data : bytes ) -> bytes :
306+ def decompress_sync (
307+ self , data : Buffer , max_length : int = ZLIB_MAX_LENGTH_UNLIMITED
308+ ) -> bytes :
309+ """Decompress the given data."""
296310 if hasattr (self ._obj , "decompress" ):
297- return cast (bytes , self ._obj .decompress (data ))
298- return cast (bytes , self ._obj .process (data ))
311+ return cast (bytes , self ._obj .decompress (data , max_length ))
312+ return cast (bytes , self ._obj .process (data , max_length ))
299313
300314 def flush (self ) -> bytes :
315+ """Flush the decompressor."""
301316 if hasattr (self ._obj , "flush" ):
302317 return cast (bytes , self ._obj .flush ())
303318 return b""
304319
305320
306- class ZSTDDecompressor :
307- def __init__ (self ) -> None :
321+ class ZSTDDecompressor (DecompressionBaseHandler ):
322+ def __init__ (
323+ self ,
324+ executor : Optional [Executor ] = None ,
325+ max_sync_chunk_size : Optional [int ] = MAX_SYNC_CHUNK_SIZE ,
326+ ) -> None :
308327 if not HAS_ZSTD :
309328 raise RuntimeError (
310329 "The zstd decompression is not available. "
311330 "Please install `backports.zstd` module"
312331 )
313332 self ._obj = ZstdDecompressor ()
314-
315- def decompress_sync (self , data : bytes ) -> bytes :
316- return self ._obj .decompress (data )
333+ super ().__init__ (executor = executor , max_sync_chunk_size = max_sync_chunk_size )
334+
335+ def decompress_sync (
336+ self , data : bytes , max_length : int = ZLIB_MAX_LENGTH_UNLIMITED
337+ ) -> bytes :
338+ # zstd uses -1 for unlimited, while zlib uses 0 for unlimited
339+ # Convert the zlib convention (0=unlimited) to zstd convention (-1=unlimited)
340+ zstd_max_length = (
341+ ZSTD_MAX_LENGTH_UNLIMITED
342+ if max_length == ZLIB_MAX_LENGTH_UNLIMITED
343+ else max_length
344+ )
345+ return self ._obj .decompress (data , zstd_max_length )
317346
318347 def flush (self ) -> bytes :
319348 return b""
0 commit comments