Skip to content

Commit ceeca6a

Browse files
authored
Add support for switching the zlib implementation (#10700)
1 parent 6db713e commit ceeca6a

30 files changed

+632
-88
lines changed

CHANGES/9798.feature.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
Allow user setting zlib compression backend -- by :user:`TimMenninger`
2+
3+
This change allows the user to call :func:`aiohttp.set_zlib_backend()` with the
4+
zlib compression module of their choice. Default behavior continues to use
5+
the builtin ``zlib`` library.

aiohttp/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
WSServerHandshakeError,
4848
request,
4949
)
50+
from .compression_utils import set_zlib_backend
5051
from .connector import AddrInfoType, SocketFactoryType
5152
from .cookiejar import CookieJar, DummyCookieJar
5253
from .formdata import FormData
@@ -165,6 +166,7 @@
165166
"BasicAuth",
166167
"ChainMapProxy",
167168
"ETag",
169+
"set_zlib_backend",
168170
# http
169171
"HttpVersion",
170172
"HttpVersion10",

aiohttp/_websocket/reader_py.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -243,15 +243,23 @@ def _feed_data(self, data: bytes) -> None:
243243
self._decompressobj = ZLibDecompressor(
244244
suppress_deflate_header=True
245245
)
246+
# XXX: It's possible that the zlib backend (isal is known to
247+
# do this, maybe others too?) will return max_length bytes,
248+
# but internally buffer more data such that the payload is
249+
# >max_length, so we return one extra byte and if we're able
250+
# to do that, then the message is too big.
246251
payload_merged = self._decompressobj.decompress_sync(
247-
assembled_payload + WS_DEFLATE_TRAILING, self._max_msg_size
252+
assembled_payload + WS_DEFLATE_TRAILING,
253+
(
254+
self._max_msg_size + 1
255+
if self._max_msg_size
256+
else self._max_msg_size
257+
),
248258
)
249-
if self._decompressobj.unconsumed_tail:
250-
left = len(self._decompressobj.unconsumed_tail)
259+
if self._max_msg_size and len(payload_merged) > self._max_msg_size:
251260
raise WebSocketError(
252261
WSCloseCode.MESSAGE_TOO_BIG,
253-
f"Decompressed message size {self._max_msg_size + left}"
254-
f" exceeds limit {self._max_msg_size}",
262+
f"Decompressed message exceeds size limit {self._max_msg_size}",
255263
)
256264
elif type(assembled_payload) is bytes:
257265
payload_merged = assembled_payload

aiohttp/_websocket/writer.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,12 @@
22

33
import asyncio
44
import random
5-
import zlib
65
from functools import partial
76
from typing import Any, Final, Optional, Union
87

