Skip to content

Commit f8b1e8f

Browse files
Fix/ws heartbeat reset on data (#12030) - fix backport merge (#12074)
(cherry picked from commit a640f4f) Co-authored-by: Sam Bull <[email protected]>
1 parent dc89aec commit f8b1e8f

13 files changed

Lines changed: 403 additions & 20 deletions

CHANGES/12030.bugfix.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Reset the WebSocket heartbeat timer on inbound data to avoid false ping/pong timeouts while receiving large frames
2+
-- by :user:`hoffmang9`.

CONTRIBUTORS.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ Dmitry Trofimov
115115
Dmytro Bohomiakov
116116
Dmytro Kuznetsov
117117
Dustin J. Mitchell
118+
Earle Lowe
118119
Eduard Iskandarov
119120
Eli Ribble
120121
Elizabeth Leddy
@@ -139,6 +140,7 @@ Gabriel Tremblay
139140
Gang Ji
140141
Gary Leung
141142
Gary Wilson Jr.
143+
Gene Hoffman
142144
Gennady Andreyev
143145
Georges Dubus
144146
Greg Holt

aiohttp/client.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1269,9 +1269,6 @@ async def _ws_connect(
12691269
transport = conn.transport
12701270
assert transport is not None
12711271
reader = WebSocketDataQueue(conn_proto, 2**16, loop=self._loop)
1272-
conn_proto.set_parser(
1273-
WebSocketReader(reader, max_msg_size, decode_text=decode_text), reader
1274-
)
12751272
writer = WebSocketWriter(
12761273
conn_proto,
12771274
transport,
@@ -1283,7 +1280,7 @@ async def _ws_connect(
12831280
resp.close()
12841281
raise
12851282
else:
1286-
return self._ws_response_class(
1283+
ws_resp = self._ws_response_class(
12871284
reader,
12881285
writer,
12891286
protocol,
@@ -1296,6 +1293,10 @@ async def _ws_connect(
12961293
compress=compress,
12971294
client_notakeover=notakeover,
12981295
)
1296+
parser = WebSocketReader(reader, max_msg_size, decode_text=decode_text)
1297+
cb = None if heartbeat is None else ws_resp._on_data_received
1298+
conn_proto.set_parser(parser, reader, data_received_cb=cb)
1299+
return ws_resp
12991300

13001301
def _prepare_headers(self, headers: LooseHeaders | None) -> "CIMultiDict[str]":
13011302
"""Add default headers and transform it to CIMultiDict"""

aiohttp/client_proto.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import asyncio
22
from contextlib import suppress
3-
from typing import Any
3+
from typing import Any, Callable
44

55
from .base_protocol import BaseProtocol
66
from .client_exceptions import (
@@ -34,6 +34,7 @@ def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
3434
self._payload: StreamReader | None = None
3535
self._skip_payload = False
3636
self._payload_parser = None
37+
self._data_received_cb: Callable[[], None] | None = None
3738

3839
self._timer = None
3940

@@ -203,14 +204,20 @@ def set_exception(
203204
self._drop_timeout()
204205
super().set_exception(exc, exc_cause)
205206

206-
def set_parser(self, parser: Any, payload: Any) -> None:
207+
def set_parser(
208+
self,
209+
parser: Any,
210+
payload: Any,
211+
data_received_cb: Callable[[], None] | None = None,
212+
) -> None:
207213
# TODO: actual types are:
208214
# parser: WebSocketReader
209215
# payload: WebSocketDataQueue
210216
# but they are not generi enough
211217
# Need an ABC for both types
212218
self._payload = payload
213219
self._payload_parser = parser
220+
self._data_received_cb = data_received_cb
214221

215222
self._drop_timeout()
216223

@@ -298,6 +305,8 @@ def data_received(self, data: bytes) -> None:
298305

299306
# custom payload parser - currently always WebSocketReader
300307
if self._payload_parser is not None:
308+
if self._data_received_cb is not None:
309+
self._data_received_cb()
301310
eof, tail = self._payload_parser.feed_data(data)
302311
if eof:
303312
self._payload = None

aiohttp/client_ws.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,11 +98,17 @@ def __init__(
9898
self._compress = compress
9999
self._client_notakeover = client_notakeover
100100
self._ping_task: asyncio.Task[None] | None = None
101+
self._need_heartbeat_reset = False
102+
self._heartbeat_reset_handle: asyncio.Handle | None = None
101103

102104
self._reset_heartbeat()
103105

104106
def _cancel_heartbeat(self) -> None:
105107
self._cancel_pong_response_cb()
108+
if self._heartbeat_reset_handle is not None:
109+
self._heartbeat_reset_handle.cancel()
110+
self._heartbeat_reset_handle = None
111+
self._need_heartbeat_reset = False
106112
if self._heartbeat_cb is not None:
107113
self._heartbeat_cb.cancel()
108114
self._heartbeat_cb = None
@@ -115,6 +121,23 @@ def _cancel_pong_response_cb(self) -> None:
115121
self._pong_response_cb.cancel()
116122
self._pong_response_cb = None
117123

124+
def _on_data_received(self) -> None:
125+
if self._heartbeat is None or self._need_heartbeat_reset:
126+
return
127+
loop = self._loop
128+
assert loop is not None
129+
# Coalesce multiple chunks received in the same loop tick into a single
130+
# heartbeat reset. Resetting immediately per chunk increases timer churn.
131+
self._need_heartbeat_reset = True
132+
self._heartbeat_reset_handle = loop.call_soon(self._flush_heartbeat_reset)
133+
134+
def _flush_heartbeat_reset(self) -> None:
135+
self._heartbeat_reset_handle = None
136+
if not self._need_heartbeat_reset:
137+
return
138+
self._reset_heartbeat()
139+
self._need_heartbeat_reset = False
140+
118141
def _reset_heartbeat(self) -> None:
119142
if self._heartbeat is None:
120143
return
@@ -138,6 +161,12 @@ def _reset_heartbeat(self) -> None:
138161

139162
def _send_heartbeat(self) -> None:
140163
self._heartbeat_cb = None
164+
165+
# If heartbeat reset is pending (data is being received), skip sending
166+
# the ping and let the reset callback handle rescheduling the heartbeat.
167+
if self._need_heartbeat_reset:
168+
return
169+
141170
loop = self._loop
142171
now = loop.time()
143172
if now < self._heartbeat_when:
@@ -365,7 +394,6 @@ async def receive(
365394
msg = await self._reader.read()
366395
else:
367396
msg = await self._reader.read()
368-
self._reset_heartbeat()
369397
finally:
370398
self._waiting = False
371399
if self._close_wait:

aiohttp/web_protocol.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ class RequestHandler(BaseProtocol):
148148
"_task_handler",
149149
"_upgrade",
150150
"_payload_parser",
151+
"_data_received_cb",
151152
"_request_parser",
152153
"_reading_paused",
153154
"logger",
@@ -203,6 +204,7 @@ def __init__(
203204

204205
self._messages: deque[_MsgType] = deque()
205206
self._message_tail = b""
207+
self._data_received_cb: Callable[[], None] | None = None
206208

207209
self._waiter: asyncio.Future[None] | None = None
208210
self._handler_waiter: asyncio.Future[None] | None = None
@@ -373,11 +375,14 @@ def connection_lost(self, exc: BaseException | None) -> None:
373375
self._payload_parser.feed_eof()
374376
self._payload_parser = None
375377

376-
def set_parser(self, parser: Any) -> None:
378+
def set_parser(
379+
self, parser: Any, data_received_cb: Callable[[], None] | None = None
380+
) -> None:
377381
# Actual type is WebReader
378382
assert self._payload_parser is None
379383

380384
self._payload_parser = parser
385+
self._data_received_cb = data_received_cb
381386

382387
if self._message_tail:
383388
self._payload_parser.feed_data(self._message_tail)
@@ -421,6 +426,8 @@ def data_received(self, data: bytes) -> None:
421426

422427
# feed payload
423428
elif data:
429+
if self._data_received_cb is not None:
430+
self._data_received_cb()
424431
eof, tail = self._payload_parser.feed_data(data)
425432
if eof:
426433
self.close()

aiohttp/web_ws.py

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,8 @@ class WebSocketResponse(StreamResponse, Generic[_DecodeText]):
9090
_heartbeat_cb: asyncio.TimerHandle | None = None
9191
_pong_response_cb: asyncio.TimerHandle | None = None
9292
_ping_task: asyncio.Task[None] | None = None
93+
_need_heartbeat_reset: bool = False
94+
_heartbeat_reset_handle: asyncio.Handle | None = None
9395

9496
def __init__(
9597
self,
@@ -118,9 +120,15 @@ def __init__(
118120
self._max_msg_size = max_msg_size
119121
self._writer_limit = writer_limit
120122
self._decode_text = decode_text
123+
self._need_heartbeat_reset = False
124+
self._heartbeat_reset_handle = None
121125

122126
def _cancel_heartbeat(self) -> None:
123127
self._cancel_pong_response_cb()
128+
if self._heartbeat_reset_handle is not None:
129+
self._heartbeat_reset_handle.cancel()
130+
self._heartbeat_reset_handle = None
131+
self._need_heartbeat_reset = False
124132
if self._heartbeat_cb is not None:
125133
self._heartbeat_cb.cancel()
126134
self._heartbeat_cb = None
@@ -133,6 +141,23 @@ def _cancel_pong_response_cb(self) -> None:
133141
self._pong_response_cb.cancel()
134142
self._pong_response_cb = None
135143

144+
def _on_data_received(self) -> None:
145+
if self._heartbeat is None or self._need_heartbeat_reset:
146+
return
147+
loop = self._loop
148+
assert loop is not None
149+
# Coalesce multiple chunks received in the same loop tick into a single
150+
# heartbeat reset. Resetting immediately per chunk increases timer churn.
151+
self._need_heartbeat_reset = True
152+
self._heartbeat_reset_handle = loop.call_soon(self._flush_heartbeat_reset)
153+
154+
def _flush_heartbeat_reset(self) -> None:
155+
self._heartbeat_reset_handle = None
156+
if not self._need_heartbeat_reset:
157+
return
158+
self._reset_heartbeat()
159+
self._need_heartbeat_reset = False
160+
136161
def _reset_heartbeat(self) -> None:
137162
if self._heartbeat is None:
138163
return
@@ -156,6 +181,12 @@ def _reset_heartbeat(self) -> None:
156181

157182
def _send_heartbeat(self) -> None:
158183
self._heartbeat_cb = None
184+
185+
# If heartbeat reset is pending (data is being received), skip sending
186+
# the ping and let the reset callback handle rescheduling the heartbeat.
187+
if self._need_heartbeat_reset:
188+
return
189+
159190
loop = self._loop
160191
assert loop is not None and self._writer is not None
161192
now = loop.time()
@@ -349,14 +380,14 @@ def _post_start(
349380
loop = self._loop
350381
assert loop is not None
351382
self._reader = WebSocketDataQueue(request._protocol, 2**16, loop=loop)
352-
request.protocol.set_parser(
353-
WebSocketReader(
354-
self._reader,
355-
self._max_msg_size,
356-
compress=bool(self._compress),
357-
decode_text=self._decode_text,
358-
)
383+
parser = WebSocketReader(
384+
self._reader,
385+
self._max_msg_size,
386+
compress=bool(self._compress),
387+
decode_text=self._decode_text,
359388
)
389+
cb = None if self._heartbeat is None else self._on_data_received
390+
request.protocol.set_parser(parser, data_received_cb=cb)
360391
# disable HTTP keepalive for WebSocket
361392
request.protocol.keep_alive(False)
362393

@@ -576,7 +607,6 @@ async def receive(
576607
msg = await self._reader.read()
577608
else:
578609
msg = await self._reader.read()
579-
self._reset_heartbeat()
580610
finally:
581611
self._waiting = False
582612
if self._close_wait:

docs/client_reference.rst

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -777,8 +777,9 @@ The client session supports the context manager protocol for self closing.
777777
:param float heartbeat: Send *ping* message every *heartbeat*
778778
seconds and wait *pong* response, if
779779
*pong* response is not received then
780-
close connection. The timer is reset on any data
781-
reception.(optional)
780+
close connection. The timer is reset on any
781+
inbound data reception (coalesced per event loop
782+
iteration). (optional)
782783

783784
:param str origin: Origin header to send to server(optional)
784785

docs/web_reference.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -999,7 +999,8 @@ and :ref:`aiohttp-web-signals` handlers::
999999
:param float heartbeat: Send `ping` message every `heartbeat`
10001000
seconds and wait `pong` response, close
10011001
connection if `pong` response is not
1002-
received. The timer is reset on any data reception.
1002+
received. The timer is reset on any inbound data
1003+
reception (coalesced per event loop iteration).
10031004

10041005
:param float timeout: Timeout value for the ``close``
10051006
operation. After sending the close websocket message,

0 commit comments

Comments
 (0)