98
from ..base_protocol import BaseProtocol
109
from ..client_exceptions import ClientConnectionResetError
11-
from ..compression_utils import ZLibCompressor
10+
from ..compression_utils import ZLibBackend, ZLibCompressor
1211
from .helpers import (
1312
MASK_LEN,
1413
MSG_SIZE,
@@ -95,7 +94,9 @@ async def send_frame(
9594
message = (
9695
await compressobj.compress(message)
9796
+ compressobj.flush(
98-
zlib.Z_FULL_FLUSH if self.notakeover else zlib.Z_SYNC_FLUSH
97+
ZLibBackend.Z_FULL_FLUSH
98+
if self.notakeover
99+
else ZLibBackend.Z_SYNC_FLUSH
99100
)
100101
).removesuffix(WS_DEFLATE_TRAILING)
101102
# Its critical that we do not return control to the event
@@ -160,7 +161,7 @@ async def send_frame(
160161

161162
def _make_compress_obj(self, compress: int) -> ZLibCompressor:
162163
return ZLibCompressor(
163-
level=zlib.Z_BEST_SPEED,
164+
level=ZLibBackend.Z_BEST_SPEED,
164165
wbits=-compress,
165166
max_sync_chunk_size=WEBSOCKET_MAX_SYNC_CHUNK_SIZE,
166167
)

aiohttp/abc.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import logging
22
import socket
3-
import zlib
43
from abc import ABC, abstractmethod
54
from collections.abc import Sized
65
from http.cookies import BaseCookie, Morsel
@@ -217,7 +216,7 @@ async def drain(self) -> None:
217216

218217
@abstractmethod
219218
def enable_compression(
220-
self, encoding: str = "deflate", strategy: int = zlib.Z_DEFAULT_STRATEGY
219+
self, encoding: str = "deflate", strategy: Optional[int] = None
221220
) -> None:
222221
"""Enable HTTP body compression"""
223222

aiohttp/compression_utils.py

Lines changed: 118 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import sys
33
import zlib
44
from concurrent.futures import Executor
5-
from typing import Optional, cast
5+
from typing import Any, Final, Optional, Protocol, TypedDict, cast
66

77
if sys.version_info >= (3, 12):
88
from collections.abc import Buffer
@@ -24,14 +24,113 @@
2424
MAX_SYNC_CHUNK_SIZE = 1024
2525

2626

27+
class ZLibCompressObjProtocol(Protocol):
28+
def compress(self, data: Buffer) -> bytes: ...
29+
def flush(self, mode: int = ..., /) -> bytes: ...
30+
31+
32+
class ZLibDecompressObjProtocol(Protocol):
33+
def decompress(self, data: Buffer, max_length: int = ...) -> bytes: ...
34+
def flush(self, length: int = ..., /) -> bytes: ...
35+
36+
@property
37+
def eof(self) -> bool: ...
38+
39+
40+
class ZLibBackendProtocol(Protocol):
41+
MAX_WBITS: int
42+
Z_FULL_FLUSH: int
43+
Z_SYNC_FLUSH: int
44+
Z_BEST_SPEED: int
45+
Z_FINISH: int
46+
47+
def compressobj(
48+
self,
49+
level: int = ...,
50+
method: int = ...,
51+
wbits: int = ...,
52+
memLevel: int = ...,
53+
strategy: int = ...,
54+
zdict: Optional[Buffer] = ...,
55+
) -> ZLibCompressObjProtocol: ...
56+
def decompressobj(
57+
self, wbits: int = ..., zdict: Buffer = ...
58+
) -> ZLibDecompressObjProtocol: ...
59+
60+
def compress(
61+
self, data: Buffer, /, level: int = ..., wbits: int = ...
62+
) -> bytes: ...
63+
def decompress(
64+
self, data: Buffer, /, wbits: int = ..., bufsize: int = ...
65+
) -> bytes: ...
66+
67+
68+
class CompressObjArgs(TypedDict, total=False):
69+
wbits: int
70+
strategy: int
71+
level: int
72+
73+
74+
class ZLibBackendWrapper:
75+
def __init__(self, _zlib_backend: ZLibBackendProtocol):
76+
self._zlib_backend: ZLibBackendProtocol = _zlib_backend
77+
78+
@property
79+
def name(self) -> str:
80+
return getattr(self._zlib_backend, "__name__", "undefined")
81+
82+
@property
83+
def MAX_WBITS(self) -> int:
84+
return self._zlib_backend.MAX_WBITS
85+
86+
@property
87+
def Z_FULL_FLUSH(self) -> int:
88+
return self._zlib_backend.Z_FULL_FLUSH
89+
90+
@property
91+
def Z_SYNC_FLUSH(self) -> int:
92+
return self._zlib_backend.Z_SYNC_FLUSH
93+
94+
@property
95+
def Z_BEST_SPEED(self) -> int:
96+
return self._zlib_backend.Z_BEST_SPEED
97+
98+
@property
99+
def Z_FINISH(self) -> int:
100+
return self._zlib_backend.Z_FINISH
101+
102+
def compressobj(self, *args: Any, **kwargs: Any) -> ZLibCompressObjProtocol:
103+
return self._zlib_backend.compressobj(*args, **kwargs)
104+
105+
def decompressobj(self, *args: Any, **kwargs: Any) -> ZLibDecompressObjProtocol:
106+
return self._zlib_backend.decompressobj(*args, **kwargs)
107+
108+
def compress(self, data: Buffer, *args: Any, **kwargs: Any) -> bytes:
109+
return self._zlib_backend.compress(data, *args, **kwargs)
110+
111+
def decompress(self, data: Buffer, *args: Any, **kwargs: Any) -> bytes:
112+
return self._zlib_backend.decompress(data, *args, **kwargs)
113+
114+
# Everything not explicitly listed in the Protocol we just pass through
115+
def __getattr__(self, attrname: str) -> Any:
116+
return getattr(self._zlib_backend, attrname)
117+
118+
119+
ZLibBackend: ZLibBackendWrapper = ZLibBackendWrapper(zlib)
120+
121+
122+
def set_zlib_backend(new_zlib_backend: ZLibBackendProtocol) -> None:
123+
ZLibBackend._zlib_backend = new_zlib_backend
124+
125+
27126
def encoding_to_mode(
28127
encoding: Optional[str] = None,
29128
suppress_deflate_header: bool = False,
30129
) -> int:
31130
if encoding == "gzip":
32-
return 16 + zlib.MAX_WBITS
131+
return 16 + ZLibBackend.MAX_WBITS
33132

34-
return -zlib.MAX_WBITS if suppress_deflate_header else zlib.MAX_WBITS
133+
return -ZLibBackend.MAX_WBITS if suppress_deflate_header else ZLibBackend.MAX_WBITS
35134

36135

37136
class ZlibBaseHandler:
@@ -53,7 +152,7 @@ def __init__(
53152
suppress_deflate_header: bool = False,
54153
level: Optional[int] = None,
55154
wbits: Optional[int] = None,
56-
strategy: int = zlib.Z_DEFAULT_STRATEGY,
155+
strategy: Optional[int] = None,
57156
executor: Optional[Executor] = None,
58157
max_sync_chunk_size: Optional[int] = MAX_SYNC_CHUNK_SIZE,
59158
):
@@ -66,12 +165,15 @@ def __init__(
66165
executor=executor,
67166
max_sync_chunk_size=max_sync_chunk_size,
68167
)
69-
if level is None:
70-
self._compressor = zlib.compressobj(wbits=self._mode, strategy=strategy)
71-
else:
72-
self._compressor = zlib.compressobj(
73-
wbits=self._mode, strategy=strategy, level=level
74-
)
168+
self._zlib_backend: Final = ZLibBackendWrapper(ZLibBackend._zlib_backend)
169+
170+
kwargs: CompressObjArgs = {}
171+
kwargs["wbits"] = self._mode
172+
if strategy is not None:
173+
kwargs["strategy"] = strategy
174+
if level is not None:
175+
kwargs["level"] = level
176+
self._compressor = self._zlib_backend.compressobj(**kwargs)
75177
self._compress_lock = asyncio.Lock()
76178

77179
def compress_sync(self, data: Buffer) -> bytes:
@@ -100,8 +202,10 @@ async def compress(self, data: Buffer) -> bytes:
100202
)
101203
return self.compress_sync(data)
102204

103-
def flush(self, mode: int = zlib.Z_FINISH) -> bytes:
104-
return self._compressor.flush(mode)
205+
def flush(self, mode: Optional[int] = None) -> bytes:
206+
return self._compressor.flush(
207+
mode if mode is not None else self._zlib_backend.Z_FINISH
208+
)
105209

106210

107211
class ZLibDecompressor(ZlibBaseHandler):
@@ -117,7 +221,8 @@ def __init__(
117221
executor=executor,
118222
max_sync_chunk_size=max_sync_chunk_size,
119223
)
120-
self._decompressor = zlib.decompressobj(wbits=self._mode)
224+
self._zlib_backend: Final = ZLibBackendWrapper(ZLibBackend._zlib_backend)
225+
self._decompressor = self._zlib_backend.decompressobj(wbits=self._mode)
121226

122227
def decompress_sync(self, data: Buffer, max_length: int = 0) -> bytes:
123228
return self._decompressor.decompress(data, max_length)
@@ -149,14 +254,6 @@ def flush(self, length: int = 0) -> bytes:
149254
def eof(self) -> bool:
150255
return self._decompressor.eof
151256

152-
@property
153-
def unconsumed_tail(self) -> bytes:
154-
return self._decompressor.unconsumed_tail
155-
156-
@property
157-
def unused_data(self) -> bytes:
158-
return self._decompressor.unused_data
159-
160257

161258
class BrotliDecompressor:
162259
# Supports both 'brotlipy' and 'Brotli' packages

aiohttp/http_writer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import asyncio
44
import sys
5-
import zlib
65
from typing import ( # noqa
76
Any,
87
Awaitable,
@@ -85,7 +84,7 @@ def enable_chunking(self) -> None:
8584
self.chunked = True
8685

8786
def enable_compression(
88-
self, encoding: str = "deflate", strategy: int = zlib.Z_DEFAULT_STRATEGY
87+
self, encoding: str = "deflate", strategy: Optional[int] = None
8988
) -> None:
9089
self._compress = ZLibCompressor(encoding=encoding, strategy=strategy)
9190

aiohttp/multipart.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import sys
66
import uuid
77
import warnings
8-
import zlib
98
from collections import deque
109
from types import TracebackType
1110
from typing import (
@@ -1032,7 +1031,7 @@ def enable_encoding(self, encoding: str) -> None:
10321031
self._encoding = "quoted-printable"
10331032

10341033
def enable_compression(
1035-
self, encoding: str = "deflate", strategy: int = zlib.Z_DEFAULT_STRATEGY
1034+
self, encoding: str = "deflate", strategy: Optional[int] = None
10361035
) -> None:
10371036
self._compress = ZLibCompressor(
10381037
encoding=encoding,

aiohttp/web_response.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import math
77
import time
88
import warnings
9-
import zlib
109
from concurrent.futures import Executor
1110
from http import HTTPStatus
1211
from typing import (
@@ -83,7 +82,7 @@ class StreamResponse(BaseClass, HeadersMixin, CookieMixin):
8382
_keep_alive: Optional[bool] = None
8483
_chunked: bool = False
8584
_compression: bool = False
86-
_compression_strategy: int = zlib.Z_DEFAULT_STRATEGY
85+
_compression_strategy: Optional[int] = None
8786
_compression_force: Optional[ContentCoding] = None
8887
_req: Optional["BaseRequest"] = None
8988
_payload_writer: Optional[AbstractStreamWriter] = None
@@ -184,7 +183,7 @@ def enable_chunked_encoding(self) -> None:
184183
def enable_compression(
185184
self,
186185
force: Optional[ContentCoding] = None,
187-
strategy: int = zlib.Z_DEFAULT_STRATEGY,
186+
strategy: Optional[int] = None,
188187
) -> None:
189188
"""Enables response compression encoding."""
190189
self._compression = True

docs/client_reference.rst

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2145,6 +2145,30 @@ Utilities
21452145

21462146
.. versionadded:: 3.0
21472147

2148+
.. function:: set_zlib_backend(lib)
2149+
2150+
Sets the compression backend for zlib-based operations.
2151+
2152+
This function allows you to override the default zlib backend
2153+
used internally by passing a module that implements the standard
2154+
compression interface.
2155+
2156+
The module should implement at minimum the exact interface offered by the
2157+
latest version of zlib.
2158+
2159+
:param types.ModuleType lib: A module that implements the zlib-compatible compression API.
2160+
2161+
Example usage::
2162+
2163+
import zlib_ng.zlib_ng as zng
2164+
import aiohttp
2165+
2166+
aiohttp.set_zlib_backend(zng)
2167+
2168+
.. note:: aiohttp has been tested internally with :mod:`zlib`, :mod:`zlib_ng.zlib_ng`, and :mod:`isal.isal_zlib`.
2169+
2170+
.. versionadded:: 3.12
2171+
21482172
FormData
21492173
^^^^^^^^
21502174

docs/conf.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@
8484
"aiohttpdemos": ("https://aiohttp-demos.readthedocs.io/en/latest/", None),
8585
"aiojobs": ("https://aiojobs.readthedocs.io/en/stable/", None),
8686
"aiohappyeyeballs": ("https://aiohappyeyeballs.readthedocs.io/en/latest/", None),
87+
"isal": ("https://python-isal.readthedocs.io/en/stable/", None),
88+
"zlib_ng": ("https://python-zlib-ng.readthedocs.io/en/stable/", None),
8789
}
8890

8991
# Add any paths that contain templates here, relative to this directory.

0 commit comments

Comments
 (0)