From 49249861e956356c38c4eebd74e00afe9b40df50 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 24 Oct 2025 07:59:26 +0200 Subject: [PATCH 01/13] Rename variables for clarity. ..._waiter is internal asyncio terminology that I don't find intuitive. --- docs/topics/keepalive.rst | 4 +- src/websockets/asyncio/connection.py | 74 ++++++++++++++-------------- src/websockets/sync/connection.py | 34 ++++++------- tests/asyncio/test_connection.py | 42 ++++++++-------- tests/sync/test_connection.py | 24 ++++----- 5 files changed, 89 insertions(+), 89 deletions(-) diff --git a/docs/topics/keepalive.rst b/docs/topics/keepalive.rst index e63c2f8f5..fd8300183 100644 --- a/docs/topics/keepalive.rst +++ b/docs/topics/keepalive.rst @@ -136,8 +136,8 @@ measured during the last exchange of Ping and Pong frames:: Alternatively, you can measure the latency at any time by calling :attr:`~asyncio.connection.Connection.ping` and awaiting its result:: - pong_waiter = await websocket.ping() - latency = await pong_waiter + pong_received = await websocket.ping() + latency = await pong_received Latency between a client and a server may increase for two reasons: diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index 592480f91..3540648cd 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -101,10 +101,10 @@ def __init__( self.close_deadline: float | None = None # Protect sending fragmented messages. - self.fragmented_send_waiter: asyncio.Future[None] | None = None + self.send_in_progress: asyncio.Future[None] | None = None # Mapping of ping IDs to pong waiters, in chronological order. - self.pong_waiters: dict[bytes, tuple[asyncio.Future[float], float]] = {} + self.pending_pings: dict[bytes, tuple[asyncio.Future[float], float]] = {} self.latency: float = 0 """ @@ -468,8 +468,8 @@ async def send( """ # While sending a fragmented message, prevent sending other messages # until all fragments are sent. - while self.fragmented_send_waiter is not None: - await asyncio.shield(self.fragmented_send_waiter) + while self.send_in_progress is not None: + await asyncio.shield(self.send_in_progress) # Unfragmented message -- this case must be handled first because # strings and bytes-like objects are iterable. @@ -502,8 +502,8 @@ async def send( except StopIteration: return - assert self.fragmented_send_waiter is None - self.fragmented_send_waiter = self.loop.create_future() + assert self.send_in_progress is None + self.send_in_progress = self.loop.create_future() try: # First fragment. if isinstance(chunk, str): @@ -549,8 +549,8 @@ async def send( raise finally: - self.fragmented_send_waiter.set_result(None) - self.fragmented_send_waiter = None + self.send_in_progress.set_result(None) + self.send_in_progress = None # Fragmented message -- async iterator. @@ -561,8 +561,8 @@ async def send( except StopAsyncIteration: return - assert self.fragmented_send_waiter is None - self.fragmented_send_waiter = self.loop.create_future() + assert self.send_in_progress is None + self.send_in_progress = self.loop.create_future() try: # First fragment. if isinstance(chunk, str): @@ -610,8 +610,8 @@ async def send( raise finally: - self.fragmented_send_waiter.set_result(None) - self.fragmented_send_waiter = None + self.send_in_progress.set_result(None) + self.send_in_progress = None else: raise TypeError("data must be str, bytes, iterable, or async iterable") @@ -639,7 +639,7 @@ async def close( # The context manager takes care of waiting for the TCP connection # to terminate after calling a method that sends a close frame. async with self.send_context(): - if self.fragmented_send_waiter is not None: + if self.send_in_progress is not None: self.protocol.fail( CloseCode.INTERNAL_ERROR, "close during fragmented message", @@ -681,9 +681,9 @@ async def ping(self, data: DataLike | None = None) -> Awaitable[float]: :: - pong_waiter = await ws.ping() + pong_received = await ws.ping() # only if you want to wait for the corresponding pong - latency = await pong_waiter + latency = await pong_received Raises: ConnectionClosed: When the connection is closed. @@ -700,19 +700,19 @@ async def ping(self, data: DataLike | None = None) -> Awaitable[float]: async with self.send_context(): # Protect against duplicates if a payload is explicitly set. - if data in self.pong_waiters: + if data in self.pending_pings: raise ConcurrencyError("already waiting for a pong with the same data") # Generate a unique random payload otherwise. - while data is None or data in self.pong_waiters: + while data is None or data in self.pending_pings: data = struct.pack("!I", random.getrandbits(32)) - pong_waiter = self.loop.create_future() + pong_received = self.loop.create_future() # The event loop's default clock is time.monotonic(). Its resolution # is a bit low on Windows (~16ms). This is improved in Python 3.13. - self.pong_waiters[data] = (pong_waiter, self.loop.time()) + self.pending_pings[data] = (pong_received, self.loop.time()) self.protocol.send_ping(data) - return pong_waiter + return pong_received async def pong(self, data: DataLike = b"") -> None: """ @@ -761,7 +761,7 @@ def acknowledge_pings(self, data: bytes) -> None: """ # Ignore unsolicited pong. - if data not in self.pong_waiters: + if data not in self.pending_pings: return pong_timestamp = self.loop.time() @@ -770,22 +770,22 @@ def acknowledge_pings(self, data: bytes) -> None: # Acknowledge all previous pings too in that case. ping_id = None ping_ids = [] - for ping_id, (pong_waiter, ping_timestamp) in self.pong_waiters.items(): + for ping_id, (pong_received, ping_timestamp) in self.pending_pings.items(): ping_ids.append(ping_id) latency = pong_timestamp - ping_timestamp - if not pong_waiter.done(): - pong_waiter.set_result(latency) + if not pong_received.done(): + pong_received.set_result(latency) if ping_id == data: self.latency = latency break else: raise AssertionError("solicited pong not found in pings") - # Remove acknowledged pings from self.pong_waiters. + # Remove acknowledged pings from self.pending_pings. for ping_id in ping_ids: - del self.pong_waiters[ping_id] + del self.pending_pings[ping_id] - def abort_pings(self) -> None: + def terminate_pending_pings(self) -> None: """ Raise ConnectionClosed in pending pings. @@ -795,16 +795,16 @@ def abort_pings(self) -> None: assert self.protocol.state is CLOSED exc = self.protocol.close_exc - for pong_waiter, _ping_timestamp in self.pong_waiters.values(): - if not pong_waiter.done(): - pong_waiter.set_exception(exc) + for pong_received, _ping_timestamp in self.pending_pings.values(): + if not pong_received.done(): + pong_received.set_exception(exc) # If the exception is never retrieved, it will be logged when ping # is garbage-collected. This is confusing for users. # Given that ping is done (with an exception), canceling it does # nothing, but it prevents logging the exception. - pong_waiter.cancel() + pong_received.cancel() - self.pong_waiters.clear() + self.pending_pings.clear() async def keepalive(self) -> None: """ @@ -825,7 +825,7 @@ async def keepalive(self) -> None: # connection to be closed before raising ConnectionClosed. # However, connection_lost() cancels keepalive_task before # it gets a chance to resume excuting. - pong_waiter = await self.ping() + pong_received = await self.ping() if self.debug: self.logger.debug("% sent keepalive ping") @@ -834,9 +834,9 @@ async def keepalive(self) -> None: async with asyncio_timeout(self.ping_timeout): # connection_lost cancels keepalive immediately # after setting a ConnectionClosed exception on - # pong_waiter. A CancelledError is raised here, + # pong_received. A CancelledError is raised here, # not a ConnectionClosed exception. - latency = await pong_waiter + latency = await pong_received self.logger.debug("% received keepalive pong") except asyncio.TimeoutError: if self.debug: @@ -1022,7 +1022,7 @@ def connection_lost(self, exc: Exception | None) -> None: # Abort recv() and pending pings with a ConnectionClosed exception. self.recv_messages.close() - self.abort_pings() + self.terminate_pending_pings() if self.keepalive_task is not None: self.keepalive_task.cancel() @@ -1205,7 +1205,7 @@ def broadcast( if connection.protocol.state is not OPEN: continue - if connection.fragmented_send_waiter is not None: + if connection.send_in_progress is not None: if raise_exceptions: exception = ConcurrencyError("sending a fragmented message") exceptions.append(exception) diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index 6ef1ef039..351679b90 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -104,7 +104,7 @@ def __init__( self.send_in_progress = False # Mapping of ping IDs to pong waiters, in chronological order. - self.pong_waiters: dict[bytes, tuple[threading.Event, float, bool]] = {} + self.pending_pings: dict[bytes, tuple[threading.Event, float, bool]] = {} self.latency: float = 0 """ @@ -651,17 +651,17 @@ def ping( with self.send_context(): # Protect against duplicates if a payload is explicitly set. - if data in self.pong_waiters: + if data in self.pending_pings: raise ConcurrencyError("already waiting for a pong with the same data") # Generate a unique random payload otherwise. - while data is None or data in self.pong_waiters: + while data is None or data in self.pending_pings: data = struct.pack("!I", random.getrandbits(32)) - pong_waiter = threading.Event() - self.pong_waiters[data] = (pong_waiter, time.monotonic(), ack_on_close) + pong_received = threading.Event() + self.pending_pings[data] = (pong_received, time.monotonic(), ack_on_close) self.protocol.send_ping(data) - return pong_waiter + return pong_received def pong(self, data: DataLike = b"") -> None: """ @@ -711,7 +711,7 @@ def acknowledge_pings(self, data: bytes) -> None: """ with self.protocol_mutex: # Ignore unsolicited pong. - if data not in self.pong_waiters: + if data not in self.pending_pings: return pong_timestamp = time.monotonic() @@ -721,21 +721,21 @@ def acknowledge_pings(self, data: bytes) -> None: ping_id = None ping_ids = [] for ping_id, ( - pong_waiter, + pong_received, ping_timestamp, _ack_on_close, - ) in self.pong_waiters.items(): + ) in self.pending_pings.items(): ping_ids.append(ping_id) - pong_waiter.set() + pong_received.set() if ping_id == data: self.latency = pong_timestamp - ping_timestamp break else: raise AssertionError("solicited pong not found in pings") - # Remove acknowledged pings from self.pong_waiters. + # Remove acknowledged pings from self.pending_pings. for ping_id in ping_ids: - del self.pong_waiters[ping_id] + del self.pending_pings[ping_id] def acknowledge_pending_pings(self) -> None: """ @@ -744,11 +744,11 @@ def acknowledge_pending_pings(self) -> None: """ assert self.protocol.state is CLOSED - for pong_waiter, _ping_timestamp, ack_on_close in self.pong_waiters.values(): + for pong_received, _ping_timestamp, ack_on_close in self.pending_pings.values(): if ack_on_close: - pong_waiter.set() + pong_received.set() - self.pong_waiters.clear() + self.pending_pings.clear() def keepalive(self) -> None: """ @@ -766,7 +766,7 @@ def keepalive(self) -> None: break try: - pong_waiter = self.ping(ack_on_close=True) + pong_received = self.ping(ack_on_close=True) except ConnectionClosed: break if self.debug: @@ -774,7 +774,7 @@ def keepalive(self) -> None: if self.ping_timeout is not None: # - if pong_waiter.wait(self.ping_timeout): + if pong_received.wait(self.ping_timeout): if self.debug: self.logger.debug("% received keepalive pong") else: diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index 39fc953dc..92dbf5392 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -554,7 +554,7 @@ async def test_send_connection_closed_error(self): async def test_send_while_send_blocked(self): """send waits for a previous call to send to complete.""" - # This test fails if the guard with fragmented_send_waiter is removed + # This test fails if the guard with send_in_progress is removed # from send() in the case when message is an Iterable. self.connection.pause_writing() asyncio.create_task(self.connection.send(["⏳", "⌛️"])) @@ -579,7 +579,7 @@ async def test_send_while_send_blocked(self): async def test_send_while_send_async_blocked(self): """send waits for a previous call to send to complete.""" - # This test fails if the guard with fragmented_send_waiter is removed + # This test fails if the guard with send_in_progress is removed # from send() in the case when message is an AsyncIterable. self.connection.pause_writing() @@ -609,7 +609,7 @@ async def fragments(): async def test_send_during_send_async(self): """send waits for a previous call to send to complete.""" - # This test fails if the guard with fragmented_send_waiter is removed + # This test fails if the guard with send_in_progress is removed # from send() in the case when message is an AsyncIterable. gate = asyncio.get_running_loop().create_future() @@ -884,54 +884,54 @@ async def test_ping_explicit_binary(self): async def test_acknowledge_ping(self): """ping is acknowledged by a pong with the same payload.""" async with self.drop_frames_rcvd(): # drop automatic response to ping - pong_waiter = await self.connection.ping("this") + pong_received = await self.connection.ping("this") await self.remote_connection.pong("this") async with asyncio_timeout(MS): - await pong_waiter + await pong_received async def test_acknowledge_canceled_ping(self): """ping is acknowledged by a pong with the same payload after being canceled.""" async with self.drop_frames_rcvd(): # drop automatic response to ping - pong_waiter = await self.connection.ping("this") - pong_waiter.cancel() + pong_received = await self.connection.ping("this") + pong_received.cancel() await self.remote_connection.pong("this") with self.assertRaises(asyncio.CancelledError): - await pong_waiter + await pong_received async def test_acknowledge_ping_non_matching_pong(self): """ping isn't acknowledged by a pong with a different payload.""" async with self.drop_frames_rcvd(): # drop automatic response to ping - pong_waiter = await self.connection.ping("this") + pong_received = await self.connection.ping("this") await self.remote_connection.pong("that") with self.assertRaises(TimeoutError): async with asyncio_timeout(MS): - await pong_waiter + await pong_received async def test_acknowledge_previous_ping(self): """ping is acknowledged by a pong for a later ping.""" async with self.drop_frames_rcvd(): # drop automatic response to ping - pong_waiter = await self.connection.ping("this") + pong_received = await self.connection.ping("this") await self.connection.ping("that") await self.remote_connection.pong("that") async with asyncio_timeout(MS): - await pong_waiter + await pong_received async def test_acknowledge_previous_canceled_ping(self): """ping is acknowledged by a pong for a later ping after being canceled.""" async with self.drop_frames_rcvd(): # drop automatic response to ping - pong_waiter = await self.connection.ping("this") - pong_waiter_2 = await self.connection.ping("that") - pong_waiter.cancel() + pong_received = await self.connection.ping("this") + pong_received_2 = await self.connection.ping("that") + pong_received.cancel() await self.remote_connection.pong("that") async with asyncio_timeout(MS): - await pong_waiter_2 + await pong_received_2 with self.assertRaises(asyncio.CancelledError): - await pong_waiter + await pong_received async def test_ping_duplicate_payload(self): """ping rejects the same payload until receiving the pong.""" async with self.drop_frames_rcvd(): # drop automatic response to ping - pong_waiter = await self.connection.ping("idem") + pong_received = await self.connection.ping("idem") with self.assertRaises(ConcurrencyError) as raised: await self.connection.ping("idem") @@ -942,7 +942,7 @@ async def test_ping_duplicate_payload(self): await self.remote_connection.pong("idem") async with asyncio_timeout(MS): - await pong_waiter + await pong_received await self.connection.ping("idem") # doesn't raise an exception @@ -1060,9 +1060,9 @@ async def test_keepalive_reports_errors(self): await asyncio.sleep(2 * MS) # Exiting the context manager sleeps for 1 ms. # 3 ms: inject a fault: raise an exception in the pending pong waiter. - pong_waiter = next(iter(self.connection.pong_waiters.values()))[0] + pong_received = next(iter(self.connection.pending_pings.values()))[0] with self.assertLogs("websockets", logging.ERROR) as logs: - pong_waiter.set_exception(Exception("BOOM")) + pong_received.set_exception(Exception("BOOM")) await asyncio.sleep(0) self.assertEqual( [record.getMessage() for record in logs.records], diff --git a/tests/sync/test_connection.py b/tests/sync/test_connection.py index 5558b662c..b81102079 100644 --- a/tests/sync/test_connection.py +++ b/tests/sync/test_connection.py @@ -628,38 +628,38 @@ def test_ping_explicit_binary(self): def test_acknowledge_ping(self): """ping is acknowledged by a pong with the same payload.""" with self.drop_frames_rcvd(): # drop automatic response to ping - pong_waiter = self.connection.ping("this") + pong_received = self.connection.ping("this") self.remote_connection.pong("this") - self.assertTrue(pong_waiter.wait(MS)) + self.assertTrue(pong_received.wait(MS)) def test_acknowledge_ping_non_matching_pong(self): """ping isn't acknowledged by a pong with a different payload.""" with self.drop_frames_rcvd(): # drop automatic response to ping - pong_waiter = self.connection.ping("this") + pong_received = self.connection.ping("this") self.remote_connection.pong("that") - self.assertFalse(pong_waiter.wait(MS)) + self.assertFalse(pong_received.wait(MS)) def test_acknowledge_previous_ping(self): """ping is acknowledged by a pong for as a later ping.""" with self.drop_frames_rcvd(): # drop automatic response to ping - pong_waiter = self.connection.ping("this") + pong_received = self.connection.ping("this") self.connection.ping("that") self.remote_connection.pong("that") - self.assertTrue(pong_waiter.wait(MS)) + self.assertTrue(pong_received.wait(MS)) def test_acknowledge_ping_on_close(self): """ping with ack_on_close is acknowledged when the connection is closed.""" with self.drop_frames_rcvd(): # drop automatic response to ping - pong_waiter_ack_on_close = self.connection.ping("this", ack_on_close=True) - pong_waiter = self.connection.ping("that") + pong_received_aoc = self.connection.ping("this", ack_on_close=True) + pong_received = self.connection.ping("that") self.connection.close() - self.assertTrue(pong_waiter_ack_on_close.wait(MS)) - self.assertFalse(pong_waiter.wait(MS)) + self.assertTrue(pong_received_aoc.wait(MS)) + self.assertFalse(pong_received.wait(MS)) def test_ping_duplicate_payload(self): """ping rejects the same payload until receiving the pong.""" with self.drop_frames_rcvd(): # drop automatic response to ping - pong_waiter = self.connection.ping("idem") + pong_received = self.connection.ping("idem") with self.assertRaises(ConcurrencyError) as raised: self.connection.ping("idem") @@ -669,7 +669,7 @@ def test_ping_duplicate_payload(self): ) self.remote_connection.pong("idem") - self.assertTrue(pong_waiter.wait(MS)) + self.assertTrue(pong_received.wait(MS)) self.connection.ping("idem") # doesn't raise an exception From 759b2a2d9247bc5d3eec069a5cee7ef4d343cc87 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 24 Oct 2025 08:03:59 +0200 Subject: [PATCH 02/13] Clarify which argument is high or low limit. --- src/websockets/asyncio/connection.py | 18 +++++++++++------- src/websockets/sync/connection.py | 8 ++++---- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index 3540648cd..bf733adb9 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -63,14 +63,14 @@ def __init__( self.ping_interval = ping_interval self.ping_timeout = ping_timeout self.close_timeout = close_timeout - self.max_queue: tuple[int | None, int | None] if isinstance(max_queue, int) or max_queue is None: - self.max_queue = (max_queue, None) + self.max_queue_high, self.max_queue_low = max_queue, None else: - self.max_queue = max_queue + self.max_queue_high, self.max_queue_low = max_queue if isinstance(write_limit, int): - write_limit = (write_limit, None) - self.write_limit = write_limit + self.write_limit_high, self.write_limit_low = write_limit, None + else: + self.write_limit_high, self.write_limit_low = write_limit # Inject reference to this instance in the protocol's logger. self.protocol.logger = logging.LoggerAdapter( @@ -1005,11 +1005,15 @@ def set_recv_exc(self, exc: BaseException | None) -> None: def connection_made(self, transport: asyncio.BaseTransport) -> None: transport = cast(asyncio.Transport, transport) self.recv_messages = Assembler( - *self.max_queue, + self.max_queue_high, + self.max_queue_low, pause=transport.pause_reading, resume=transport.resume_reading, ) - transport.set_write_buffer_limits(*self.write_limit) + transport.set_write_buffer_limits( + self.write_limit_high, + self.write_limit_low, + ) self.transport = transport def connection_lost(self, exc: Exception | None) -> None: diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index 351679b90..aab584f58 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -59,11 +59,10 @@ def __init__( self.ping_interval = ping_interval self.ping_timeout = ping_timeout self.close_timeout = close_timeout - self.max_queue: tuple[int | None, int | None] if isinstance(max_queue, int) or max_queue is None: - self.max_queue = (max_queue, None) + max_queue_high, max_queue_low = max_queue, None else: - self.max_queue = max_queue + max_queue_high, max_queue_low = max_queue # Inject reference to this instance in the protocol's logger. self.protocol.logger = logging.LoggerAdapter( @@ -92,7 +91,8 @@ def __init__( # Assembler turning frames into messages and serializing reads. self.recv_messages = Assembler( - *self.max_queue, + max_queue_high, + max_queue_low, pause=self.recv_flow_control.acquire, resume=self.recv_flow_control.release, ) From 2715b35c5c65f40d6fed68e8b80c4dff96c1e67d Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 1 Nov 2025 08:35:45 +0100 Subject: [PATCH 03/13] Align the asyncio and sync connection modules. --- src/websockets/asyncio/connection.py | 88 +++++++++++---------- src/websockets/sync/connection.py | 111 +++++++++++++-------------- 2 files changed, 101 insertions(+), 98 deletions(-) diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index bf733adb9..205a2be50 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -100,19 +100,19 @@ def __init__( # Deadline for the closing handshake. self.close_deadline: float | None = None - # Protect sending fragmented messages. + # Whether we are busy sending a fragmented message. self.send_in_progress: asyncio.Future[None] | None = None # Mapping of ping IDs to pong waiters, in chronological order. self.pending_pings: dict[bytes, tuple[asyncio.Future[float], float]] = {} - self.latency: float = 0 + self.latency: float = 0.0 """ Latency of the connection, in seconds. Latency is defined as the round-trip time of the connection. It is measured by sending a Ping frame and waiting for a matching Pong frame. - Before the first measurement, :attr:`latency` is ``0``. + Before the first measurement, :attr:`latency` is ``0.0``. By default, websockets enables a :ref:`keepalive ` mechanism that sends Ping frames automatically at regular intervals. You can also @@ -130,7 +130,7 @@ def __init__( # connection state becomes CLOSED. self.connection_lost_waiter: asyncio.Future[None] = self.loop.create_future() - # Adapted from asyncio.FlowControlMixin + # Adapted from asyncio.FlowControlMixin. self.paused: bool = False self.drain_waiters: collections.deque[asyncio.Future[None]] = ( collections.deque() @@ -291,9 +291,9 @@ async def recv(self, decode: bool | None = None) -> Data: return a bytestring (:class:`bytes`). This improves performance when decoding isn't needed, for example if the message contains JSON and you're using a JSON library that expects a bytestring. - * Set ``decode=True`` to force UTF-8 decoding of Binary_ frames - and return a string (:class:`str`). This may be useful for - servers that send binary frames instead of text frames. + * Set ``decode=True`` to force UTF-8 decoding of Binary_ frames and + return strings (:class:`str`). This may be useful for servers that + send binary frames instead of text frames. Raises: ConnectionClosed: When the connection is closed. @@ -363,12 +363,12 @@ async def recv_streaming(self, decode: bool | None = None) -> AsyncIterator[Data You may override this behavior with the ``decode`` argument: - * Set ``decode=False`` to disable UTF-8 decoding of Text_ frames - and return bytestrings (:class:`bytes`). This may be useful to - optimize performance when decoding isn't needed. - * Set ``decode=True`` to force UTF-8 decoding of Binary_ frames - and return strings (:class:`str`). This is useful for servers - that send binary frames instead of text frames. + * Set ``decode=False`` to disable UTF-8 decoding of Text_ frames and + yield bytestrings (:class:`bytes`). This improves performance + when decoding isn't needed. + * Set ``decode=True`` to force UTF-8 decoding of Binary_ frames and + yield strings (:class:`str`). This may be useful for servers that + send binary frames instead of text frames. Raises: ConnectionClosed: When the connection is closed. @@ -417,8 +417,8 @@ async def send( You may override this behavior with the ``text`` argument: - * Set ``text=True`` to send a bytestring or bytes-like object - (:class:`bytes`, :class:`bytearray`, or :class:`memoryview`) as a + * Set ``text=True`` to send an UTF-8 bytestring or bytes-like object + (:class:`bytes`, :class:`bytearray`, or :class:`memoryview`) in a Text_ frame. This improves performance when the message is already UTF-8 encoded, for example if the message contains JSON and you're using a JSON library that produces a bytestring. @@ -426,7 +426,7 @@ async def send( frame. This may be useful for servers that expect binary frames instead of text frames. - :meth:`send` also accepts an iterable or an asynchronous iterable of + :meth:`send` also accepts an iterable or asynchronous iterable of strings, bytestrings, or bytes-like objects to enable fragmentation_. Each item is treated as a message fragment and sent in its own frame. All items must be of the same type, or else :meth:`send` will raise a @@ -441,8 +441,8 @@ async def send( Canceling :meth:`send` is discouraged. Instead, you should close the connection with :meth:`close`. Indeed, there are only two situations where :meth:`send` may yield control to the event loop and then get - canceled; in both cases, :meth:`close` has the same effect and is - more clear: + canceled; in both cases, :meth:`close` has the same effect and the + effect is more obvious: 1. The write buffer is full. If you don't want to wait until enough data is sent, your only alternative is to close the connection. @@ -708,9 +708,10 @@ async def ping(self, data: DataLike | None = None) -> Awaitable[float]: data = struct.pack("!I", random.getrandbits(32)) pong_received = self.loop.create_future() + ping_timestamp = self.loop.time() # The event loop's default clock is time.monotonic(). Its resolution # is a bit low on Windows (~16ms). This is improved in Python 3.13. - self.pending_pings[data] = (pong_received, self.loop.time()) + self.pending_pings[data] = (pong_received, ping_timestamp) self.protocol.send_ping(data) return pong_received @@ -787,9 +788,7 @@ def acknowledge_pings(self, data: bytes) -> None: def terminate_pending_pings(self) -> None: """ - Raise ConnectionClosed in pending pings. - - They'll never receive a pong once the connection is closed. + Raise ConnectionClosed in pending pings when the connection is closed. """ assert self.protocol.state is CLOSED @@ -837,7 +836,8 @@ async def keepalive(self) -> None: # pong_received. A CancelledError is raised here, # not a ConnectionClosed exception. latency = await pong_received - self.logger.debug("% received keepalive pong") + if self.debug: + self.logger.debug("% received keepalive pong") except asyncio.TimeoutError: if self.debug: self.logger.debug("- timed out waiting for keepalive pong") @@ -908,20 +908,22 @@ async def send_context( # Check if the connection is expected to close soon. if self.protocol.close_expected(): wait_for_close = True - # If the connection is expected to close soon, set the - # close deadline based on the close timeout. - # Since we tested earlier that protocol.state was OPEN + # Set the close deadline based on the close timeout. + # Since we tested earlier that protocol.state is OPEN # (or CONNECTING), self.close_deadline is still None. + assert self.close_deadline is None if self.close_timeout is not None: - assert self.close_deadline is None self.close_deadline = self.loop.time() + self.close_timeout - # Write outgoing data to the socket and enforce flow control. + # Write outgoing data to the socket with flow control. try: self.send_data() await self.drain() except Exception as exc: if self.debug: - self.logger.debug("! error while sending data", exc_info=True) + self.logger.debug( + "! error while sending data", + exc_info=True, + ) # While the only expected exception here is OSError, # other exceptions would be treated identically. wait_for_close = False @@ -933,8 +935,8 @@ async def send_context( # will be closing soon if it isn't in the expected state. wait_for_close = True # Calculate close_deadline if it wasn't set yet. - if self.close_timeout is not None: - if self.close_deadline is None: + if self.close_deadline is None: + if self.close_timeout is not None: self.close_deadline = self.loop.time() + self.close_timeout raise_close_exc = True @@ -945,7 +947,7 @@ async def send_context( async with asyncio_timeout_at(self.close_deadline): await asyncio.shield(self.connection_lost_waiter) except TimeoutError: - # There's no risk to overwrite another error because + # There's no risk of overwriting another error because # original_exc is never set when wait_for_close is True. assert original_exc is None original_exc = TimeoutError("timed out while closing connection") @@ -966,9 +968,6 @@ def send_data(self) -> None: """ Send outgoing data. - Raises: - OSError: When a socket operations fails. - """ for data in self.protocol.data_to_send(): if data: @@ -982,7 +981,7 @@ def send_data(self) -> None: # OSError is plausible. uvloop can raise RuntimeError here. try: self.transport.write_eof() - except (OSError, RuntimeError): # pragma: no cover + except Exception: # pragma: no cover pass # Else, close the TCP connection. else: # pragma: no cover @@ -994,6 +993,8 @@ def set_recv_exc(self, exc: BaseException | None) -> None: """ Set recv_exc, if not set yet. + This method must be called only from connection callbacks. + """ if self.recv_exc is None: self.recv_exc = exc @@ -1096,26 +1097,29 @@ def data_received(self, data: bytes) -> None: self.logger.debug("! error while sending data", exc_info=True) self.set_recv_exc(exc) + # If needed, set the close deadline based on the close timeout. if self.protocol.close_expected(): - # If the connection is expected to close soon, set the - # close deadline based on the close timeout. - if self.close_timeout is not None: - if self.close_deadline is None: + if self.close_deadline is None: + if self.close_timeout is not None: self.close_deadline = self.loop.time() + self.close_timeout + # If self.send_data raised an exception, then events are lost. + # Given that automatic responses write small amounts of data, + # this should be uncommon, so we don't handle the edge case. + for event in events: # This isn't expected to raise an exception. self.process_event(event) def eof_received(self) -> None: - # Feed the end of the data stream to the connection. + # Feed the end of the data stream to the protocol. self.protocol.receive_eof() # This isn't expected to raise an exception. events = self.protocol.events_received() # There is no error handling because send_data() can only write - # the end of the data stream here and it shouldn't raise errors. + # the end of the data stream and it handles errors by itself. self.send_data() # This code path is triggered when receiving an HTTP response diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index aab584f58..052a8fef4 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -106,13 +106,13 @@ def __init__( # Mapping of ping IDs to pong waiters, in chronological order. self.pending_pings: dict[bytes, tuple[threading.Event, float, bool]] = {} - self.latency: float = 0 + self.latency: float = 0.0 """ Latency of the connection, in seconds. Latency is defined as the round-trip time of the connection. It is measured by sending a Ping frame and waiting for a matching Pong frame. - Before the first measurement, :attr:`latency` is ``0``. + Before the first measurement, :attr:`latency` is ``0.0``. By default, websockets enables a :ref:`keepalive ` mechanism that sends Ping frames automatically at regular intervals. You can also @@ -122,8 +122,8 @@ def __init__( # Thread that sends keepalive pings. None when ping_interval is None. self.keepalive_thread: threading.Thread | None = None - # Exception raised in recv_events, to be chained to ConnectionClosed - # in the user thread in order to show why the TCP connection dropped. + # Exception raised while reading from the connection, to be chained to + # ConnectionClosed in order to show why the TCP connection dropped. self.recv_exc: BaseException | None = None # Receiving events from the socket. This thread is marked as daemon to @@ -284,8 +284,8 @@ def recv(self, timeout: float | None = None, decode: bool | None = None) -> Data is ``0`` or negative, check if a message has been received already and return it, else raise :exc:`TimeoutError`. - If the message is fragmented, wait until all fragments are received, - reassemble them, and return the whole message. + When the message is fragmented, :meth:`recv` waits until all fragments + are received, reassembles them, and returns the whole message. Args: timeout: Timeout for receiving a message in seconds. @@ -305,9 +305,9 @@ def recv(self, timeout: float | None = None, decode: bool | None = None) -> Data return a bytestring (:class:`bytes`). This improves performance when decoding isn't needed, for example if the message contains JSON and you're using a JSON library that expects a bytestring. - * Set ``decode=True`` to force UTF-8 decoding of Binary_ frames - and return a string (:class:`str`). This may be useful for - servers that send binary frames instead of text frames. + * Set ``decode=True`` to force UTF-8 decoding of Binary_ frames and + return strings (:class:`str`). This may be useful for servers that + send binary frames instead of text frames. Raises: ConnectionClosed: When the connection is closed. @@ -372,12 +372,12 @@ def recv_streaming(self, decode: bool | None = None) -> Iterator[Data]: You may override this behavior with the ``decode`` argument: - * Set ``decode=False`` to disable UTF-8 decoding of Text_ frames - and return bytestrings (:class:`bytes`). This may be useful to - optimize performance when decoding isn't needed. - * Set ``decode=True`` to force UTF-8 decoding of Binary_ frames - and return strings (:class:`str`). This is useful for servers - that send binary frames instead of text frames. + * Set ``decode=False`` to disable UTF-8 decoding of Text_ frames and + yield bytestrings (:class:`bytes`). This improves performance + when decoding isn't needed. + * Set ``decode=True`` to force UTF-8 decoding of Binary_ frames and + yield strings (:class:`str`). This may be useful for servers that + send binary frames instead of text frames. Raises: ConnectionClosed: When the connection is closed. @@ -425,8 +425,8 @@ def send( You may override this behavior with the ``text`` argument: - * Set ``text=True`` to send a bytestring or bytes-like object - (:class:`bytes`, :class:`bytearray`, or :class:`memoryview`) as a + * Set ``text=True`` to send an UTF-8 bytestring or bytes-like object + (:class:`bytes`, :class:`bytearray`, or :class:`memoryview`) in a Text_ frame. This improves performance when the message is already UTF-8 encoded, for example if the message contains JSON and you're using a JSON library that produces a bytestring. @@ -530,7 +530,7 @@ def send( self.protocol.send_binary(chunk, fin=False) encode = False else: - raise TypeError("data iterable must contain bytes or str") + raise TypeError("iterable must contain bytes or str") # Other fragments for chunk in chunks: @@ -543,7 +543,7 @@ def send( assert self.send_in_progress self.protocol.send_continuation(chunk, fin=False) else: - raise TypeError("data iterable must contain uniform types") + raise TypeError("iterable must contain uniform types") # Final fragment. with self.send_context(): @@ -576,9 +576,8 @@ def close( """ Perform the closing handshake. - :meth:`close` waits for the other end to complete the handshake, for the - TCP connection to terminate, and for all incoming messages to be read - with :meth:`recv`. + :meth:`close` waits for the other end to complete the handshake and + for the TCP connection to terminate. :meth:`close` is idempotent: it doesn't do anything once the connection is closed. @@ -633,8 +632,9 @@ def ping( :: - pong_event = ws.ping() - pong_event.wait() # only if you want to wait for the pong + pong_received = ws.ping() + # only if you want to wait for the corresponding pong + pong_received.wait() Raises: ConnectionClosed: When the connection is closed. @@ -659,7 +659,8 @@ def ping( data = struct.pack("!I", random.getrandbits(32)) pong_received = threading.Event() - self.pending_pings[data] = (pong_received, time.monotonic(), ack_on_close) + ping_timestamp = time.monotonic() + self.pending_pings[data] = (pong_received, ping_timestamp, ack_on_close) self.protocol.send_ping(data) return pong_received @@ -737,7 +738,7 @@ def acknowledge_pings(self, data: bytes) -> None: for ping_id in ping_ids: del self.pending_pings[ping_id] - def acknowledge_pending_pings(self) -> None: + def terminate_pending_pings(self) -> None: """ Acknowledge pending pings when the connection is closed. @@ -773,7 +774,6 @@ def keepalive(self) -> None: self.logger.debug("% sent keepalive ping") if self.ping_timeout is not None: - # if pong_received.wait(self.ping_timeout): if self.debug: self.logger.debug("% received keepalive pong") @@ -808,15 +808,17 @@ def recv_events(self) -> None: Run this method in a thread as long as the connection is alive. - ``recv_events()`` exits immediately when the ``self.socket`` is closed. + ``recv_events()`` exits immediately when ``self.socket`` is closed. """ try: while True: try: + # If the assembler buffer is full, block until it drains. with self.recv_flow_control: - if self.close_deadline is not None: - self.socket.settimeout(self.close_deadline.timeout()) + pass + if self.close_deadline is not None: + self.socket.settimeout(self.close_deadline.timeout()) data = self.socket.recv(self.recv_bufsize) except Exception as exc: if self.debug: @@ -859,9 +861,8 @@ def recv_events(self) -> None: self.set_recv_exc(exc) break + # If needed, set the close deadline based on the close timeout. if self.protocol.close_expected(): - # If the connection is expected to close soon, set the - # close deadline based on the close timeout. if self.close_deadline is None: self.close_deadline = Deadline(self.close_timeout) @@ -878,6 +879,7 @@ def recv_events(self) -> None: # Breaking out of the while True: ... loop means that we believe # that the socket doesn't work anymore. + with self.protocol_mutex: # Feed the end of the data stream to the protocol. self.protocol.receive_eof() @@ -886,7 +888,7 @@ def recv_events(self) -> None: events = self.protocol.events_received() # There is no error handling because send_data() can only write - # the end of the data stream here and it handles errors itself. + # the end of the data stream and it handles errors by itself. self.send_data() # This code path is triggered when receiving an HTTP response @@ -918,7 +920,7 @@ def send_context( On entry, :meth:`send_context` acquires the connection lock and checks that the connection is open; on exit, it writes outgoing data to the - socket:: + socket and releases the connection lock:: with self.send_context(): self.protocol.send_text(message.encode()) @@ -957,11 +959,10 @@ def send_context( # Check if the connection is expected to close soon. if self.protocol.close_expected(): wait_for_close = True - # If the connection is expected to close soon, set the - # close deadline based on the close timeout. - # Since we tested earlier that protocol.state was OPEN + # Set the close deadline based on the close timeout. + # Since we tested earlier that protocol.state is OPEN # (or CONNECTING) and we didn't release protocol_mutex, - # it is certain that self.close_deadline is still None. + # self.close_deadline is still None. assert self.close_deadline is None self.close_deadline = Deadline(self.close_timeout) # Write outgoing data to the socket. @@ -983,6 +984,9 @@ def send_context( # Minor layering violation: we assume that the connection # will be closing soon if it isn't in the expected state. wait_for_close = True + # Calculate close_deadline if it wasn't set yet. + if self.close_deadline is None: + self.close_deadline = Deadline(self.close_timeout) raise_close_exc = True # To avoid a deadlock, release the connection lock by exiting the @@ -991,15 +995,12 @@ def send_context( # If the connection is expected to close soon and the close timeout # elapses, close the socket to terminate the connection. if wait_for_close: - if self.close_deadline is None: - timeout = self.close_timeout - else: - # Thread.join() returns immediately if timeout is negative. - timeout = self.close_deadline.timeout(raise_if_elapsed=False) + # Thread.join() returns immediately if timeout is negative. + assert self.close_deadline is not None + timeout = self.close_deadline.timeout(raise_if_elapsed=False) self.recv_events_thread.join(timeout) - if self.recv_events_thread.is_alive(): - # There's no risk to overwrite another error because + # There's no risk of overwriting another error because # original_exc is never set when wait_for_close is True. assert original_exc is None original_exc = TimeoutError("timed out while closing connection") @@ -1023,9 +1024,6 @@ def send_data(self) -> None: This method requires holding protocol_mutex. - Raises: - OSError: When a socket operations fails. - """ assert self.protocol_mutex.locked() for data in self.protocol.data_to_send(): @@ -1043,11 +1041,12 @@ def set_recv_exc(self, exc: BaseException | None) -> None: """ Set recv_exc, if not set yet. - This method requires holding protocol_mutex. + This method requires holding protocol_mutex and must be called only from + the thread running recv_events(). """ assert self.protocol_mutex.locked() - if self.recv_exc is None: # pragma: no branch + if self.recv_exc is None: self.recv_exc = exc def close_socket(self) -> None: @@ -1061,8 +1060,8 @@ def close_socket(self) -> None: # shutdown() is required to interrupt recv() on Linux. try: self.socket.shutdown(socket.SHUT_RDWR) - except OSError: - pass # socket is already closed + except OSError: # socket already closed + pass self.socket.close() # Calling protocol.receive_eof() is safe because it's idempotent. @@ -1071,8 +1070,8 @@ def close_socket(self) -> None: self.protocol.receive_eof() assert self.protocol.state is CLOSED - # Abort recv() with a ConnectionClosed exception. - self.recv_messages.close() + # Abort recv() with a ConnectionClosed exception. + self.recv_messages.close() - # Acknowledge pings sent with the ack_on_close option. - self.acknowledge_pending_pings() + # Acknowledge pings sent with the ack_on_close option. + self.terminate_pending_pings() From 7d5b63c6bccd51059172e81665f216a2f4876935 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 1 Nov 2025 09:00:58 +0100 Subject: [PATCH 04/13] Make test connection class less error prone. --- tests/asyncio/connection.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/tests/asyncio/connection.py b/tests/asyncio/connection.py index ad1c121bf..5cd673d97 100644 --- a/tests/asyncio/connection.py +++ b/tests/asyncio/connection.py @@ -21,7 +21,7 @@ def delay_frames_sent(self, delay): """ Add a delay before sending frames. - This can result in out-of-order writes, which is unrealistic. + Misuse can result in out-of-order writes, which is unrealistic. """ assert self.transport.delay_write is None @@ -36,7 +36,7 @@ def delay_eof_sent(self, delay): """ Add a delay before sending EOF. - This can result in out-of-order writes, which is unrealistic. + Misuse can result in out-of-order writes, which is unrealistic. """ assert self.transport.delay_write_eof is None @@ -83,9 +83,9 @@ class InterceptingTransport: This is coupled to the implementation, which relies on these two methods. - Since ``write()`` and ``write_eof()`` are not coroutines, this effect is - achieved by scheduling writes at a later time, after the methods return. - This can easily result in out-of-order writes, which is unrealistic. + Since ``write()`` and ``write_eof()`` are synchronous, we can only schedule + writes at a later time, after they return. This is unrealistic and can lead + to out-of-order writes if tests aren't written carefully. """ @@ -101,15 +101,15 @@ def __getattr__(self, name): return getattr(self.transport, name) def write(self, data): - if not self.drop_write: - if self.delay_write is not None: - self.loop.call_later(self.delay_write, self.transport.write, data) - else: - self.transport.write(data) + if self.delay_write is not None: + assert not self.drop_write + self.loop.call_later(self.delay_write, self.transport.write, data) + elif not self.drop_write: + self.transport.write(data) def write_eof(self): - if not self.drop_write_eof: - if self.delay_write_eof is not None: - self.loop.call_later(self.delay_write_eof, self.transport.write_eof) - else: - self.transport.write_eof() + if self.delay_write_eof is not None: + assert not self.drop_write_eof + self.loop.call_later(self.delay_write_eof, self.transport.write_eof) + elif not self.drop_write_eof: + self.transport.write_eof() From 37a8f1b6621d0c9f92c15721ea9a7c51d02223db Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 1 Nov 2025 09:03:48 +0100 Subject: [PATCH 05/13] Close iterators in connection tests. --- tests/asyncio/test_connection.py | 51 ++++++++++++++++++-------------- tests/sync/test_connection.py | 41 ++++++++++++++----------- 2 files changed, 51 insertions(+), 41 deletions(-) diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index 92dbf5392..cce389ce0 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -124,41 +124,46 @@ async def test_exit_with_exception(self): async def test_aiter_text(self): """__aiter__ yields text messages.""" - aiterator = aiter(self.connection) - await self.remote_connection.send("😀") - self.assertEqual(await anext(aiterator), "😀") - await self.remote_connection.send("😀") - self.assertEqual(await anext(aiterator), "😀") + iterator = aiter(self.connection) + async with contextlib.aclosing(iterator): + await self.remote_connection.send("😀") + self.assertEqual(await anext(iterator), "😀") + await self.remote_connection.send("😀") + self.assertEqual(await anext(iterator), "😀") async def test_aiter_binary(self): """__aiter__ yields binary messages.""" - aiterator = aiter(self.connection) - await self.remote_connection.send(b"\x01\x02\xfe\xff") - self.assertEqual(await anext(aiterator), b"\x01\x02\xfe\xff") - await self.remote_connection.send(b"\x01\x02\xfe\xff") - self.assertEqual(await anext(aiterator), b"\x01\x02\xfe\xff") + iterator = aiter(self.connection) + async with contextlib.aclosing(iterator): + await self.remote_connection.send(b"\x01\x02\xfe\xff") + self.assertEqual(await anext(iterator), b"\x01\x02\xfe\xff") + await self.remote_connection.send(b"\x01\x02\xfe\xff") + self.assertEqual(await anext(iterator), b"\x01\x02\xfe\xff") async def test_aiter_mixed(self): """__aiter__ yields a mix of text and binary messages.""" - aiterator = aiter(self.connection) - await self.remote_connection.send("😀") - self.assertEqual(await anext(aiterator), "😀") - await self.remote_connection.send(b"\x01\x02\xfe\xff") - self.assertEqual(await anext(aiterator), b"\x01\x02\xfe\xff") + iterator = aiter(self.connection) + async with contextlib.aclosing(iterator): + await self.remote_connection.send("😀") + self.assertEqual(await anext(iterator), "😀") + await self.remote_connection.send(b"\x01\x02\xfe\xff") + self.assertEqual(await anext(iterator), b"\x01\x02\xfe\xff") async def test_aiter_connection_closed_ok(self): """__aiter__ terminates after a normal closure.""" - aiterator = aiter(self.connection) - await self.remote_connection.close() - with self.assertRaises(StopAsyncIteration): - await anext(aiterator) + iterator = aiter(self.connection) + async with contextlib.aclosing(iterator): + await self.remote_connection.close() + with self.assertRaises(StopAsyncIteration): + await anext(iterator) async def test_aiter_connection_closed_error(self): """__aiter__ raises ConnectionClosedError after an error.""" - aiterator = aiter(self.connection) - await self.remote_connection.close(code=CloseCode.INTERNAL_ERROR) - with self.assertRaises(ConnectionClosedError): - await anext(aiterator) + iterator = aiter(self.connection) + async with contextlib.aclosing(iterator): + await self.remote_connection.close(code=CloseCode.INTERNAL_ERROR) + with self.assertRaises(ConnectionClosedError): + await anext(iterator) # Test recv. diff --git a/tests/sync/test_connection.py b/tests/sync/test_connection.py index b81102079..8c6d15d7b 100644 --- a/tests/sync/test_connection.py +++ b/tests/sync/test_connection.py @@ -107,40 +107,45 @@ def test_exit_with_exception(self): def test_iter_text(self): """__iter__ yields text messages.""" iterator = iter(self.connection) - self.remote_connection.send("😀") - self.assertEqual(next(iterator), "😀") - self.remote_connection.send("😀") - self.assertEqual(next(iterator), "😀") + with contextlib.closing(iterator): + self.remote_connection.send("😀") + self.assertEqual(next(iterator), "😀") + self.remote_connection.send("😀") + self.assertEqual(next(iterator), "😀") def test_iter_binary(self): """__iter__ yields binary messages.""" iterator = iter(self.connection) - self.remote_connection.send(b"\x01\x02\xfe\xff") - self.assertEqual(next(iterator), b"\x01\x02\xfe\xff") - self.remote_connection.send(b"\x01\x02\xfe\xff") - self.assertEqual(next(iterator), b"\x01\x02\xfe\xff") + with contextlib.closing(iterator): + self.remote_connection.send(b"\x01\x02\xfe\xff") + self.assertEqual(next(iterator), b"\x01\x02\xfe\xff") + self.remote_connection.send(b"\x01\x02\xfe\xff") + self.assertEqual(next(iterator), b"\x01\x02\xfe\xff") def test_iter_mixed(self): """__iter__ yields a mix of text and binary messages.""" iterator = iter(self.connection) - self.remote_connection.send("😀") - self.assertEqual(next(iterator), "😀") - self.remote_connection.send(b"\x01\x02\xfe\xff") - self.assertEqual(next(iterator), b"\x01\x02\xfe\xff") + with contextlib.closing(iterator): + self.remote_connection.send("😀") + self.assertEqual(next(iterator), "😀") + self.remote_connection.send(b"\x01\x02\xfe\xff") + self.assertEqual(next(iterator), b"\x01\x02\xfe\xff") def test_iter_connection_closed_ok(self): """__iter__ terminates after a normal closure.""" iterator = iter(self.connection) - self.remote_connection.close() - with self.assertRaises(StopIteration): - next(iterator) + with contextlib.closing(iterator): + self.remote_connection.close() + with self.assertRaises(StopIteration): + next(iterator) def test_iter_connection_closed_error(self): """__iter__ raises ConnectionClosedError after an error.""" iterator = iter(self.connection) - self.remote_connection.close(code=CloseCode.INTERNAL_ERROR) - with self.assertRaises(ConnectionClosedError): - next(iterator) + with contextlib.closing(iterator): + self.remote_connection.close(code=CloseCode.INTERNAL_ERROR) + with self.assertRaises(ConnectionClosedError): + next(iterator) # Test recv. From 65c7f35414e8b356898adf882ba59e87e847b1ee Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 1 Nov 2025 09:05:36 +0100 Subject: [PATCH 06/13] Clarify semantic of waits in connection tests. --- tests/asyncio/test_connection.py | 41 ++++++++++++++------------------ tests/sync/test_connection.py | 24 ++++++++++++------- 2 files changed, 34 insertions(+), 31 deletions(-) diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index cce389ce0..732977fd5 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -49,27 +49,25 @@ async def asyncTearDown(self): # Test helpers built upon RecordingProtocol and InterceptingConnection. - async def assertFrameSent(self, frame): - """Check that a single frame was sent.""" - # Let the remote side process messages. + async def wait_for_remote_side(self): + """Wait for the remote side to process messages.""" # Two runs of the event loop are required for answering pings. await asyncio.sleep(0) await asyncio.sleep(0) + + async def assertFrameSent(self, frame): + """Check that a single frame was sent.""" + await self.wait_for_remote_side() self.assertEqual(self.remote_connection.protocol.get_frames_rcvd(), [frame]) async def assertFramesSent(self, frames): """Check that several frames were sent.""" - # Let the remote side process messages. - # Two runs of the event loop are required for answering pings. - await asyncio.sleep(0) - await asyncio.sleep(0) + await self.wait_for_remote_side() self.assertEqual(self.remote_connection.protocol.get_frames_rcvd(), frames) async def assertNoFrameSent(self): """Check that no frame was sent.""" - # Run the event loop twice for consistency with assertFrameSent. - await asyncio.sleep(0) - await asyncio.sleep(0) + await self.wait_for_remote_side() self.assertEqual(self.remote_connection.protocol.get_frames_rcvd(), []) @contextlib.asynccontextmanager @@ -77,28 +75,28 @@ async def delay_frames_rcvd(self, delay): """Delay frames before they're received by the connection.""" with self.remote_connection.delay_frames_sent(delay): yield - await asyncio.sleep(MS) # let the remote side process messages + await self.wait_for_remote_side() @contextlib.asynccontextmanager async def delay_eof_rcvd(self, delay): """Delay EOF before it's received by the connection.""" with self.remote_connection.delay_eof_sent(delay): yield - await asyncio.sleep(MS) # let the remote side process messages + await self.wait_for_remote_side() @contextlib.asynccontextmanager async def drop_frames_rcvd(self): """Drop frames before they're received by the connection.""" with self.remote_connection.drop_frames_sent(): yield - await asyncio.sleep(MS) # let the remote side process messages + await self.wait_for_remote_side() @contextlib.asynccontextmanager async def drop_eof_rcvd(self): """Drop EOF before it's received by the connection.""" with self.remote_connection.drop_eof_sent(): yield - await asyncio.sleep(MS) # let the remote side process messages + await self.wait_for_remote_side() # Test __aenter__ and __aexit__. @@ -1009,9 +1007,8 @@ async def test_keepalive_times_out(self, getrandbits): async with self.drop_frames_rcvd(): self.connection.start_keepalive() # 4 ms: keepalive() sends a ping frame. - await asyncio.sleep(4 * MS) - # Exiting the context manager sleeps for 1 ms. # 4.x ms: a pong frame is dropped. + await asyncio.sleep(5 * MS) # 6 ms: no pong frame is received; the connection is closed. await asyncio.sleep(2 * MS) # 7 ms: check that the connection is closed. @@ -1026,8 +1023,7 @@ async def test_keepalive_ignores_timeout(self, getrandbits): self.connection.start_keepalive() # 4 ms: keepalive() sends a ping frame. # 4.x ms: a pong frame is dropped. - await asyncio.sleep(4 * MS) - # Exiting the context manager sleeps for 1 ms. + await asyncio.sleep(5 * MS) # 6 ms: no pong frame is received; the connection remains open. await asyncio.sleep(2 * MS) # 7 ms: check that the connection is still open. @@ -1038,19 +1034,19 @@ async def test_keepalive_terminates_while_sleeping(self): self.connection.ping_interval = 3 * MS self.connection.start_keepalive() await asyncio.sleep(MS) + self.assertFalse(self.connection.keepalive_task.done()) await self.connection.close() self.assertTrue(self.connection.keepalive_task.done()) async def test_keepalive_terminates_while_waiting_for_pong(self): """keepalive task terminates while waiting to receive a pong.""" self.connection.ping_interval = MS - self.connection.ping_timeout = 3 * MS + self.connection.ping_timeout = 4 * MS async with self.drop_frames_rcvd(): self.connection.start_keepalive() # 1 ms: keepalive() sends a ping frame. # 1.x ms: a pong frame is dropped. - await asyncio.sleep(MS) - # Exiting the context manager sleeps for 1 ms. + await asyncio.sleep(2 * MS) # 2 ms: close the connection before ping_timeout elapses. await self.connection.close() self.assertTrue(self.connection.keepalive_task.done()) @@ -1062,8 +1058,7 @@ async def test_keepalive_reports_errors(self): self.connection.start_keepalive() # 2 ms: keepalive() sends a ping frame. # 2.x ms: a pong frame is dropped. - await asyncio.sleep(2 * MS) - # Exiting the context manager sleeps for 1 ms. + await asyncio.sleep(3 * MS) # 3 ms: inject a fault: raise an exception in the pending pong waiter. pong_received = next(iter(self.connection.pending_pings.values()))[0] with self.assertLogs("websockets", logging.ERROR) as logs: diff --git a/tests/sync/test_connection.py b/tests/sync/test_connection.py index 8c6d15d7b..8895b6420 100644 --- a/tests/sync/test_connection.py +++ b/tests/sync/test_connection.py @@ -44,14 +44,20 @@ def tearDown(self): # Test helpers built upon RecordingProtocol and InterceptingConnection. + def wait_for_remote_side(self): + """Wait for the remote side to process messages.""" + # We don't have a way to tell if the remote side is blocked on I/O. + # The sync tests still run faster than the asyncio and trio tests :-) + time.sleep(MS) + def assertFrameSent(self, frame): """Check that a single frame was sent.""" - time.sleep(MS) # let the remote side process messages + self.wait_for_remote_side() self.assertEqual(self.remote_connection.protocol.get_frames_rcvd(), [frame]) def assertNoFrameSent(self): """Check that no frame was sent.""" - time.sleep(MS) # let the remote side process messages + self.wait_for_remote_side() self.assertEqual(self.remote_connection.protocol.get_frames_rcvd(), []) @contextlib.contextmanager @@ -59,28 +65,28 @@ def delay_frames_rcvd(self, delay): """Delay frames before they're received by the connection.""" with self.remote_connection.delay_frames_sent(delay): yield - time.sleep(MS) # let the remote side process messages + self.wait_for_remote_side() @contextlib.contextmanager def delay_eof_rcvd(self, delay): """Delay EOF before it's received by the connection.""" with self.remote_connection.delay_eof_sent(delay): yield - time.sleep(MS) # let the remote side process messages + self.wait_for_remote_side() @contextlib.contextmanager def drop_frames_rcvd(self): """Drop frames before they're received by the connection.""" with self.remote_connection.drop_frames_sent(): yield - time.sleep(MS) # let the remote side process messages + self.wait_for_remote_side() @contextlib.contextmanager def drop_eof_rcvd(self): """Drop EOF before it's received by the connection.""" with self.remote_connection.drop_eof_sent(): yield - time.sleep(MS) # let the remote side process messages + self.wait_for_remote_side() # Test __enter__ and __exit__. @@ -736,9 +742,9 @@ def test_keepalive_times_out(self, getrandbits): with self.drop_frames_rcvd(): self.connection.start_keepalive() # 4 ms: keepalive() sends a ping frame. + # 4.x ms: a pong frame is dropped. time.sleep(4 * MS) # Exiting the context manager sleeps for 1 ms. - # 4.x ms: a pong frame is dropped. # 6 ms: no pong frame is received; the connection is closed. time.sleep(2 * MS) # 7 ms: check that the connection is closed. @@ -765,16 +771,18 @@ def test_keepalive_terminates_while_sleeping(self): self.connection.ping_interval = 3 * MS self.connection.start_keepalive() time.sleep(MS) + self.assertTrue(self.connection.keepalive_thread.is_alive()) self.connection.close() self.connection.keepalive_thread.join(MS) self.assertFalse(self.connection.keepalive_thread.is_alive()) def test_keepalive_terminates_when_sending_ping_fails(self): """keepalive task terminates when sending a ping fails.""" - self.connection.ping_interval = 1 * MS + self.connection.ping_interval = MS self.connection.start_keepalive() with self.drop_eof_rcvd(), self.drop_frames_rcvd(): self.connection.close() + # Exiting the context managers sleeps for 2 ms. self.assertFalse(self.connection.keepalive_thread.is_alive()) def test_keepalive_terminates_while_waiting_for_pong(self): From 85dcbf94947cd22759edd4873c70f004e8943bdb Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 1 Nov 2025 09:08:43 +0100 Subject: [PATCH 07/13] Actually check for waits in connection tests. --- tests/asyncio/test_connection.py | 46 ++++++++++++++++++++++++++++---- tests/sync/test_connection.py | 26 +++++++++++++++++- 2 files changed, 66 insertions(+), 6 deletions(-) diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index 732977fd5..8f6783dab 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -32,13 +32,13 @@ class ClientConnectionTests(unittest.IsolatedAsyncioTestCase): REMOTE = SERVER async def asyncSetUp(self): - loop = asyncio.get_running_loop() + self.loop = asyncio.get_running_loop() socket_, remote_socket = socket.socketpair() - self.transport, self.connection = await loop.create_connection( + self.transport, self.connection = await self.loop.create_connection( lambda: Connection(Protocol(self.LOCAL), close_timeout=2 * MS), sock=socket_, ) - self.remote_transport, self.remote_connection = await loop.create_connection( + _remote_transport, self.remote_connection = await self.loop.create_connection( lambda: InterceptingConnection(RecordingProtocol(self.REMOTE)), sock=remote_socket, ) @@ -710,9 +710,15 @@ async def test_close_explicit_code_reason(self): await self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe9bye!")) async def test_close_waits_for_close_frame(self): - """close waits for a close frame (then EOF) before returning.""" + """close waits for a close frame then EOF before returning.""" + t0 = self.loop.time() async with self.delay_frames_rcvd(MS), self.delay_eof_rcvd(MS): await self.connection.close() + t1 = self.loop.time() + + self.assertEqual(self.connection.state, State.CLOSED) + self.assertEqual(self.connection.close_code, CloseCode.NORMAL_CLOSURE) + self.assertGreater(t1 - t0, MS) with self.assertRaises(ConnectionClosedOK) as raised: await self.connection.recv() @@ -726,8 +732,14 @@ async def test_close_waits_for_connection_closed(self): if self.LOCAL is SERVER: self.skipTest("only relevant on the client-side") + t0 = self.loop.time() async with self.delay_eof_rcvd(MS): await self.connection.close() + t1 = self.loop.time() + + self.assertEqual(self.connection.state, State.CLOSED) + self.assertEqual(self.connection.close_code, CloseCode.NORMAL_CLOSURE) + self.assertGreater(t1 - t0, MS) with self.assertRaises(ConnectionClosedOK) as raised: await self.connection.recv() @@ -737,11 +749,17 @@ async def test_close_waits_for_connection_closed(self): self.assertIsNone(exc.__cause__) async def test_close_no_timeout_waits_for_close_frame(self): - """close without timeout waits for a close frame (then EOF) before returning.""" + """close without timeout waits for a close frame then EOF before returning.""" self.connection.close_timeout = None + t0 = self.loop.time() async with self.delay_frames_rcvd(MS), self.delay_eof_rcvd(MS): await self.connection.close() + t1 = self.loop.time() + + self.assertEqual(self.connection.state, State.CLOSED) + self.assertEqual(self.connection.close_code, CloseCode.NORMAL_CLOSURE) + self.assertGreater(t1 - t0, MS) with self.assertRaises(ConnectionClosedOK) as raised: await self.connection.recv() @@ -757,8 +775,14 @@ async def test_close_no_timeout_waits_for_connection_closed(self): self.connection.close_timeout = None + t0 = self.loop.time() async with self.delay_eof_rcvd(MS): await self.connection.close() + t1 = self.loop.time() + + self.assertEqual(self.connection.state, State.CLOSED) + self.assertEqual(self.connection.close_code, CloseCode.NORMAL_CLOSURE) + self.assertGreater(t1 - t0, MS) with self.assertRaises(ConnectionClosedOK) as raised: await self.connection.recv() @@ -769,8 +793,14 @@ async def test_close_no_timeout_waits_for_connection_closed(self): async def test_close_timeout_waiting_for_close_frame(self): """close times out if no close frame is received.""" + t0 = self.loop.time() async with self.drop_eof_rcvd(), self.drop_frames_rcvd(): await self.connection.close() + t1 = self.loop.time() + + self.assertEqual(self.connection.state, State.CLOSED) + self.assertEqual(self.connection.close_code, CloseCode.ABNORMAL_CLOSURE) + self.assertGreater(t1 - t0, 2 * MS) with self.assertRaises(ConnectionClosedError) as raised: await self.connection.recv() @@ -784,8 +814,14 @@ async def test_close_timeout_waiting_for_connection_closed(self): if self.LOCAL is SERVER: self.skipTest("only relevant on the client-side") + t0 = self.loop.time() async with self.drop_eof_rcvd(): await self.connection.close() + t1 = self.loop.time() + + self.assertEqual(self.connection.state, State.CLOSED) + self.assertEqual(self.connection.close_code, CloseCode.NORMAL_CLOSURE) + self.assertGreater(t1 - t0, 2 * MS) with self.assertRaises(ConnectionClosedOK) as raised: await self.connection.recv() diff --git a/tests/sync/test_connection.py b/tests/sync/test_connection.py index 8895b6420..3872de441 100644 --- a/tests/sync/test_connection.py +++ b/tests/sync/test_connection.py @@ -477,9 +477,15 @@ def test_close_explicit_code_reason(self): self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe9bye!")) def test_close_waits_for_close_frame(self): - """close waits for a close frame (then EOF) before returning.""" + """close waits for a close frame then EOF before returning.""" + t0 = time.time() with self.delay_frames_rcvd(MS): self.connection.close() + t1 = time.time() + + self.assertEqual(self.connection.state, State.CLOSED) + self.assertEqual(self.connection.close_code, CloseCode.NORMAL_CLOSURE) + self.assertGreater(t1 - t0, MS) with self.assertRaises(ConnectionClosedOK) as raised: self.connection.recv() @@ -493,8 +499,14 @@ def test_close_waits_for_connection_closed(self): if self.LOCAL is SERVER: self.skipTest("only relevant on the client-side") + t0 = time.time() with self.delay_eof_rcvd(MS): self.connection.close() + t1 = time.time() + + self.assertEqual(self.connection.state, State.CLOSED) + self.assertEqual(self.connection.close_code, CloseCode.NORMAL_CLOSURE) + self.assertGreater(t1 - t0, MS) with self.assertRaises(ConnectionClosedOK) as raised: self.connection.recv() @@ -505,8 +517,14 @@ def test_close_waits_for_connection_closed(self): def test_close_timeout_waiting_for_close_frame(self): """close times out if no close frame is received.""" + t0 = time.time() with self.drop_frames_rcvd(), self.drop_eof_rcvd(): self.connection.close() + t1 = time.time() + + self.assertEqual(self.connection.state, State.CLOSED) + self.assertEqual(self.connection.close_code, CloseCode.ABNORMAL_CLOSURE) + self.assertGreater(t1 - t0, 2 * MS) with self.assertRaises(ConnectionClosedError) as raised: self.connection.recv() @@ -520,8 +538,14 @@ def test_close_timeout_waiting_for_connection_closed(self): if self.LOCAL is SERVER: self.skipTest("only relevant on the client-side") + t0 = time.time() with self.drop_eof_rcvd(): self.connection.close() + t1 = time.time() + + self.assertEqual(self.connection.state, State.CLOSED) + self.assertEqual(self.connection.close_code, CloseCode.NORMAL_CLOSURE) + self.assertGreater(t1 - t0, 2 * MS) with self.assertRaises(ConnectionClosedOK) as raised: self.connection.recv() From 3a5488de37f07aa1472df8bdb46fda45ff8c3473 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 1 Nov 2025 09:11:14 +0100 Subject: [PATCH 08/13] Eliminate the risk of a busy loop in tests. If these tests reached the point where websockets tried to send a second ping, it would loop on trying to find a payload not sent yet. --- tests/asyncio/test_connection.py | 13 +++++++++---- tests/sync/test_connection.py | 13 +++++++++---- 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index 8f6783dab..8c92de024 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -1,5 +1,6 @@ import asyncio import contextlib +import itertools import logging import socket import sys @@ -903,9 +904,10 @@ async def test_wait_closed(self): # Test ping. - @patch("random.getrandbits", return_value=1918987876) + @patch("random.getrandbits") async def test_ping(self, getrandbits): """ping sends a ping frame with a random payload.""" + getrandbits.side_effect = itertools.count(1918987876) await self.connection.ping() getrandbits.assert_called_once_with(32) await self.assertFrameSent(Frame(Opcode.PING, b"rand")) @@ -1014,9 +1016,10 @@ async def test_pong_unsupported_type(self): # Test keepalive. - @patch("random.getrandbits", return_value=1918987876) + @patch("random.getrandbits") async def test_keepalive(self, getrandbits): """keepalive sends pings at ping_interval and measures latency.""" + getrandbits.side_effect = itertools.count(1918987876) self.connection.ping_interval = 3 * MS self.connection.start_keepalive() self.assertIsNotNone(self.connection.keepalive_task) @@ -1035,9 +1038,10 @@ async def test_disable_keepalive(self): self.connection.start_keepalive() self.assertIsNone(self.connection.keepalive_task) - @patch("random.getrandbits", return_value=1918987876) + @patch("random.getrandbits") async def test_keepalive_times_out(self, getrandbits): """keepalive closes the connection if ping_timeout elapses.""" + getrandbits.side_effect = itertools.count(1918987876) self.connection.ping_interval = 4 * MS self.connection.ping_timeout = 2 * MS async with self.drop_frames_rcvd(): @@ -1050,9 +1054,10 @@ async def test_keepalive_times_out(self, getrandbits): # 7 ms: check that the connection is closed. self.assertEqual(self.connection.state, State.CLOSED) - @patch("random.getrandbits", return_value=1918987876) + @patch("random.getrandbits") async def test_keepalive_ignores_timeout(self, getrandbits): """keepalive ignores timeouts if ping_timeout isn't set.""" + getrandbits.side_effect = itertools.count(1918987876) self.connection.ping_interval = 4 * MS self.connection.ping_timeout = None async with self.drop_frames_rcvd(): diff --git a/tests/sync/test_connection.py b/tests/sync/test_connection.py index 3872de441..92474fc16 100644 --- a/tests/sync/test_connection.py +++ b/tests/sync/test_connection.py @@ -1,4 +1,5 @@ import contextlib +import itertools import logging import socket import sys @@ -643,9 +644,10 @@ def fragments(): # Test ping. - @patch("random.getrandbits", return_value=1918987876) + @patch("random.getrandbits") def test_ping(self, getrandbits): """ping sends a ping frame with a random payload.""" + getrandbits.side_effect = itertools.count(1918987876) self.connection.ping() getrandbits.assert_called_once_with(32) self.assertFrameSent(Frame(Opcode.PING, b"rand")) @@ -737,9 +739,10 @@ def test_pong_unsupported_type(self): # Test keepalive. - @patch("random.getrandbits", return_value=1918987876) + @patch("random.getrandbits") def test_keepalive(self, getrandbits): """keepalive sends pings at ping_interval and measures latency.""" + getrandbits.side_effect = itertools.count(1918987876) self.connection.ping_interval = 4 * MS self.connection.start_keepalive() self.assertIsNotNone(self.connection.keepalive_thread) @@ -758,9 +761,10 @@ def test_disable_keepalive(self): self.connection.start_keepalive() self.assertIsNone(self.connection.keepalive_thread) - @patch("random.getrandbits", return_value=1918987876) + @patch("random.getrandbits") def test_keepalive_times_out(self, getrandbits): """keepalive closes the connection if ping_timeout elapses.""" + getrandbits.side_effect = itertools.count(1918987876) self.connection.ping_interval = 4 * MS self.connection.ping_timeout = 2 * MS with self.drop_frames_rcvd(): @@ -774,9 +778,10 @@ def test_keepalive_times_out(self, getrandbits): # 7 ms: check that the connection is closed. self.assertEqual(self.connection.state, State.CLOSED) - @patch("random.getrandbits", return_value=1918987876) + @patch("random.getrandbits") def test_keepalive_ignores_timeout(self, getrandbits): """keepalive ignores timeouts if ping_timeout isn't set.""" + getrandbits.side_effect = itertools.count(1918987876) self.connection.ping_interval = 4 * MS self.connection.ping_timeout = None with self.drop_frames_rcvd(): From 1845f1d32805301fe1ce8f1f3352ee73a5c2e87a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 1 Nov 2025 09:18:25 +0100 Subject: [PATCH 09/13] Improve mocking of transport or socket layer. --- tests/asyncio/test_connection.py | 32 +++++++++++++++++++------------- tests/sync/test_connection.py | 22 +++++----------------- 2 files changed, 24 insertions(+), 30 deletions(-) diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index 8c92de024..3a238b51a 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -1118,21 +1118,28 @@ async def test_keepalive_reports_errors(self): async def test_close_timeout(self): """close_timeout parameter configures close timeout.""" - connection = Connection(Protocol(self.LOCAL), close_timeout=42 * MS) + connection = Connection( + Protocol(self.LOCAL), + close_timeout=42 * MS, + ) self.assertEqual(connection.close_timeout, 42 * MS) async def test_max_queue(self): """max_queue configures high-water mark of frames buffer.""" - connection = Connection(Protocol(self.LOCAL), max_queue=4) - transport = Mock() - connection.connection_made(transport) + connection = Connection( + Protocol(self.LOCAL), + max_queue=4, + ) + connection.connection_made(Mock(spec=asyncio.Transport)) self.assertEqual(connection.recv_messages.high, 4) async def test_max_queue_none(self): """max_queue disables high-water mark of frames buffer.""" - connection = Connection(Protocol(self.LOCAL), max_queue=None) - transport = Mock() - connection.connection_made(transport) + connection = Connection( + Protocol(self.LOCAL), + max_queue=None, + ) + connection.connection_made(Mock(spec=asyncio.Transport)) self.assertEqual(connection.recv_messages.high, None) self.assertEqual(connection.recv_messages.low, None) @@ -1142,8 +1149,7 @@ async def test_max_queue_tuple(self): Protocol(self.LOCAL), max_queue=(4, 2), ) - transport = Mock() - connection.connection_made(transport) + connection.connection_made(Mock(spec=asyncio.Transport)) self.assertEqual(connection.recv_messages.high, 4) self.assertEqual(connection.recv_messages.low, 2) @@ -1153,7 +1159,7 @@ async def test_write_limit(self): Protocol(self.LOCAL), write_limit=4096, ) - transport = Mock() + transport = Mock(spec=asyncio.Transport) connection.connection_made(transport) transport.set_write_buffer_limits.assert_called_once_with(4096, None) @@ -1163,7 +1169,7 @@ async def test_write_limits(self): Protocol(self.LOCAL), write_limit=(4096, 2048), ) - transport = Mock() + transport = Mock(spec=asyncio.Transport) connection.connection_made(transport) transport.set_write_buffer_limits.assert_called_once_with(4096, 2048) @@ -1177,13 +1183,13 @@ async def test_logger(self): """Connection has a logger attribute.""" self.assertIsInstance(self.connection.logger, logging.LoggerAdapter) - @patch("asyncio.BaseTransport.get_extra_info", return_value=("sock", 1234)) + @patch("asyncio.Transport.get_extra_info", return_value=("sock", 1234)) async def test_local_address(self, get_extra_info): """Connection provides a local_address attribute.""" self.assertEqual(self.connection.local_address, ("sock", 1234)) get_extra_info.assert_called_with("sockname") - @patch("asyncio.BaseTransport.get_extra_info", return_value=("peer", 1234)) + @patch("asyncio.Transport.get_extra_info", return_value=("peer", 1234)) async def test_remote_address(self, get_extra_info): """Connection provides a remote_address attribute.""" self.assertEqual(self.connection.remote_address, ("peer", 1234)) diff --git a/tests/sync/test_connection.py b/tests/sync/test_connection.py index 92474fc16..e67149eb8 100644 --- a/tests/sync/test_connection.py +++ b/tests/sync/test_connection.py @@ -7,7 +7,7 @@ import time import unittest import uuid -from unittest.mock import patch +from unittest.mock import Mock, patch from websockets.exceptions import ( ConcurrencyError, @@ -853,11 +853,8 @@ def test_keepalive_reports_errors(self): def test_close_timeout(self): """close_timeout parameter configures close timeout.""" - socket_, remote_socket = socket.socketpair() - self.addCleanup(socket_.close) - self.addCleanup(remote_socket.close) connection = Connection( - socket_, + Mock(spec=socket.socket), Protocol(self.LOCAL), close_timeout=42 * MS, ) @@ -865,11 +862,8 @@ def test_close_timeout(self): def test_max_queue(self): """max_queue configures high-water mark of frames buffer.""" - socket_, remote_socket = socket.socketpair() - self.addCleanup(socket_.close) - self.addCleanup(remote_socket.close) connection = Connection( - socket_, + Mock(spec=socket.socket), Protocol(self.LOCAL), max_queue=4, ) @@ -877,11 +871,8 @@ def test_max_queue(self): def test_max_queue_none(self): """max_queue disables high-water mark of frames buffer.""" - socket_, remote_socket = socket.socketpair() - self.addCleanup(socket_.close) - self.addCleanup(remote_socket.close) connection = Connection( - socket_, + Mock(spec=socket.socket), Protocol(self.LOCAL), max_queue=None, ) @@ -890,11 +881,8 @@ def test_max_queue_none(self): def test_max_queue_tuple(self): """max_queue configures high-water and low-water marks of frames buffer.""" - socket_, remote_socket = socket.socketpair() - self.addCleanup(socket_.close) - self.addCleanup(remote_socket.close) connection = Connection( - socket_, + Mock(spec=socket.socket), Protocol(self.LOCAL), max_queue=(4, 2), ) From 5b3332fe4fd1382ef782a32ae5c621db0f905b34 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 1 Nov 2025 09:19:05 +0100 Subject: [PATCH 10/13] Remove platform-specific test that isn't very useful. --- tests/sync/test_connection.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/tests/sync/test_connection.py b/tests/sync/test_connection.py index e67149eb8..05e7d729d 100644 --- a/tests/sync/test_connection.py +++ b/tests/sync/test_connection.py @@ -2,10 +2,8 @@ import itertools import logging import socket -import sys import threading import time -import unittest import uuid from unittest.mock import Mock, patch @@ -937,17 +935,6 @@ def test_close_reason(self): # Test reporting of network errors. - @unittest.skipUnless(sys.platform == "darwin", "works only on BSD") - def test_reading_in_recv_events_fails(self): - """Error when reading incoming frames is correctly reported.""" - # Inject a fault by closing the socket. This works only on BSD. - # I cannot find a way to achieve the same effect on Linux. - self.connection.socket.close() - # The connection closed exception reports the injected fault. - with self.assertRaises(ConnectionClosedError) as raised: - self.connection.recv() - self.assertIsInstance(raised.exception.__cause__, IOError) - def test_writing_in_recv_events_fails(self): """Error when responding to incoming frames is correctly reported.""" # Inject a fault by shutting down the socket for writing — but not by From 5ce814b2551127c9bf59dd5594b4c1eadd530a1e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 1 Nov 2025 09:20:51 +0100 Subject: [PATCH 11/13] Simplify tests of safety net in connection classes. --- tests/asyncio/test_connection.py | 16 ++-------------- tests/sync/test_connection.py | 16 ++-------------- 2 files changed, 4 insertions(+), 28 deletions(-) diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index 3a238b51a..eddb505dc 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -1250,23 +1250,14 @@ async def test_writing_in_send_context_fails(self): # Test safety nets — catching all exceptions in case of bugs. - # Inject a fault in a random call in data_received(). - # This test is tightly coupled to the implementation. @patch("websockets.protocol.Protocol.events_received", side_effect=AssertionError) async def test_unexpected_failure_in_data_received(self, events_received): """Unexpected internal error in data_received() is correctly reported.""" - # Receive a message to trigger the fault. await self.remote_connection.send("😀") - with self.assertRaises(ConnectionClosedError) as raised: await self.connection.recv() + self.assertIsInstance(raised.exception.__cause__, AssertionError) - exc = raised.exception - self.assertEqual(str(exc), "no close frame received or sent") - self.assertIsInstance(exc.__cause__, AssertionError) - - # Inject a fault in a random call in send_context(). - # This test is tightly coupled to the implementation. @patch("websockets.protocol.Protocol.send_text", side_effect=AssertionError) async def test_unexpected_failure_in_send_context(self, send_text): """Unexpected internal error in send_context() is correctly reported.""" @@ -1274,10 +1265,7 @@ async def test_unexpected_failure_in_send_context(self, send_text): # The connection closed exception reports the injected fault. with self.assertRaises(ConnectionClosedError) as raised: await self.connection.send("😀") - - exc = raised.exception - self.assertEqual(str(exc), "no close frame received or sent") - self.assertIsInstance(exc.__cause__, AssertionError) + self.assertIsInstance(raised.exception.__cause__, AssertionError) # Test broadcast. diff --git a/tests/sync/test_connection.py b/tests/sync/test_connection.py index 05e7d729d..6fb852c78 100644 --- a/tests/sync/test_connection.py +++ b/tests/sync/test_connection.py @@ -960,23 +960,14 @@ def test_writing_in_send_context_fails(self): # Test safety nets — catching all exceptions in case of bugs. - # Inject a fault in a random call in recv_events(). - # This test is tightly coupled to the implementation. @patch("websockets.protocol.Protocol.events_received", side_effect=AssertionError) def test_unexpected_failure_in_recv_events(self, events_received): """Unexpected internal error in recv_events() is correctly reported.""" - # Receive a message to trigger the fault. self.remote_connection.send("😀") - with self.assertRaises(ConnectionClosedError) as raised: self.connection.recv() + self.assertIsInstance(raised.exception.__cause__, AssertionError) - exc = raised.exception - self.assertEqual(str(exc), "no close frame received or sent") - self.assertIsInstance(exc.__cause__, AssertionError) - - # Inject a fault in a random call in send_context(). - # This test is tightly coupled to the implementation. @patch("websockets.protocol.Protocol.send_text", side_effect=AssertionError) def test_unexpected_failure_in_send_context(self, send_text): """Unexpected internal error in send_context() is correctly reported.""" @@ -984,10 +975,7 @@ def test_unexpected_failure_in_send_context(self, send_text): # The connection closed exception reports the injected fault. with self.assertRaises(ConnectionClosedError) as raised: self.connection.send("😀") - - exc = raised.exception - self.assertEqual(str(exc), "no close frame received or sent") - self.assertIsInstance(exc.__cause__, AssertionError) + self.assertIsInstance(raised.exception.__cause__, AssertionError) class ServerConnectionTests(ClientConnectionTests): From 5917b4b67d2519695bf6ee58a6a3131b36e72e6b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 1 Nov 2025 09:22:09 +0100 Subject: [PATCH 12/13] Clean up the asyncio and sync connection tests. --- tests/asyncio/test_connection.py | 295 +++++++++++++++++-------------- tests/sync/test_connection.py | 199 +++++++++++---------- tests/sync/utils.py | 4 +- 3 files changed, 276 insertions(+), 222 deletions(-) diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index eddb505dc..f0342776c 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -35,12 +35,14 @@ class ClientConnectionTests(unittest.IsolatedAsyncioTestCase): async def asyncSetUp(self): self.loop = asyncio.get_running_loop() socket_, remote_socket = socket.socketpair() + protocol = Protocol(self.LOCAL) + remote_protocol = RecordingProtocol(self.REMOTE) self.transport, self.connection = await self.loop.create_connection( - lambda: Connection(Protocol(self.LOCAL), close_timeout=2 * MS), + lambda: Connection(protocol, close_timeout=2 * MS), sock=socket_, ) _remote_transport, self.remote_connection = await self.loop.create_connection( - lambda: InterceptingConnection(RecordingProtocol(self.REMOTE)), + lambda: InterceptingConnection(remote_protocol), sock=remote_socket, ) @@ -112,8 +114,8 @@ async def test_aexit(self): await self.assertNoFrameSent() await self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe8")) - async def test_exit_with_exception(self): - """__exit__ with an exception closes the connection with code 1011.""" + async def test_aexit_with_exception(self): + """__aexit__ with an exception closes the connection with code 1011.""" with self.assertRaises(RuntimeError): async with self.connection: raise RuntimeError @@ -211,20 +213,19 @@ async def test_recv_connection_closed_error(self): async def test_recv_non_utf8_text(self): """recv receives a non-UTF-8 text message.""" await self.remote_connection.send(b"\x01\x02\xfe\xff", text=True) - with self.assertRaises(ConnectionClosedError): + with self.assertRaises(ConnectionClosedError) as raised: await self.connection.recv() - await self.assertFrameSent( - Frame(Opcode.CLOSE, b"\x03\xefinvalid start byte at position 2") - ) + self.assertEqual(raised.exception.sent.code, CloseCode.INVALID_DATA) async def test_recv_during_recv(self): """recv raises ConcurrencyError when called concurrently.""" recv_task = asyncio.create_task(self.connection.recv()) await asyncio.sleep(0) # let the event loop start recv_task - self.addCleanup(recv_task.cancel) - - with self.assertRaises(ConcurrencyError) as raised: - await self.connection.recv() + try: + with self.assertRaises(ConcurrencyError) as raised: + await self.connection.recv() + finally: + recv_task.cancel() self.assertEqual( str(raised.exception), "cannot call recv while another coroutine " @@ -237,10 +238,11 @@ async def test_recv_during_recv_streaming(self): alist(self.connection.recv_streaming()) ) await asyncio.sleep(0) # let the event loop start recv_streaming_task - self.addCleanup(recv_streaming_task.cancel) - - with self.assertRaises(ConcurrencyError) as raised: - await self.connection.recv() + try: + with self.assertRaises(ConcurrencyError) as raised: + await self.connection.recv() + finally: + recv_streaming_task.cancel() self.assertEqual( str(raised.exception), "cannot call recv while another coroutine " @@ -248,10 +250,9 @@ async def test_recv_during_recv_streaming(self): ) async def test_recv_cancellation_before_receiving(self): - """recv can be canceled before receiving a frame.""" + """recv can be canceled before receiving a message.""" recv_task = asyncio.create_task(self.connection.recv()) await asyncio.sleep(0) # let the event loop start recv_task - recv_task.cancel() await asyncio.sleep(0) # let the event loop cancel recv_task @@ -260,25 +261,25 @@ async def test_recv_cancellation_before_receiving(self): self.assertEqual(await self.connection.recv(), "😀") async def test_recv_cancellation_while_receiving(self): - """recv cannot be canceled after receiving a frame.""" - recv_task = asyncio.create_task(self.connection.recv()) - await asyncio.sleep(0) # let the event loop start recv_task - - gate = asyncio.get_running_loop().create_future() + """recv can be canceled while receiving a fragmented message.""" + gate = asyncio.Event() async def fragments(): yield "⏳" - await gate + await gate.wait() yield "⌛️" asyncio.create_task(self.remote_connection.send(fragments())) - await asyncio.sleep(MS) + await asyncio.sleep(0) + recv_task = asyncio.create_task(self.connection.recv()) + await asyncio.sleep(0) # let the event loop start recv_task recv_task.cancel() await asyncio.sleep(0) # let the event loop cancel recv_task + gate.set() + # Running recv again receives the complete message. - gate.set_result(None) self.assertEqual(await self.connection.recv(), "⏳⌛️") # Test recv_streaming. @@ -350,21 +351,20 @@ async def test_recv_streaming_connection_closed_error(self): async def test_recv_streaming_non_utf8_text(self): """recv_streaming receives a non-UTF-8 text message.""" await self.remote_connection.send(b"\x01\x02\xfe\xff", text=True) - with self.assertRaises(ConnectionClosedError): + with self.assertRaises(ConnectionClosedError) as raised: await alist(self.connection.recv_streaming()) - await self.assertFrameSent( - Frame(Opcode.CLOSE, b"\x03\xefinvalid start byte at position 2") - ) + self.assertEqual(raised.exception.sent.code, CloseCode.INVALID_DATA) async def test_recv_streaming_during_recv(self): """recv_streaming raises ConcurrencyError when called concurrently with recv.""" recv_task = asyncio.create_task(self.connection.recv()) await asyncio.sleep(0) # let the event loop start recv_task - self.addCleanup(recv_task.cancel) - - with self.assertRaises(ConcurrencyError) as raised: - async for _ in self.connection.recv_streaming(): - self.fail("did not raise") + try: + with self.assertRaises(ConcurrencyError) as raised: + async for _ in self.connection.recv_streaming(): + self.fail("did not raise") + finally: + recv_task.cancel() self.assertEqual( str(raised.exception), "cannot call recv_streaming while another coroutine " @@ -377,11 +377,12 @@ async def test_recv_streaming_during_recv_streaming(self): alist(self.connection.recv_streaming()) ) await asyncio.sleep(0) # let the event loop start recv_streaming_task - self.addCleanup(recv_streaming_task.cancel) - - with self.assertRaises(ConcurrencyError) as raised: - async for _ in self.connection.recv_streaming(): - self.fail("did not raise") + try: + with self.assertRaises(ConcurrencyError) as raised: + async for _ in self.connection.recv_streaming(): + self.fail("did not raise") + finally: + recv_streaming_task.cancel() self.assertEqual( str(raised.exception), r"cannot call recv_streaming while another coroutine " @@ -394,7 +395,6 @@ async def test_recv_streaming_cancellation_before_receiving(self): alist(self.connection.recv_streaming()) ) await asyncio.sleep(0) # let the event loop start recv_streaming_task - recv_streaming_task.cancel() await asyncio.sleep(0) # let the event loop cancel recv_streaming_task @@ -407,28 +407,32 @@ async def test_recv_streaming_cancellation_before_receiving(self): async def test_recv_streaming_cancellation_while_receiving(self): """recv_streaming cannot be canceled after receiving a frame.""" - recv_streaming_task = asyncio.create_task( - alist(self.connection.recv_streaming()) - ) - await asyncio.sleep(0) # let the event loop start recv_streaming_task - - gate = asyncio.get_running_loop().create_future() + gate = asyncio.Event() async def fragments(): yield "⏳" - await gate + await gate.wait() yield "⌛️" asyncio.create_task(self.remote_connection.send(fragments())) - await asyncio.sleep(MS) + await asyncio.sleep(0) + recv_streaming_task = asyncio.create_task( + alist(self.connection.recv_streaming()) + ) + await asyncio.sleep(0) # let the event loop start recv_streaming_task + await asyncio.sleep(0) # experimentally, two runs of the event loop + await asyncio.sleep(0) # are needed to receive the first fragment recv_streaming_task.cancel() await asyncio.sleep(0) # let the event loop cancel recv_streaming_task - gate.set_result(None) + gate.set() + await asyncio.sleep(0) + # Running recv_streaming again fails. with self.assertRaises(ConcurrencyError): - await alist(self.connection.recv_streaming()) + async for _ in self.connection.recv_streaming(): + self.fail("did not raise") # Test send. @@ -556,23 +560,31 @@ async def test_send_connection_closed_error(self): with self.assertRaises(ConnectionClosedError): await self.connection.send("😀") - async def test_send_while_send_blocked(self): + async def test_send_during_send(self): """send waits for a previous call to send to complete.""" # This test fails if the guard with send_in_progress is removed - # from send() in the case when message is an Iterable. - self.connection.pause_writing() - asyncio.create_task(self.connection.send(["⏳", "⌛️"])) - await asyncio.sleep(MS) + # from send() in the case when message is an AsyncIterable. + gate = asyncio.Event() + + async def fragments(): + yield "⏳" + await gate.wait() + yield "⌛️" + + asyncio.create_task(self.connection.send(fragments())) + await asyncio.sleep(0) # let the event loop start the task await self.assertFrameSent( Frame(Opcode.TEXT, "⏳".encode(), fin=False), ) asyncio.create_task(self.connection.send("✅")) - await asyncio.sleep(MS) + await asyncio.sleep(0) # let the event loop start the task await self.assertNoFrameSent() - self.connection.resume_writing() - await asyncio.sleep(MS) + gate.set() + await asyncio.sleep(0) # run the event loop + await asyncio.sleep(0) # three times in order + await asyncio.sleep(0) # to send three frames await self.assertFramesSent( [ Frame(Opcode.CONT, "⌛️".encode(), fin=False), @@ -581,28 +593,26 @@ async def test_send_while_send_blocked(self): ] ) - async def test_send_while_send_async_blocked(self): - """send waits for a previous call to send to complete.""" + async def test_send_while_send_blocked(self): + """send waits for a blocked call to send to complete.""" # This test fails if the guard with send_in_progress is removed - # from send() in the case when message is an AsyncIterable. + # from send() in the case when message is an Iterable. self.connection.pause_writing() - async def fragments(): - yield "⏳" - yield "⌛️" - - asyncio.create_task(self.connection.send(fragments())) - await asyncio.sleep(MS) + asyncio.create_task(self.connection.send(["⏳", "⌛️"])) + await asyncio.sleep(0) # let the event loop start the task await self.assertFrameSent( Frame(Opcode.TEXT, "⏳".encode(), fin=False), ) asyncio.create_task(self.connection.send("✅")) - await asyncio.sleep(MS) + await asyncio.sleep(0) # let the event loop start the task await self.assertNoFrameSent() self.connection.resume_writing() - await asyncio.sleep(MS) + await asyncio.sleep(0) # run the event loop + await asyncio.sleep(0) # three times in order + await asyncio.sleep(0) # to send three frames await self.assertFramesSent( [ Frame(Opcode.CONT, "⌛️".encode(), fin=False), @@ -611,29 +621,30 @@ async def fragments(): ] ) - async def test_send_during_send_async(self): - """send waits for a previous call to send to complete.""" + async def test_send_while_send_async_blocked(self): + """send waits for a blocked call to send to complete.""" # This test fails if the guard with send_in_progress is removed # from send() in the case when message is an AsyncIterable. - gate = asyncio.get_running_loop().create_future() + self.connection.pause_writing() async def fragments(): yield "⏳" - await gate yield "⌛️" asyncio.create_task(self.connection.send(fragments())) - await asyncio.sleep(MS) + await asyncio.sleep(0) # let the event loop start the task await self.assertFrameSent( Frame(Opcode.TEXT, "⏳".encode(), fin=False), ) asyncio.create_task(self.connection.send("✅")) - await asyncio.sleep(MS) + await asyncio.sleep(0) # let the event loop start the task await self.assertNoFrameSent() - gate.set_result(None) - await asyncio.sleep(MS) + self.connection.resume_writing() + await asyncio.sleep(0) # run the event loop + await asyncio.sleep(0) # three times in order + await asyncio.sleep(0) # to send three frames await self.assertFramesSent( [ Frame(Opcode.CONT, "⌛️".encode(), fin=False), @@ -676,8 +687,10 @@ async def fragments(): yield "😀" yield b"\xfe\xff" - with self.assertRaises(TypeError): - await self.connection.send(fragments()) + iterator = fragments() + async with contextlib.aclosing(iterator): + with self.assertRaises(TypeError): + await self.connection.send(iterator) async def test_send_unsupported_async_iterable(self): """send raises TypeError when called with an iterable of unsupported type.""" @@ -685,8 +698,10 @@ async def test_send_unsupported_async_iterable(self): async def fragments(): yield None - with self.assertRaises(TypeError): - await self.connection.send(fragments()) + iterator = fragments() + async with contextlib.aclosing(iterator): + with self.assertRaises(TypeError): + await self.connection.send(iterator) async def test_send_dict(self): """send raises TypeError when called with a dict.""" @@ -837,13 +852,9 @@ async def test_close_preserves_queued_messages(self): await self.connection.close() self.assertEqual(await self.connection.recv(), "😀") - with self.assertRaises(ConnectionClosedOK) as raised: + with self.assertRaises(ConnectionClosedOK): await self.connection.recv() - exc = raised.exception - self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") - self.assertIsNone(exc.__cause__) - async def test_close_idempotency(self): """close does nothing if the connection is already closed.""" await self.connection.close() @@ -854,11 +865,30 @@ async def test_close_idempotency(self): async def test_close_during_recv(self): """close aborts recv when called concurrently with recv.""" - recv_task = asyncio.create_task(self.connection.recv()) - await asyncio.sleep(MS) - await self.connection.close() + + async def closer(): + await asyncio.sleep(MS) + await self.connection.close() + + asyncio.create_task(closer()) with self.assertRaises(ConnectionClosedOK) as raised: - await recv_task + await self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") + self.assertIsNone(exc.__cause__) + + async def test_close_during_recv_streaming(self): + """close aborts recv_streaming when called concurrently with recv_streaming.""" + + async def closer(): + await asyncio.sleep(MS) + await self.connection.close() + + asyncio.create_task(closer()) + with self.assertRaises(ConnectionClosedOK) as raised: + async for _ in self.connection.recv_streaming(): + self.fail("did not raise") exc = raised.exception self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") @@ -866,23 +896,25 @@ async def test_close_during_recv(self): async def test_close_during_send(self): """close fails the connection when called concurrently with send.""" - gate = asyncio.get_running_loop().create_future() + close_gate = asyncio.Event() + exit_gate = asyncio.Event() + + async def closer(): + await close_gate.wait() + await self.connection.close() + exit_gate.set() async def fragments(): yield "⏳" - await gate + close_gate.set() + await exit_gate.wait() yield "⌛️" - send_task = asyncio.create_task(self.connection.send(fragments())) - await asyncio.sleep(MS) - - asyncio.create_task(self.connection.close()) - await asyncio.sleep(MS) - - gate.set_result(None) - - with self.assertRaises(ConnectionClosedError) as raised: - await send_task + asyncio.create_task(closer()) + iterator = fragments() + async with contextlib.aclosing(iterator): + with self.assertRaises(ConnectionClosedError) as raised: + await self.connection.send(iterator) exc = raised.exception self.assertEqual( @@ -969,6 +1001,14 @@ async def test_acknowledge_previous_canceled_ping(self): with self.assertRaises(asyncio.CancelledError): await pong_received + async def test_terminate_ping_on_close(self): + """ping is canceled when the connection is closed.""" + async with self.drop_frames_rcvd(): # drop automatic response to ping + pong_received = await self.connection.ping("this") + await self.connection.close() + with self.assertRaises(ConnectionClosedOK): + await pong_received + async def test_ping_duplicate_payload(self): """ping rejects the same payload until receiving the pong.""" async with self.drop_frames_rcvd(): # drop automatic response to ping @@ -1050,8 +1090,8 @@ async def test_keepalive_times_out(self, getrandbits): # 4.x ms: a pong frame is dropped. await asyncio.sleep(5 * MS) # 6 ms: no pong frame is received; the connection is closed. - await asyncio.sleep(2 * MS) - # 7 ms: check that the connection is closed. + await asyncio.sleep(3 * MS) + # 8 ms: check that the connection is closed. self.assertEqual(self.connection.state, State.CLOSED) @patch("random.getrandbits") @@ -1066,8 +1106,8 @@ async def test_keepalive_ignores_timeout(self, getrandbits): # 4.x ms: a pong frame is dropped. await asyncio.sleep(5 * MS) # 6 ms: no pong frame is received; the connection remains open. - await asyncio.sleep(2 * MS) - # 7 ms: check that the connection is still open. + await asyncio.sleep(3 * MS) + # 8 ms: check that the connection is still open. self.assertEqual(self.connection.state, State.OPEN) async def test_keepalive_terminates_while_sleeping(self): @@ -1079,6 +1119,9 @@ async def test_keepalive_terminates_while_sleeping(self): await self.connection.close() self.assertTrue(self.connection.keepalive_task.done()) + # test_keepalive_terminates_when_sending_ping_fails is not implemented + # because sending a ping cannot fail in the asyncio implementation. + async def test_keepalive_terminates_while_waiting_for_pong(self): """keepalive task terminates while waiting to receive a pong.""" self.connection.ping_interval = MS @@ -1095,8 +1138,8 @@ async def test_keepalive_terminates_while_waiting_for_pong(self): async def test_keepalive_reports_errors(self): """keepalive reports unexpected errors in logs.""" self.connection.ping_interval = 2 * MS + self.connection.start_keepalive() async with self.drop_frames_rcvd(): - self.connection.start_keepalive() # 2 ms: keepalive() sends a ping frame. # 2.x ms: a pong frame is dropped. await asyncio.sleep(3 * MS) @@ -1183,15 +1226,17 @@ async def test_logger(self): """Connection has a logger attribute.""" self.assertIsInstance(self.connection.logger, logging.LoggerAdapter) - @patch("asyncio.Transport.get_extra_info", return_value=("sock", 1234)) + @patch("asyncio.Transport.get_extra_info") async def test_local_address(self, get_extra_info): - """Connection provides a local_address attribute.""" + """Connection has a local_address attribute.""" + get_extra_info.return_value = ("sock", 1234) self.assertEqual(self.connection.local_address, ("sock", 1234)) get_extra_info.assert_called_with("sockname") - @patch("asyncio.Transport.get_extra_info", return_value=("peer", 1234)) + @patch("asyncio.Transport.get_extra_info") async def test_remote_address(self, get_extra_info): - """Connection provides a remote_address attribute.""" + """Connection has a remote_address attribute.""" + get_extra_info.return_value = ("peer", 1234) self.assertEqual(self.connection.remote_address, ("peer", 1234)) get_extra_info.assert_called_with("peername") @@ -1228,12 +1273,9 @@ async def test_writing_in_data_received_fails(self): self.transport.write_eof() # Receive a ping. Responding with a pong will fail. await self.remote_connection.ping() - # The connection closed exception reports the injected fault. with self.assertRaises(ConnectionClosedError) as raised: await self.connection.recv() - cause = raised.exception.__cause__ - self.assertEqual(str(cause), "Cannot call write() after write_eof()") - self.assertIsInstance(cause, RuntimeError) + self.assertIsInstance(raised.exception.__cause__, RuntimeError) async def test_writing_in_send_context_fails(self): """Error when sending outgoing frame is correctly reported.""" @@ -1241,12 +1283,9 @@ async def test_writing_in_send_context_fails(self): # closing it because that would terminate the connection. self.transport.write_eof() # Sending a pong will fail. - # The connection closed exception reports the injected fault. with self.assertRaises(ConnectionClosedError) as raised: await self.connection.pong() - cause = raised.exception.__cause__ - self.assertEqual(str(cause), "Cannot call write() after write_eof()") - self.assertIsInstance(cause, RuntimeError) + self.assertIsInstance(raised.exception.__cause__, RuntimeError) # Test safety nets — catching all exceptions in case of bugs. @@ -1254,6 +1293,7 @@ async def test_writing_in_send_context_fails(self): async def test_unexpected_failure_in_data_received(self, events_received): """Unexpected internal error in data_received() is correctly reported.""" await self.remote_connection.send("😀") + # Reading the message will trigger the injected fault. with self.assertRaises(ConnectionClosedError) as raised: await self.connection.recv() self.assertIsInstance(raised.exception.__cause__, AssertionError) @@ -1261,8 +1301,7 @@ async def test_unexpected_failure_in_data_received(self, events_received): @patch("websockets.protocol.Protocol.send_text", side_effect=AssertionError) async def test_unexpected_failure_in_send_context(self, send_text): """Unexpected internal error in send_context() is correctly reported.""" - # Send a message to trigger the fault. - # The connection closed exception reports the injected fault. + # Sending a message will trigger the injected fault. with self.assertRaises(ConnectionClosedError) as raised: await self.connection.send("😀") self.assertIsInstance(raised.exception.__cause__, AssertionError) @@ -1336,11 +1375,11 @@ async def test_broadcast_skips_closing_connection(self): async def test_broadcast_skips_connection_with_send_blocked(self): """broadcast logs a warning when a connection is blocked in send.""" - gate = asyncio.get_running_loop().create_future() + gate = asyncio.Event() async def fragments(): yield "⏳" - await gate + await gate.wait() send_task = asyncio.create_task(self.connection.send(fragments())) await asyncio.sleep(MS) @@ -1354,7 +1393,7 @@ async def fragments(): ["skipped broadcast: sending a fragmented message"], ) - gate.set_result(None) + gate.set() await send_task @unittest.skipIf( @@ -1363,11 +1402,11 @@ async def fragments(): ) async def test_broadcast_reports_connection_with_send_blocked(self): """broadcast raises exceptions for connections blocked in send.""" - gate = asyncio.get_running_loop().create_future() + gate = asyncio.Event() async def fragments(): yield "⏳" - await gate + await gate.wait() send_task = asyncio.create_task(self.connection.send(fragments())) await asyncio.sleep(MS) @@ -1381,7 +1420,7 @@ async def fragments(): self.assertEqual(str(exc), "sending a fragmented message") self.assertIsInstance(exc, ConcurrencyError) - gate.set_result(None) + gate.set() await send_task async def test_broadcast_skips_connection_failing_to_send(self): diff --git a/tests/sync/test_connection.py b/tests/sync/test_connection.py index 6fb852c78..7fa88e94c 100644 --- a/tests/sync/test_connection.py +++ b/tests/sync/test_connection.py @@ -199,30 +199,32 @@ def test_recv_connection_closed_error(self): def test_recv_non_utf8_text(self): """recv receives a non-UTF-8 text message.""" self.remote_connection.send(b"\x01\x02\xfe\xff", text=True) - with self.assertRaises(ConnectionClosedError): + with self.assertRaises(ConnectionClosedError) as raised: self.connection.recv() - self.assertFrameSent( - Frame(Opcode.CLOSE, b"\x03\xefinvalid start byte at position 2") - ) + self.assertEqual(raised.exception.sent.code, CloseCode.INVALID_DATA) def test_recv_during_recv(self): """recv raises ConcurrencyError when called concurrently.""" with self.run_in_thread(self.connection.recv): - with self.assertRaises(ConcurrencyError) as raised: - self.connection.recv() - self.remote_connection.send("") - self.assertEqual( - str(raised.exception), - "cannot call recv while another thread " - "is already running recv or recv_streaming", - ) + try: + with self.assertRaises(ConcurrencyError) as raised: + self.connection.recv() + finally: + self.remote_connection.send("") + self.assertEqual( + str(raised.exception), + "cannot call recv while another thread " + "is already running recv or recv_streaming", + ) def test_recv_during_recv_streaming(self): """recv raises ConcurrencyError when called concurrently with recv_streaming.""" with self.run_in_thread(lambda: list(self.connection.recv_streaming())): - with self.assertRaises(ConcurrencyError) as raised: - self.connection.recv() - self.remote_connection.send("") + try: + with self.assertRaises(ConcurrencyError) as raised: + self.connection.recv() + finally: + self.remote_connection.send("") self.assertEqual( str(raised.exception), "cannot call recv while another thread " @@ -298,19 +300,19 @@ def test_recv_streaming_connection_closed_error(self): def test_recv_streaming_non_utf8_text(self): """recv_streaming receives a non-UTF-8 text message.""" self.remote_connection.send(b"\x01\x02\xfe\xff", text=True) - with self.assertRaises(ConnectionClosedError): + with self.assertRaises(ConnectionClosedError) as raised: list(self.connection.recv_streaming()) - self.assertFrameSent( - Frame(Opcode.CLOSE, b"\x03\xefinvalid start byte at position 2") - ) + self.assertEqual(raised.exception.sent.code, CloseCode.INVALID_DATA) def test_recv_streaming_during_recv(self): """recv_streaming raises ConcurrencyError when called concurrently with recv.""" with self.run_in_thread(self.connection.recv): - with self.assertRaises(ConcurrencyError) as raised: - for _ in self.connection.recv_streaming(): - self.fail("did not raise") - self.remote_connection.send("") + try: + with self.assertRaises(ConcurrencyError) as raised: + for _ in self.connection.recv_streaming(): + self.fail("did not raise") + finally: + self.remote_connection.send("") self.assertEqual( str(raised.exception), "cannot call recv_streaming while another thread " @@ -320,14 +322,16 @@ def test_recv_streaming_during_recv(self): def test_recv_streaming_during_recv_streaming(self): """recv_streaming raises ConcurrencyError when called concurrently.""" with self.run_in_thread(lambda: list(self.connection.recv_streaming())): - with self.assertRaises(ConcurrencyError) as raised: - for _ in self.connection.recv_streaming(): - self.fail("did not raise") - self.remote_connection.send("") + try: + with self.assertRaises(ConcurrencyError) as raised: + for _ in self.connection.recv_streaming(): + self.fail("did not raise") + finally: + self.remote_connection.send("") self.assertEqual( str(raised.exception), - r"cannot call recv_streaming while another thread " - r"is already running recv or recv_streaming", + "cannot call recv_streaming while another thread " + "is already running recv or recv_streaming", ) # Test send. @@ -412,30 +416,25 @@ def fragments(): exit_gate.wait() yield "😀" - send_thread = threading.Thread( - target=self.connection.send, - args=(fragments(),), - ) - send_thread.start() - - send_gate.wait() - # The check happens in four code paths, depending on the argument. - for message in [ - "😀", - b"\x01\x02\xfe\xff", - ["😀", "😀"], - [b"\x01\x02", b"\xfe\xff"], - ]: - with self.subTest(message=message): - with self.assertRaises(ConcurrencyError) as raised: - self.connection.send(message) - self.assertEqual( - str(raised.exception), - "cannot call send while another thread is already running send", - ) - - exit_gate.set() - send_thread.join() + with self.run_in_thread(self.connection.send, args=(fragments(),)): + send_gate.wait() + # The check happens in four code paths, depending on the argument. + for message in [ + "😀", + b"\x01\x02\xfe\xff", + ["😀", "😀"], + [b"\x01\x02", b"\xfe\xff"], + ]: + with self.subTest(message=message): + with self.assertRaises(ConcurrencyError) as raised: + self.connection.send(message) + self.assertEqual( + str(raised.exception), + "cannot call send while another thread " + "is already running send", + ) + + exit_gate.set() def test_send_empty_iterable(self): """send does nothing when called with an empty iterable.""" @@ -559,13 +558,9 @@ def test_close_preserves_queued_messages(self): self.connection.close() self.assertEqual(self.connection.recv(), "😀") - with self.assertRaises(ConnectionClosedOK) as raised: + with self.assertRaises(ConnectionClosedOK): self.connection.recv() - exc = raised.exception - self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") - self.assertIsNone(exc.__cause__) - def test_close_idempotency(self): """close does nothing if the connection is already closed.""" self.connection.close() @@ -576,7 +571,6 @@ def test_close_idempotency(self): def test_close_idempotency_race_condition(self): """close waits if the connection is already closing.""" - self.connection.close_timeout = 6 * MS def closer(): @@ -600,7 +594,13 @@ def closer(): def test_close_during_recv(self): """close aborts recv when called concurrently with recv.""" - with self.run_in_thread(self.connection.close): + + def closer(): + # Wait 2 * MS because run_in_thread() waits for MS. + time.sleep(2 * MS) + self.connection.close() + + with self.run_in_thread(closer): with self.assertRaises(ConnectionClosedOK) as raised: self.connection.recv() @@ -608,6 +608,23 @@ def test_close_during_recv(self): self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") self.assertIsNone(exc.__cause__) + def test_aclose_during_recv_streaming(self): + """aclose aborts recv_streaming when called concurrently with recv_streaming.""" + + def closer(): + # Wait 2 * MS because run_in_thread() waits for MS. + time.sleep(2 * MS) + self.connection.close() + + with self.run_in_thread(closer): + with self.assertRaises(ConnectionClosedOK) as raised: + for _ in self.connection.recv_streaming(): + self.fail("did not raise") + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") + self.assertIsNone(exc.__cause__) + def test_close_during_send(self): """close fails the connection when called concurrently with send.""" close_gate = threading.Event() @@ -619,16 +636,16 @@ def closer(): exit_gate.set() def fragments(): - yield "😀" + yield "⏳" close_gate.set() exit_gate.wait() - yield "😀" + yield "⌛️" - close_thread = threading.Thread(target=closer) - close_thread.start() - - with self.assertRaises(ConnectionClosedError) as raised: - self.connection.send(fragments()) + with self.run_in_thread(closer): + iterator = fragments() + with contextlib.closing(iterator): + with self.assertRaises(ConnectionClosedError) as raised: + self.connection.send(iterator) exc = raised.exception self.assertEqual( @@ -638,8 +655,6 @@ def fragments(): ) self.assertIsNone(exc.__cause__) - close_thread.join() - # Test ping. @patch("random.getrandbits") @@ -741,7 +756,7 @@ def test_pong_unsupported_type(self): def test_keepalive(self, getrandbits): """keepalive sends pings at ping_interval and measures latency.""" getrandbits.side_effect = itertools.count(1918987876) - self.connection.ping_interval = 4 * MS + self.connection.ping_interval = 3 * MS self.connection.start_keepalive() self.assertIsNotNone(self.connection.keepalive_thread) self.assertEqual(self.connection.latency, 0) @@ -772,8 +787,8 @@ def test_keepalive_times_out(self, getrandbits): time.sleep(4 * MS) # Exiting the context manager sleeps for 1 ms. # 6 ms: no pong frame is received; the connection is closed. - time.sleep(2 * MS) - # 7 ms: check that the connection is closed. + time.sleep(3 * MS) + # 8 ms: check that the connection is closed. self.assertEqual(self.connection.state, State.CLOSED) @patch("random.getrandbits") @@ -785,19 +800,19 @@ def test_keepalive_ignores_timeout(self, getrandbits): with self.drop_frames_rcvd(): self.connection.start_keepalive() # 4 ms: keepalive() sends a ping frame. + # 4.x ms: a pong frame is dropped. time.sleep(4 * MS) # Exiting the context manager sleeps for 1 ms. - # 4.x ms: a pong frame is dropped. # 6 ms: no pong frame is received; the connection remains open. - time.sleep(2 * MS) - # 7 ms: check that the connection is still open. + time.sleep(3 * MS) + # 8 ms: check that the connection is still open. self.assertEqual(self.connection.state, State.OPEN) def test_keepalive_terminates_while_sleeping(self): """keepalive task terminates while waiting to send a ping.""" self.connection.ping_interval = 3 * MS self.connection.start_keepalive() - time.sleep(MS) + self.connection.keepalive_thread.join(MS) self.assertTrue(self.connection.keepalive_thread.is_alive()) self.connection.close() self.connection.keepalive_thread.join(MS) @@ -807,6 +822,7 @@ def test_keepalive_terminates_when_sending_ping_fails(self): """keepalive task terminates when sending a ping fails.""" self.connection.ping_interval = MS self.connection.start_keepalive() + self.assertTrue(self.connection.keepalive_thread.is_alive()) with self.drop_eof_rcvd(), self.drop_frames_rcvd(): self.connection.close() # Exiting the context managers sleeps for 2 ms. @@ -830,14 +846,13 @@ def test_keepalive_terminates_while_waiting_for_pong(self): def test_keepalive_reports_errors(self): """keepalive reports unexpected errors in logs.""" self.connection.ping_interval = 2 * MS - with self.drop_frames_rcvd(): - self.connection.start_keepalive() - # 2 ms: keepalive() sends a ping frame. - # 2.x ms: a pong frame is dropped. - with self.assertLogs("websockets", logging.ERROR) as logs: - with patch("threading.Event.wait", side_effect=Exception("BOOM")): - time.sleep(3 * MS) - # Exiting the context manager sleeps for 1 ms. + self.connection.start_keepalive() + # Inject a fault when waiting to receive a pong. + with self.assertLogs("websockets", logging.ERROR) as logs: + with patch("threading.Event.wait", side_effect=Exception("BOOM")): + # 2 ms: keepalive() sends a ping frame. + # 2.x ms: a pong frame is dropped. + time.sleep(3 * MS) self.assertEqual( [record.getMessage() for record in logs.records], ["keepalive ping failed"], @@ -897,15 +912,17 @@ def test_logger(self): """Connection has a logger attribute.""" self.assertIsInstance(self.connection.logger, logging.LoggerAdapter) - @patch("socket.socket.getsockname", return_value=("sock", 1234)) + @patch("socket.socket.getsockname") def test_local_address(self, getsockname): - """Connection provides a local_address attribute.""" + """Connection has a local_address attribute.""" + getsockname.return_value = ("sock", 1234) self.assertEqual(self.connection.local_address, ("sock", 1234)) getsockname.assert_called_with() - @patch("socket.socket.getpeername", return_value=("peer", 1234)) + @patch("socket.socket.getpeername") def test_remote_address(self, getpeername): - """Connection provides a remote_address attribute.""" + """Connection has a remote_address attribute.""" + getpeername.return_value = ("peer", 1234) self.assertEqual(self.connection.remote_address, ("peer", 1234)) getpeername.assert_called_with() @@ -942,7 +959,6 @@ def test_writing_in_recv_events_fails(self): self.connection.socket.shutdown(socket.SHUT_WR) # Receive a ping. Responding with a pong will fail. self.remote_connection.ping() - # The connection closed exception reports the injected fault. with self.assertRaises(ConnectionClosedError) as raised: self.connection.recv() self.assertIsInstance(raised.exception.__cause__, BrokenPipeError) @@ -953,7 +969,6 @@ def test_writing_in_send_context_fails(self): # closing it because that would terminate the connection. self.connection.socket.shutdown(socket.SHUT_WR) # Sending a pong will fail. - # The connection closed exception reports the injected fault. with self.assertRaises(ConnectionClosedError) as raised: self.connection.pong() self.assertIsInstance(raised.exception.__cause__, BrokenPipeError) @@ -964,6 +979,7 @@ def test_writing_in_send_context_fails(self): def test_unexpected_failure_in_recv_events(self, events_received): """Unexpected internal error in recv_events() is correctly reported.""" self.remote_connection.send("😀") + # Reading the message will trigger the injected fault. with self.assertRaises(ConnectionClosedError) as raised: self.connection.recv() self.assertIsInstance(raised.exception.__cause__, AssertionError) @@ -971,8 +987,7 @@ def test_unexpected_failure_in_recv_events(self, events_received): @patch("websockets.protocol.Protocol.send_text", side_effect=AssertionError) def test_unexpected_failure_in_send_context(self, send_text): """Unexpected internal error in send_context() is correctly reported.""" - # Send a message to trigger the fault. - # The connection closed exception reports the injected fault. + # Sending a message will trigger the injected fault. with self.assertRaises(ConnectionClosedError) as raised: self.connection.send("😀") self.assertIsInstance(raised.exception.__cause__, AssertionError) diff --git a/tests/sync/utils.py b/tests/sync/utils.py index 8903cd349..9f734504a 100644 --- a/tests/sync/utils.py +++ b/tests/sync/utils.py @@ -8,7 +8,7 @@ class ThreadTestCase(unittest.TestCase): @contextlib.contextmanager - def run_in_thread(self, target): + def run_in_thread(self, target, args=(), kwargs=None): """ Run ``target`` function without arguments in a thread. @@ -16,7 +16,7 @@ def run_in_thread(self, target): for 1ms on entry and joins the thread with a 1ms timeout on exit. """ - thread = threading.Thread(target=target) + thread = threading.Thread(target=target, args=args, kwargs=kwargs) thread.start() time.sleep(MS) try: From fa35f5396db0e480b5b68896e209c8fad4cf64ca Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 1 Nov 2025 09:23:45 +0100 Subject: [PATCH 13/13] Add trio connection. --- src/websockets/trio/connection.py | 1124 +++++++++++++++++++++++++ tests/trio/connection.py | 116 +++ tests/trio/test_connection.py | 1265 +++++++++++++++++++++++++++++ 3 files changed, 2505 insertions(+) create mode 100644 src/websockets/trio/connection.py create mode 100644 tests/trio/connection.py create mode 100644 tests/trio/test_connection.py diff --git a/src/websockets/trio/connection.py b/src/websockets/trio/connection.py new file mode 100644 index 000000000..61532316f --- /dev/null +++ b/src/websockets/trio/connection.py @@ -0,0 +1,1124 @@ +from __future__ import annotations + +import contextlib +import logging +import random +import struct +import uuid +from collections.abc import AsyncIterable, AsyncIterator, Iterable, Mapping +from types import TracebackType +from typing import Any, Literal, overload + +import trio +import trio.abc + +from ..asyncio.compatibility import ( + TimeoutError, + aiter, + anext, +) +from ..exceptions import ( + ConcurrencyError, + ConnectionClosed, + ConnectionClosedOK, + ProtocolError, +) +from ..frames import DATA_OPCODES, CloseCode, Frame, Opcode +from ..http11 import Request, Response +from ..protocol import CLOSED, OPEN, Event, Protocol, State +from ..typing import BytesLike, Data, DataLike, LoggerLike, Subprotocol +from .messages import Assembler + + +__all__ = ["Connection"] + + +class Connection(trio.abc.AsyncResource): + """ + :mod:`trio` implementation of a WebSocket connection. + + :class:`Connection` provides APIs shared between WebSocket servers and + clients. + + You shouldn't use it directly. Instead, use + :class:`~websockets.trio.client.ClientConnection` or + :class:`~websockets.trio.server.ServerConnection`. + + """ + + def __init__( + self, + nursery: trio.Nursery, + stream: trio.abc.Stream, + protocol: Protocol, + *, + ping_interval: float | None = 20, + ping_timeout: float | None = 20, + close_timeout: float | None = 10, + max_queue: int | None | tuple[int | None, int | None] = 16, + ) -> None: + self.nursery = nursery + self.stream = stream + self.protocol = protocol + self.ping_interval = ping_interval + self.ping_timeout = ping_timeout + self.close_timeout = close_timeout + if isinstance(max_queue, int) or max_queue is None: + max_queue_high, max_queue_low = max_queue, None + else: + max_queue_high, max_queue_low = max_queue + + # Inject reference to this instance in the protocol's logger. + self.protocol.logger = logging.LoggerAdapter( + self.protocol.logger, + {"websocket": self}, + ) + + # Copy attributes from the protocol for convenience. + self.id: uuid.UUID = self.protocol.id + """Unique identifier of the connection. Useful in logs.""" + self.logger: LoggerLike = self.protocol.logger + """Logger for this connection.""" + self.debug = self.protocol.debug + + # HTTP handshake request and response. + self.request: Request | None = None + """Opening handshake request.""" + self.response: Response | None = None + """Opening handshake response.""" + + # Lock stopping reads when the assembler buffer is full. + self.recv_flow_control = trio.Lock() + + # Assembler turning frames into messages and serializing reads. + self.recv_messages = Assembler( + max_queue_high, + max_queue_low, + pause=self.recv_flow_control.acquire_nowait, + resume=self.recv_flow_control.release, + ) + + # Deadline for the closing handshake. + self.close_deadline: float | None = None + + # Lock preventing concurrent calls to send_all or send_eof. + self.send_lock = trio.Lock() + + # Protect sending fragmented messages. + self.send_in_progress: trio.Event | None = None + + # Mapping of ping IDs to pong waiters, in chronological order. + self.pending_pings: dict[bytes, tuple[trio.Event, float, bool]] = {} + + self.latency: float = 0.0 + """ + Latency of the connection, in seconds. + + Latency is defined as the round-trip time of the connection. It is + measured by sending a Ping frame and waiting for a matching Pong frame. + Before the first measurement, :attr:`latency` is ``0``. + + By default, websockets enables a :ref:`keepalive ` mechanism + that sends Ping frames automatically at regular intervals. You can also + send Ping frames and measure latency with :meth:`ping`. + """ + + # Exception raised while reading from the connection, to be chained to + # ConnectionClosed in order to show why the TCP connection dropped. + self.recv_exc: BaseException | None = None + + # Completed when the TCP connection is closed and the WebSocket + # connection state becomes CLOSED. + self.stream_closed: trio.Event = trio.Event() + + # Start recv_events only after all attributes are initialized. + self.nursery.start_soon(self.recv_events) + + # Public attributes + + @property + def local_address(self) -> Any: + """ + Local address of the connection. + + For IPv4 connections, this is a ``(host, port)`` tuple. + + The format of the address depends on the address family. + See :meth:`~socket.socket.getsockname`. + + """ + if isinstance(self.stream, trio.SSLStream): # pragma: no cover + stream = self.stream.transport_stream + else: + stream = self.stream + if isinstance(stream, trio.SocketStream): + return stream.socket.getsockname() + else: # pragma: no cover + raise NotImplementedError(f"unsupported stream type: {stream}") + + @property + def remote_address(self) -> Any: + """ + Remote address of the connection. + + For IPv4 connections, this is a ``(host, port)`` tuple. + + The format of the address depends on the address family. + See :meth:`~socket.socket.getpeername`. + + """ + if isinstance(self.stream, trio.SSLStream): # pragma: no cover + stream = self.stream.transport_stream + else: + stream = self.stream + if isinstance(stream, trio.SocketStream): + return stream.socket.getpeername() + else: # pragma: no cover + raise NotImplementedError(f"unsupported stream type: {stream}") + + @property + def state(self) -> State: + """ + State of the WebSocket connection, defined in :rfc:`6455`. + + This attribute is provided for completeness. Typical applications + shouldn't check its value. Instead, they should call :meth:`~recv` or + :meth:`send` and handle :exc:`~websockets.exceptions.ConnectionClosed` + exceptions. + + """ + return self.protocol.state + + @property + def subprotocol(self) -> Subprotocol | None: + """ + Subprotocol negotiated during the opening handshake. + + :obj:`None` if no subprotocol was negotiated. + + """ + return self.protocol.subprotocol + + @property + def close_code(self) -> int | None: + """ + State of the WebSocket connection, defined in :rfc:`6455`. + + This attribute is provided for completeness. Typical applications + shouldn't check its value. Instead, they should inspect attributes + of :exc:`~websockets.exceptions.ConnectionClosed` exceptions. + + """ + return self.protocol.close_code + + @property + def close_reason(self) -> str | None: + """ + State of the WebSocket connection, defined in :rfc:`6455`. + + This attribute is provided for completeness. Typical applications + shouldn't check its value. Instead, they should inspect attributes + of :exc:`~websockets.exceptions.ConnectionClosed` exceptions. + + """ + return self.protocol.close_reason + + # Public methods + + async def __aenter__(self) -> Connection: + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + if exc_type is None: + await self.aclose() + else: + await self.aclose(CloseCode.INTERNAL_ERROR) + + async def __aiter__(self) -> AsyncIterator[Data]: + """ + Iterate on incoming messages. + + The iterator calls :meth:`recv` and yields messages asynchronously in an + infinite loop. + + It exits when the connection is closed normally. It raises a + :exc:`~websockets.exceptions.ConnectionClosedError` exception after a + protocol error or a network failure. + + """ + try: + while True: + yield await self.recv() + except ConnectionClosedOK: + return + + @overload + async def recv(self, decode: Literal[True]) -> str: ... + + @overload + async def recv(self, decode: Literal[False]) -> bytes: ... + + @overload + async def recv(self, decode: bool | None = None) -> Data: ... + + async def recv(self, decode: bool | None = None) -> Data: + """ + Receive the next message. + + When the connection is closed, :meth:`recv` raises + :exc:`~websockets.exceptions.ConnectionClosed`. Specifically, it raises + :exc:`~websockets.exceptions.ConnectionClosedOK` after a normal closure + and :exc:`~websockets.exceptions.ConnectionClosedError` after a protocol + error or a network failure. This is how you detect the end of the + message stream. + + Canceling :meth:`recv` is safe. There's no risk of losing data. The next + invocation of :meth:`recv` will return the next message. + + This makes it possible to enforce a timeout by wrapping :meth:`recv` in + :func:`~trio.move_on_after` or :func:`~trio.fail_after`. + + When the message is fragmented, :meth:`recv` waits until all fragments + are received, reassembles them, and returns the whole message. + + Args: + decode: Set this flag to override the default behavior of returning + :class:`str` or :class:`bytes`. See below for details. + + Returns: + A string (:class:`str`) for a Text_ frame or a bytestring + (:class:`bytes`) for a Binary_ frame. + + .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + + You may override this behavior with the ``decode`` argument: + + * Set ``decode=False`` to disable UTF-8 decoding of Text_ frames and + return a bytestring (:class:`bytes`). This improves performance + when decoding isn't needed, for example if the message contains + JSON and you're using a JSON library that expects a bytestring. + * Set ``decode=True`` to force UTF-8 decoding of Binary_ frames + and return a string (:class:`str`). This may be useful for + servers that send binary frames instead of text frames. + + Raises: + ConnectionClosed: When the connection is closed. + ConcurrencyError: If two coroutines call :meth:`recv` or + :meth:`recv_streaming` concurrently. + + """ + try: + return await self.recv_messages.get(decode) + except EOFError: + pass + # fallthrough + except ConcurrencyError: + raise ConcurrencyError( + "cannot call recv while another coroutine " + "is already running recv or recv_streaming" + ) from None + except UnicodeDecodeError as exc: + async with self.send_context(): + self.protocol.fail( + CloseCode.INVALID_DATA, + f"{exc.reason} at position {exc.start}", + ) + # fallthrough + + # Wait for the protocol state to be CLOSED before accessing close_exc. + await self.stream_closed.wait() + raise self.protocol.close_exc from self.recv_exc + + @overload + def recv_streaming(self, decode: Literal[True]) -> AsyncIterator[str]: ... + + @overload + def recv_streaming(self, decode: Literal[False]) -> AsyncIterator[bytes]: ... + + @overload + def recv_streaming(self, decode: bool | None = None) -> AsyncIterator[Data]: ... + + async def recv_streaming(self, decode: bool | None = None) -> AsyncIterator[Data]: + """ + Receive the next message frame by frame. + + This method is designed for receiving fragmented messages. It returns an + asynchronous iterator that yields each fragment as it is received. This + iterator must be fully consumed. Else, future calls to :meth:`recv` or + :meth:`recv_streaming` will raise + :exc:`~websockets.exceptions.ConcurrencyError`, making the connection + unusable. + + :meth:`recv_streaming` raises the same exceptions as :meth:`recv`. + + Canceling :meth:`recv_streaming` before receiving the first frame is + safe. Canceling it after receiving one or more frames leaves the + iterator in a partially consumed state, making the connection unusable. + Instead, you should close the connection with :meth:`aclose`. + + Args: + decode: Set this flag to override the default behavior of returning + :class:`str` or :class:`bytes`. See below for details. + + Returns: + An iterator of strings (:class:`str`) for a Text_ frame or + bytestrings (:class:`bytes`) for a Binary_ frame. + + .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + + You may override this behavior with the ``decode`` argument: + + * Set ``decode=False`` to disable UTF-8 decoding of Text_ frames + and return bytestrings (:class:`bytes`). This may be useful to + optimize performance when decoding isn't needed. + * Set ``decode=True`` to force UTF-8 decoding of Binary_ frames + and return strings (:class:`str`). This is useful for servers + that send binary frames instead of text frames. + + Raises: + ConnectionClosed: When the connection is closed. + ConcurrencyError: If two coroutines call :meth:`recv` or + :meth:`recv_streaming` concurrently. + + """ + try: + async for frame in self.recv_messages.get_iter(decode): + yield frame + return + except EOFError: + pass + # fallthrough + except ConcurrencyError: + raise ConcurrencyError( + "cannot call recv_streaming while another coroutine " + "is already running recv or recv_streaming" + ) from None + except UnicodeDecodeError as exc: + async with self.send_context(): + self.protocol.fail( + CloseCode.INVALID_DATA, + f"{exc.reason} at position {exc.start}", + ) + # fallthrough + + # Wait for the protocol state to be CLOSED before accessing close_exc. + await self.stream_closed.wait() + raise self.protocol.close_exc from self.recv_exc + + async def send( + self, + message: DataLike | Iterable[DataLike] | AsyncIterable[DataLike], + text: bool | None = None, + ) -> None: + """ + Send a message. + + A string (:class:`str`) is sent as a Text_ frame. A bytestring or + bytes-like object (:class:`bytes`, :class:`bytearray`, or + :class:`memoryview`) is sent as a Binary_ frame. + + .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + + You may override this behavior with the ``text`` argument: + + * Set ``text=True`` to send an UTF-8 bytestring or bytes-like object + (:class:`bytes`, :class:`bytearray`, or :class:`memoryview`) in a + Text_ frame. This improves performance when the message is already + UTF-8 encoded, for example if the message contains JSON and you're + using a JSON library that produces a bytestring. + * Set ``text=False`` to send a string (:class:`str`) in a Binary_ + frame. This may be useful for servers that expect binary frames + instead of text frames. + + :meth:`send` also accepts an iterable or asynchronous iterable of + strings, bytestrings, or bytes-like objects to enable fragmentation_. + Each item is treated as a message fragment and sent in its own frame. + All items must be of the same type, or else :meth:`send` will raise a + :exc:`TypeError` and the connection will be closed. + + .. _fragmentation: https://datatracker.ietf.org/doc/html/rfc6455#section-5.4 + + :meth:`send` rejects dict-like objects because this is often an error. + (If you really want to send the keys of a dict-like object as fragments, + call its :meth:`~dict.keys` method and pass the result to :meth:`send`.) + + Canceling :meth:`send` is discouraged. Instead, you should close the + connection with :meth:`aclose`. Indeed, there are only two situations + where :meth:`send` may yield control to the event loop and then get + canceled; in both cases, :meth:`aclose` has the same effect and is + more clear: + + 1. The write buffer is full. If you don't want to wait until enough + data is sent, your only alternative is to close the connection. + :meth:`aclose` will likely time out then abort the TCP connection. + 2. ``message`` is an asynchronous iterator that yields control. + Stopping in the middle of a fragmented message will cause a + protocol error and the connection will be closed. + + When the connection is closed, :meth:`send` raises + :exc:`~websockets.exceptions.ConnectionClosed`. Specifically, it + raises :exc:`~websockets.exceptions.ConnectionClosedOK` after a normal + connection closure and + :exc:`~websockets.exceptions.ConnectionClosedError` after a protocol + error or a network failure. + + Args: + message: Message to send. + + Raises: + ConnectionClosed: When the connection is closed. + TypeError: If ``message`` doesn't have a supported type. + + """ + # While sending a fragmented message, prevent sending other messages + # until all fragments are sent. + while self.send_in_progress is not None: + await self.send_in_progress.wait() + + # Unfragmented message -- this case must be handled first because + # strings and bytes-like objects are iterable. + + if isinstance(message, str): + async with self.send_context(): + if text is False: + self.protocol.send_binary(message.encode()) + else: + self.protocol.send_text(message.encode()) + + elif isinstance(message, BytesLike): + async with self.send_context(): + if text is True: + self.protocol.send_text(message) + else: + self.protocol.send_binary(message) + + # Catch a common mistake -- passing a dict to send(). + + elif isinstance(message, Mapping): + raise TypeError("data is a dict-like object") + + # Fragmented message -- regular iterator. + + elif isinstance(message, Iterable): + chunks = iter(message) + try: + chunk = next(chunks) + except StopIteration: + return + + assert self.send_in_progress is None + self.send_in_progress = trio.Event() + try: + # First fragment. + if isinstance(chunk, str): + async with self.send_context(): + if text is False: + self.protocol.send_binary(chunk.encode(), fin=False) + else: + self.protocol.send_text(chunk.encode(), fin=False) + encode = True + elif isinstance(chunk, BytesLike): + async with self.send_context(): + if text is True: + self.protocol.send_text(chunk, fin=False) + else: + self.protocol.send_binary(chunk, fin=False) + encode = False + else: + raise TypeError("iterable must contain bytes or str") + + # Other fragments + for chunk in chunks: + if isinstance(chunk, str) and encode: + async with self.send_context(): + self.protocol.send_continuation(chunk.encode(), fin=False) + elif isinstance(chunk, BytesLike) and not encode: + async with self.send_context(): + self.protocol.send_continuation(chunk, fin=False) + else: + raise TypeError("iterable must contain uniform types") + + # Final fragment. + async with self.send_context(): + self.protocol.send_continuation(b"", fin=True) + + except Exception: + # We're half-way through a fragmented message and we can't + # complete it. This makes the connection unusable. + async with self.send_context(): + self.protocol.fail( + CloseCode.INTERNAL_ERROR, + "error in fragmented message", + ) + raise + + finally: + self.send_in_progress.set() + self.send_in_progress = None + + # Fragmented message -- async iterator. + + elif isinstance(message, AsyncIterable): + achunks = aiter(message) + try: + chunk = await anext(achunks) + except StopAsyncIteration: + return + + assert self.send_in_progress is None + self.send_in_progress = trio.Event() + try: + # First fragment. + if isinstance(chunk, str): + if text is False: + async with self.send_context(): + self.protocol.send_binary(chunk.encode(), fin=False) + else: + async with self.send_context(): + self.protocol.send_text(chunk.encode(), fin=False) + encode = True + elif isinstance(chunk, BytesLike): + if text is True: + async with self.send_context(): + self.protocol.send_text(chunk, fin=False) + else: + async with self.send_context(): + self.protocol.send_binary(chunk, fin=False) + encode = False + else: + raise TypeError("async iterable must contain bytes or str") + + # Other fragments + async for chunk in achunks: + if isinstance(chunk, str) and encode: + async with self.send_context(): + self.protocol.send_continuation(chunk.encode(), fin=False) + elif isinstance(chunk, BytesLike) and not encode: + async with self.send_context(): + self.protocol.send_continuation(chunk, fin=False) + else: + raise TypeError("async iterable must contain uniform types") + + # Final fragment. + async with self.send_context(): + self.protocol.send_continuation(b"", fin=True) + + except Exception: + # We're half-way through a fragmented message and we can't + # complete it. This makes the connection unusable. + async with self.send_context(): + self.protocol.fail( + CloseCode.INTERNAL_ERROR, + "error in fragmented message", + ) + raise + + finally: + self.send_in_progress.set() + self.send_in_progress = None + + else: + raise TypeError("data must be str, bytes, iterable, or async iterable") + + async def aclose( + self, + code: CloseCode | int = CloseCode.NORMAL_CLOSURE, + reason: str = "", + ) -> None: + """ + Perform the closing handshake. + + :meth:`aclose` waits for the other end to complete the handshake and + for the TCP connection to terminate. + + :meth:`aclose` is idempotent: it doesn't do anything once the + connection is closed. + + Args: + code: WebSocket close code. + reason: WebSocket close reason. + + """ + try: + # The context manager takes care of waiting for the TCP connection + # to terminate after calling a method that sends a close frame. + async with self.send_context(): + if self.send_in_progress is not None: + self.protocol.fail( + CloseCode.INTERNAL_ERROR, + "close during fragmented message", + ) + else: + self.protocol.send_close(code, reason) + except ConnectionClosed: + # Ignore ConnectionClosed exceptions raised from send_context(). + # They mean that the connection is closed, which was the goal. + pass + # Safety net: enforce the semantics of trio.abc.AsyncResource.aclose(). + except BaseException: # pragma: no cover + await trio.aclose_forcefully(self.stream) + + async def wait_closed(self) -> None: + """ + Wait until the connection is closed. + + :meth:`wait_closed` waits for the closing handshake to complete and for + the TCP connection to terminate. + + """ + await self.stream_closed.wait() + + async def ping( + self, + data: DataLike | None = None, + ack_on_close: bool = False, + ) -> trio.Event: + """ + Send a Ping_. + + .. _Ping: https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.2 + + A ping may serve as a keepalive or as a check that the remote endpoint + received all messages up to this point + + Args: + data: Payload of the ping. A :class:`str` will be encoded to UTF-8. + If ``data`` is :obj:`None`, the payload is four random bytes. + ack_on_close: when this option is :obj:`True`, the event will also + be set when the connection is closed. While this avoids getting + stuck waiting for a pong that will never arrive, it requires + checking that the state of the connection is still ``OPEN`` to + confirm that a pong was received, rather than the connection + being closed. + + Returns: + An event that will be set when the corresponding pong is received. + You can ignore it if you don't intend to wait. + + :: + + pong_received = await ws.ping() + # only if you want to wait for the corresponding pong + await pong_received.wait() + + Raises: + ConnectionClosed: When the connection is closed. + ConcurrencyError: If another ping was sent with the same data and + the corresponding pong wasn't received yet. + + """ + if isinstance(data, BytesLike): + data = bytes(data) + elif isinstance(data, str): + data = data.encode() + elif data is not None: + raise TypeError("data must be str or bytes-like") + + async with self.send_context(): + # Protect against duplicates if a payload is explicitly set. + if data in self.pending_pings: + raise ConcurrencyError("already waiting for a pong with the same data") + + # Generate a unique random payload otherwise. + while data is None or data in self.pending_pings: + data = struct.pack("!I", random.getrandbits(32)) + + pong_received = trio.Event() + ping_timestamp = trio.current_time() + self.pending_pings[data] = (pong_received, ping_timestamp, ack_on_close) + self.protocol.send_ping(data) + return pong_received + + async def pong(self, data: DataLike = b"") -> None: + """ + Send a Pong_. + + .. _Pong: https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.3 + + An unsolicited pong may serve as a unidirectional heartbeat. + + Args: + data: Payload of the pong. A :class:`str` will be encoded to UTF-8. + + Raises: + ConnectionClosed: When the connection is closed. + + """ + if isinstance(data, BytesLike): + data = bytes(data) + elif isinstance(data, str): + data = data.encode() + else: + raise TypeError("data must be str or bytes-like") + + async with self.send_context(): + self.protocol.send_pong(data) + + # Private methods + + def process_event(self, event: Event) -> None: + """ + Process one incoming event. + + This method is overridden in subclasses to handle the handshake. + + """ + assert isinstance(event, Frame) + if event.opcode in DATA_OPCODES: + self.recv_messages.put(event) + + if event.opcode is Opcode.PONG: + self.acknowledge_pings(bytes(event.data)) + + def acknowledge_pings(self, data: bytes) -> None: + """ + Acknowledge pings when receiving a pong. + + """ + # Ignore unsolicited pong. + if data not in self.pending_pings: + return + + pong_timestamp = trio.current_time() + + # Sending a pong for only the most recent ping is legal. + # Acknowledge all previous pings too in that case. + ping_id = None + ping_ids = [] + for ping_id, ( + pong_received, + ping_timestamp, + _ack_on_close, + ) in self.pending_pings.items(): + ping_ids.append(ping_id) + pong_received.set() + if ping_id == data: + self.latency = pong_timestamp - ping_timestamp + break + else: + raise AssertionError("solicited pong not found in pings") + + # Remove acknowledged pings from self.pending_pings. + for ping_id in ping_ids: + del self.pending_pings[ping_id] + + def acknowledge_pending_pings(self) -> None: + """ + Acknowledge pending pings when the connection is closed. + + """ + assert self.protocol.state is CLOSED + + for pong_received, _ping_timestamp, ack_on_close in self.pending_pings.values(): + if ack_on_close: + pong_received.set() + + self.pending_pings.clear() + + async def keepalive(self) -> None: + """ + Send a Ping frame and wait for a Pong frame at regular intervals. + + """ + assert self.ping_interval is not None + try: + while True: + # If self.ping_timeout > self.latency > self.ping_interval, + # pings will be sent immediately after receiving pongs. + # The period will be longer than self.ping_interval. + with trio.move_on_after(self.ping_interval - self.latency): + await self.stream_closed.wait() + break + + try: + pong_received = await self.ping(ack_on_close=True) + except ConnectionClosed: + break + if self.debug: + self.logger.debug("% sent keepalive ping") + + if self.ping_timeout is not None: + with trio.move_on_after(self.ping_timeout) as cancel_scope: + await pong_received.wait() + if self.debug: + self.logger.debug("% received keepalive pong") + if cancel_scope.cancelled_caught: + if self.debug: + self.logger.debug("- timed out waiting for keepalive pong") + async with self.send_context(): + self.protocol.fail( + CloseCode.INTERNAL_ERROR, + "keepalive ping timeout", + ) + break + except Exception: + self.logger.error("keepalive ping failed", exc_info=True) + + def start_keepalive(self) -> None: + """ + Run :meth:`keepalive` in a task, unless keepalive is disabled. + + """ + if self.ping_interval is not None: + self.nursery.start_soon(self.keepalive) + + async def recv_events(self) -> None: + """ + Read incoming data from the stream and process events. + + Run this method in a task as long as the connection is alive. + + ``recv_events()`` exits immediately when ``self.stream`` is closed. + + """ + try: + while True: + try: + # If the assembler buffer is full, block until it drains. + async with self.recv_flow_control: + pass + data = await self.stream.receive_some() + except Exception as exc: + if self.debug: + self.logger.debug( + "! error while receiving data", + exc_info=True, + ) + # When the closing handshake is initiated by our side, + # recv() may block until send_context() closes the stream. + # In that case, send_context() already set recv_exc. + # Calling set_recv_exc() avoids overwriting it. + self.set_recv_exc(exc) + break + + if data == b"": + break + + # Feed incoming data to the protocol. + self.protocol.receive_data(data) + + # This isn't expected to raise an exception. + events = self.protocol.events_received() + + # Write outgoing data to the stream. + try: + await self.send_data() + except Exception as exc: + if self.debug: + self.logger.debug( + "! error while sending data", + exc_info=True, + ) + # Similarly to the above, avoid overriding an exception + # set by send_context(), in case of a race condition + # i.e. send_context() closes the transport after recv() + # returns above but before send_data() calls send(). + self.set_recv_exc(exc) + break + + # If needed, set the close deadline based on the close timeout. + if self.protocol.close_expected(): + if self.close_deadline is None and self.close_timeout is not None: + self.close_deadline = trio.current_time() + self.close_timeout + + # If self.send_data raised an exception, then events are lost. + # Given that automatic responses write small amounts of data, + # this should be uncommon, so we don't handle the edge case. + + for event in events: + # This isn't expected to raise an exception. + self.process_event(event) + + # Breaking out of the while True: ... loop means that we believe + # that the stream doesn't work anymore. + + # Feed the end of the data stream to the protocol. + self.protocol.receive_eof() + + # This isn't expected to raise an exception. + events = self.protocol.events_received() + + # There is no error handling because send_data() can only write + # the end of the data stream here and it handles errors itself. + await self.send_data() + + # This code path is triggered when receiving an HTTP response + # without a Content-Length header. This is the only case where + # reading until EOF generates an event; all other events have + # a known length. Ignore for coverage measurement because tests + # are in test_client.py rather than test_connection.py. + for event in events: # pragma: no cover + # This isn't expected to raise an exception. + self.process_event(event) + + except Exception as exc: + # This branch should never run. It's a safety net in case of bugs. + self.logger.error("unexpected internal error", exc_info=True) + self.set_recv_exc(exc) + finally: + # This isn't expected to raise an exception. + await self.close_stream() + + @contextlib.asynccontextmanager + async def send_context( + self, + *, + expected_state: State = OPEN, # CONNECTING during the opening handshake + ) -> AsyncIterator[None]: + """ + Create a context for writing to the connection from user code. + + On entry, :meth:`send_context` checks that the connection is open; on + exit, it writes outgoing data to the socket:: + + async with self.send_context(): + self.protocol.send_text(message.encode()) + + When the connection isn't open on entry, when the connection is expected + to close on exit, or when an unexpected error happens, terminating the + connection, :meth:`send_context` waits until the connection is closed + then raises :exc:`~websockets.exceptions.ConnectionClosed`. + + """ + # Should we wait until the connection is closed? + wait_for_close = False + # Should we close the stream and raise ConnectionClosed? + raise_close_exc = False + # What exception should we chain ConnectionClosed to? + original_exc: BaseException | None = None + + if self.protocol.state is expected_state: + # Let the caller interact with the protocol. + try: + yield + except (ProtocolError, ConcurrencyError): + # The protocol state wasn't changed. Exit immediately. + raise + except Exception as exc: + self.logger.error("unexpected internal error", exc_info=True) + # This branch should never run. It's a safety net in case of + # bugs. Since we don't know what happened, we will close the + # connection and raise the exception to the caller. + wait_for_close = False + raise_close_exc = True + original_exc = exc + else: + # Check if the connection is expected to close soon. + if self.protocol.close_expected(): + wait_for_close = True + # Set the close deadline based on the close timeout. + # Since we tested earlier that protocol.state is OPEN + # (or CONNECTING), self.close_deadline is still None. + assert self.close_deadline is None + if self.close_timeout is not None: + self.close_deadline = trio.current_time() + self.close_timeout + # Write outgoing data to the socket with flow control. + try: + await self.send_data() + except Exception as exc: + if self.debug: + self.logger.debug("! error while sending data", exc_info=True) + # While the only expected exception here is OSError, + # other exceptions would be treated identically. + wait_for_close = False + raise_close_exc = True + original_exc = exc + + else: # self.protocol.state is not expected_state + # Minor layering violation: we assume that the connection + # will be closing soon if it isn't in the expected state. + wait_for_close = True + # Calculate close_deadline if it wasn't set yet. + if self.close_deadline is None and self.close_timeout is not None: + self.close_deadline = trio.current_time() + self.close_timeout + raise_close_exc = True + + # If the connection is expected to close soon and the close timeout + # elapses, close the socket to terminate the connection. + if wait_for_close: + if self.close_deadline is not None: + with trio.move_on_at(self.close_deadline) as cancel_scope: + await self.stream_closed.wait() + if cancel_scope.cancelled_caught: + # There's no risk to overwrite another error because + # original_exc is never set when wait_for_close is True. + assert original_exc is None + original_exc = TimeoutError("timed out while closing connection") + # Set recv_exc before closing the transport in order to get + # proper exception reporting. + raise_close_exc = True + self.set_recv_exc(original_exc) + else: + await self.stream_closed.wait() + + # If an error occurred, close the transport to terminate the connection and + # raise an exception. + if raise_close_exc: + await self.close_stream() + raise self.protocol.close_exc from original_exc + + async def send_data(self) -> None: + """ + Send outgoing data. + + """ + # Serialize calls to send_all(). + async with self.send_lock: + for data in self.protocol.data_to_send(): + if data: + await self.stream.send_all(data) + else: + # Half-close the TCP connection when possible i.e. no TLS. + if isinstance(self.stream, trio.abc.HalfCloseableStream): + if self.debug: + self.logger.debug("x half-closing TCP connection") + try: + await self.stream.send_eof() + except Exception: # pragma: no cover + pass + # Else, close the TCP connection. + else: # pragma: no cover + if self.debug: + self.logger.debug("x closing TCP connection") + await self.stream.aclose() + + def set_recv_exc(self, exc: BaseException | None) -> None: + """ + Set recv_exc, if not set yet. + + """ + if self.recv_exc is None: + self.recv_exc = exc + + async def close_stream(self) -> None: + """ + Shutdown and close stream. Close message assembler. + + Calling close_stream() guarantees that recv_events() terminates. Indeed, + recv_events() may block only on stream.recv() or on recv_messages.put(). + + """ + # Close the stream. + await self.stream.aclose() + + # Calling protocol.receive_eof() is safe because it's idempotent. + # This guarantees that the protocol state becomes CLOSED. + self.protocol.receive_eof() + assert self.protocol.state is CLOSED + + # Abort recv() with a ConnectionClosed exception. + self.recv_messages.close() + + # Acknowledge pings sent with the ack_on_close option. + self.acknowledge_pending_pings() + + # Unblock coroutines waiting on self.stream_closed. + self.stream_closed.set() diff --git a/tests/trio/connection.py b/tests/trio/connection.py new file mode 100644 index 000000000..226f74a3a --- /dev/null +++ b/tests/trio/connection.py @@ -0,0 +1,116 @@ +import contextlib + +import trio + +from websockets.trio.connection import Connection + + +class InterceptingConnection(Connection): + """ + Connection subclass that can intercept outgoing packets. + + By interfacing with this connection, we simulate network conditions + affecting what the component being tested receives during a test. + + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.stream = InterceptingStream(self.stream) + + @contextlib.contextmanager + def delay_frames_sent(self, delay): + """ + Add a delay before sending frames. + + Delays cumulate: they're added before every frame or before EOF. + + """ + assert self.stream.delay_send_all is None + self.stream.delay_send_all = delay + try: + yield + finally: + self.stream.delay_send_all = None + + @contextlib.contextmanager + def delay_eof_sent(self, delay): + """ + Add a delay before sending EOF. + + Delays cumulate: they're added before every frame or before EOF. + + """ + assert self.stream.delay_send_eof is None + self.stream.delay_send_eof = delay + try: + yield + finally: + self.stream.delay_send_eof = None + + @contextlib.contextmanager + def drop_frames_sent(self): + """ + Prevent frames from being sent. + + Since TCP is reliable, sending frames or EOF afterwards is unrealistic. + + """ + assert not self.stream.drop_send_all + self.stream.drop_send_all = True + try: + yield + finally: + self.stream.drop_send_all = False + + @contextlib.contextmanager + def drop_eof_sent(self): + """ + Prevent EOF from being sent. + + Since TCP is reliable, sending frames or EOF afterwards is unrealistic. + + """ + assert not self.stream.drop_send_eof + self.stream.drop_send_eof = True + try: + yield + finally: + self.stream.drop_send_eof = False + + +class InterceptingStream: + """ + Stream wrapper that intercepts calls to ``send_all()`` and ``send_eof()``. + + This is coupled to the implementation, which relies on these two methods. + + """ + + # We cannot delay EOF with trio's virtual streams because close_hook is + # synchronous. We adopt the same approach as the other implementations. + + def __init__(self, stream): + self.stream = stream + self.delay_send_all = None + self.delay_send_eof = None + self.drop_send_all = False + self.drop_send_eof = False + + def __getattr__(self, name): + return getattr(self.stream, name) + + async def send_all(self, data): + if self.delay_send_all is not None: + await trio.sleep(self.delay_send_all) + if not self.drop_send_all: + await self.stream.send_all(data) + + async def send_eof(self): + if self.delay_send_eof is not None: + await trio.sleep(self.delay_send_eof) + if not self.drop_send_eof: + await self.stream.send_eof() + + +trio.abc.HalfCloseableStream.register(InterceptingStream) diff --git a/tests/trio/test_connection.py b/tests/trio/test_connection.py new file mode 100644 index 000000000..98ce2a25a --- /dev/null +++ b/tests/trio/test_connection.py @@ -0,0 +1,1265 @@ +import contextlib +import itertools +import logging +import uuid +from unittest.mock import patch + +import trio.testing + +from websockets.asyncio.compatibility import TimeoutError, aiter, anext +from websockets.exceptions import ( + ConcurrencyError, + ConnectionClosedError, + ConnectionClosedOK, +) +from websockets.frames import CloseCode, Frame, Opcode +from websockets.protocol import CLIENT, SERVER, Protocol, State +from websockets.trio.connection import * + +from ..protocol import RecordingProtocol +from ..utils import MS, alist +from .connection import InterceptingConnection +from .utils import IsolatedTrioTestCase + + +# Connection implements symmetrical behavior between clients and servers. +# All tests run on the client side and the server side to validate this. + + +class ClientConnectionTests(IsolatedTrioTestCase): + LOCAL = CLIENT + REMOTE = SERVER + + async def asyncSetUp(self): + stream, remote_stream = trio.testing.memory_stream_pair() + protocol = Protocol(self.LOCAL) + remote_protocol = RecordingProtocol(self.REMOTE) + self.connection = Connection( + self.nursery, + stream, + protocol, + close_timeout=2 * MS, + ) + self.remote_connection = InterceptingConnection( + self.nursery, + remote_stream, + remote_protocol, + ) + + async def asyncTearDown(self): + await self.remote_connection.aclose() + await self.connection.aclose() + + # Test helpers built upon RecordingProtocol and InterceptingConnection. + + async def assertFrameSent(self, frame): + """Check that a single frame was sent.""" + await trio.testing.wait_all_tasks_blocked() + self.assertEqual(self.remote_connection.protocol.get_frames_rcvd(), [frame]) + + async def assertFramesSent(self, frames): + """Check that several frames were sent.""" + await trio.testing.wait_all_tasks_blocked() + self.assertEqual(self.remote_connection.protocol.get_frames_rcvd(), frames) + + async def assertNoFrameSent(self): + """Check that no frame was sent.""" + await trio.testing.wait_all_tasks_blocked() + self.assertEqual(self.remote_connection.protocol.get_frames_rcvd(), []) + + @contextlib.asynccontextmanager + async def delay_frames_rcvd(self, delay): + """Delay frames before they're received by the connection.""" + with self.remote_connection.delay_frames_sent(delay): + yield + await trio.testing.wait_all_tasks_blocked() + + @contextlib.asynccontextmanager + async def delay_eof_rcvd(self, delay): + """Delay EOF before it's received by the connection.""" + with self.remote_connection.delay_eof_sent(delay): + yield + await trio.testing.wait_all_tasks_blocked() + + @contextlib.asynccontextmanager + async def drop_frames_rcvd(self): + """Drop frames before they're received by the connection.""" + with self.remote_connection.drop_frames_sent(): + yield + await trio.testing.wait_all_tasks_blocked() + + @contextlib.asynccontextmanager + async def drop_eof_rcvd(self): + """Drop EOF before it's received by the connection.""" + with self.remote_connection.drop_eof_sent(): + yield + await trio.testing.wait_all_tasks_blocked() + + # Test __aenter__ and __aexit__. + + async def test_aenter(self): + """__aenter__ returns the connection itself.""" + async with self.connection as connection: + self.assertIs(connection, self.connection) + + async def test_aexit(self): + """__aexit__ closes the connection with code 1000.""" + async with self.connection: + await self.assertNoFrameSent() + await self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe8")) + + async def test_aexit_with_exception(self): + """__aexit__ with an exception closes the connection with code 1011.""" + with self.assertRaises(RuntimeError): + async with self.connection: + raise RuntimeError + await self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xf3")) + + # Test __aiter__. + + async def test_aiter_text(self): + """__aiter__ yields text messages.""" + iterator = aiter(self.connection) + async with contextlib.aclosing(iterator): + await self.remote_connection.send("😀") + self.assertEqual(await anext(iterator), "😀") + await self.remote_connection.send("😀") + self.assertEqual(await anext(iterator), "😀") + + async def test_aiter_binary(self): + """__aiter__ yields binary messages.""" + iterator = aiter(self.connection) + async with contextlib.aclosing(iterator): + await self.remote_connection.send(b"\x01\x02\xfe\xff") + self.assertEqual(await anext(iterator), b"\x01\x02\xfe\xff") + await self.remote_connection.send(b"\x01\x02\xfe\xff") + self.assertEqual(await anext(iterator), b"\x01\x02\xfe\xff") + + async def test_aiter_mixed(self): + """__aiter__ yields a mix of text and binary messages.""" + iterator = aiter(self.connection) + async with contextlib.aclosing(iterator): + await self.remote_connection.send("😀") + self.assertEqual(await anext(iterator), "😀") + await self.remote_connection.send(b"\x01\x02\xfe\xff") + self.assertEqual(await anext(iterator), b"\x01\x02\xfe\xff") + + async def test_aiter_connection_closed_ok(self): + """__aiter__ terminates after a normal closure.""" + iterator = aiter(self.connection) + async with contextlib.aclosing(iterator): + await self.remote_connection.aclose() + with self.assertRaises(StopAsyncIteration): + await anext(iterator) + + async def test_aiter_connection_closed_error(self): + """__aiter__ raises ConnectionClosedError after an error.""" + iterator = aiter(self.connection) + async with contextlib.aclosing(iterator): + await self.remote_connection.aclose(code=CloseCode.INTERNAL_ERROR) + with self.assertRaises(ConnectionClosedError): + await anext(iterator) + + # Test recv. + + async def test_recv_text(self): + """recv receives a text message.""" + await self.remote_connection.send("😀") + self.assertEqual(await self.connection.recv(), "😀") + + async def test_recv_binary(self): + """recv receives a binary message.""" + await self.remote_connection.send(b"\x01\x02\xfe\xff") + self.assertEqual(await self.connection.recv(), b"\x01\x02\xfe\xff") + + async def test_recv_text_as_bytes(self): + """recv receives a text message as bytes.""" + await self.remote_connection.send("😀") + self.assertEqual(await self.connection.recv(decode=False), "😀".encode()) + + async def test_recv_binary_as_text(self): + """recv receives a binary message as a str.""" + await self.remote_connection.send("😀".encode()) + self.assertEqual(await self.connection.recv(decode=True), "😀") + + async def test_recv_fragmented_text(self): + """recv receives a fragmented text message.""" + await self.remote_connection.send(["😀", "😀"]) + self.assertEqual(await self.connection.recv(), "😀😀") + + async def test_recv_fragmented_binary(self): + """recv receives a fragmented binary message.""" + await self.remote_connection.send([b"\x01\x02", b"\xfe\xff"]) + self.assertEqual(await self.connection.recv(), b"\x01\x02\xfe\xff") + + async def test_recv_connection_closed_ok(self): + """recv raises ConnectionClosedOK after a normal closure.""" + await self.remote_connection.aclose() + with self.assertRaises(ConnectionClosedOK): + await self.connection.recv() + + async def test_recv_connection_closed_error(self): + """recv raises ConnectionClosedError after an error.""" + await self.remote_connection.aclose(code=CloseCode.INTERNAL_ERROR) + with self.assertRaises(ConnectionClosedError): + await self.connection.recv() + + async def test_recv_non_utf8_text(self): + """recv receives a non-UTF-8 text message.""" + await self.remote_connection.send(b"\x01\x02\xfe\xff", text=True) + with self.assertRaises(ConnectionClosedError) as raised: + await self.connection.recv() + self.assertEqual(raised.exception.sent.code, CloseCode.INVALID_DATA) + + async def test_recv_during_recv(self): + """recv raises ConcurrencyError when called concurrently.""" + async with trio.open_nursery() as nursery: + nursery.start_soon(self.connection.recv) + await trio.testing.wait_all_tasks_blocked() + try: + with self.assertRaises(ConcurrencyError) as raised: + await self.connection.recv() + finally: + nursery.cancel_scope.cancel() + self.assertEqual( + str(raised.exception), + "cannot call recv while another coroutine " + "is already running recv or recv_streaming", + ) + + async def test_recv_during_recv_streaming(self): + """recv raises ConcurrencyError when called concurrently with recv_streaming.""" + async with trio.open_nursery() as nursery: + nursery.start_soon(alist, self.connection.recv_streaming()) + await trio.testing.wait_all_tasks_blocked() + try: + with self.assertRaises(ConcurrencyError) as raised: + await self.connection.recv() + finally: + nursery.cancel_scope.cancel() + self.assertEqual( + str(raised.exception), + "cannot call recv while another coroutine " + "is already running recv or recv_streaming", + ) + + async def test_recv_cancellation_before_receiving(self): + """recv can be canceled before receiving a message.""" + async with trio.open_nursery() as nursery: + nursery.start_soon(self.connection.recv) + await trio.testing.wait_all_tasks_blocked() + nursery.cancel_scope.cancel() + + # Running recv again receives the next message. + await self.remote_connection.send("😀") + self.assertEqual(await self.connection.recv(), "😀") + + async def test_recv_cancellation_while_receiving(self): + """recv can be canceled while receiving a fragmented message.""" + gate = trio.Event() + + async def fragments(): + yield "⏳" + await gate.wait() + yield "⌛️" + + self.nursery.start_soon(self.remote_connection.send, fragments()) + await trio.testing.wait_all_tasks_blocked() + + async with trio.open_nursery() as nursery: + nursery.start_soon(self.connection.recv) + await trio.testing.wait_all_tasks_blocked() + nursery.cancel_scope.cancel() + + gate.set() + + # Running recv again receives the complete message. + self.assertEqual(await self.connection.recv(), "⏳⌛️") + + # Test recv_streaming. + + async def test_recv_streaming_text(self): + """recv_streaming receives a text message.""" + await self.remote_connection.send("😀") + self.assertEqual( + await alist(self.connection.recv_streaming()), + ["😀"], + ) + + async def test_recv_streaming_binary(self): + """recv_streaming receives a binary message.""" + await self.remote_connection.send(b"\x01\x02\xfe\xff") + self.assertEqual( + await alist(self.connection.recv_streaming()), + [b"\x01\x02\xfe\xff"], + ) + + async def test_recv_streaming_text_as_bytes(self): + """recv_streaming receives a text message as bytes.""" + await self.remote_connection.send("😀") + self.assertEqual( + await alist(self.connection.recv_streaming(decode=False)), + ["😀".encode()], + ) + + async def test_recv_streaming_binary_as_str(self): + """recv_streaming receives a binary message as a str.""" + await self.remote_connection.send("😀".encode()) + self.assertEqual( + await alist(self.connection.recv_streaming(decode=True)), + ["😀"], + ) + + async def test_recv_streaming_fragmented_text(self): + """recv_streaming receives a fragmented text message.""" + await self.remote_connection.send(["😀", "😀"]) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.connection.recv_streaming()), + ["😀", "😀", ""], + ) + + async def test_recv_streaming_fragmented_binary(self): + """recv_streaming receives a fragmented binary message.""" + await self.remote_connection.send([b"\x01\x02", b"\xfe\xff"]) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.connection.recv_streaming()), + [b"\x01\x02", b"\xfe\xff", b""], + ) + + async def test_recv_streaming_connection_closed_ok(self): + """recv_streaming raises ConnectionClosedOK after a normal closure.""" + await self.remote_connection.aclose() + with self.assertRaises(ConnectionClosedOK): + async for _ in self.connection.recv_streaming(): + self.fail("did not raise") + + async def test_recv_streaming_connection_closed_error(self): + """recv_streaming raises ConnectionClosedError after an error.""" + await self.remote_connection.aclose(code=CloseCode.INTERNAL_ERROR) + with self.assertRaises(ConnectionClosedError): + async for _ in self.connection.recv_streaming(): + self.fail("did not raise") + + async def test_recv_streaming_non_utf8_text(self): + """recv_streaming receives a non-UTF-8 text message.""" + await self.remote_connection.send(b"\x01\x02\xfe\xff", text=True) + with self.assertRaises(ConnectionClosedError) as raised: + await alist(self.connection.recv_streaming()) + self.assertEqual(raised.exception.sent.code, CloseCode.INVALID_DATA) + + async def test_recv_streaming_during_recv(self): + """recv_streaming raises ConcurrencyError when called concurrently with recv.""" + async with trio.open_nursery() as nursery: + nursery.start_soon(self.connection.recv) + await trio.testing.wait_all_tasks_blocked() + try: + with self.assertRaises(ConcurrencyError) as raised: + async for _ in self.connection.recv_streaming(): + self.fail("did not raise") + finally: + nursery.cancel_scope.cancel() + self.assertEqual( + str(raised.exception), + "cannot call recv_streaming while another coroutine " + "is already running recv or recv_streaming", + ) + + async def test_recv_streaming_during_recv_streaming(self): + """recv_streaming raises ConcurrencyError when called concurrently.""" + async with trio.open_nursery() as nursery: + nursery.start_soon(alist, self.connection.recv_streaming()) + await trio.testing.wait_all_tasks_blocked() + try: + with self.assertRaises(ConcurrencyError) as raised: + async for _ in self.connection.recv_streaming(): + self.fail("did not raise") + finally: + nursery.cancel_scope.cancel() + self.assertEqual( + str(raised.exception), + r"cannot call recv_streaming while another coroutine " + r"is already running recv or recv_streaming", + ) + + async def test_recv_streaming_cancellation_before_receiving(self): + """recv_streaming can be canceled before receiving a message.""" + async with trio.open_nursery() as nursery: + nursery.start_soon(alist, self.connection.recv_streaming()) + await trio.testing.wait_all_tasks_blocked() + nursery.cancel_scope.cancel() + + # Running recv_streaming again receives the next message. + await self.remote_connection.send(["😀", "😀"]) + self.assertEqual( + await alist(self.connection.recv_streaming()), + ["😀", "😀", ""], + ) + + async def test_recv_streaming_cancellation_while_receiving(self): + """recv_streaming cannot be canceled while receiving a fragmented message.""" + gate = trio.Event() + + async def fragments(): + yield "⏳" + await gate.wait() + yield "⌛️" + + self.nursery.start_soon(self.remote_connection.send, fragments()) + await trio.testing.wait_all_tasks_blocked() + + async with trio.open_nursery() as nursery: + nursery.start_soon(alist, self.connection.recv_streaming()) + await trio.testing.wait_all_tasks_blocked() + nursery.cancel_scope.cancel() + + gate.set() + await trio.testing.wait_all_tasks_blocked() + + # Running recv_streaming again fails. + with self.assertRaises(ConcurrencyError): + async for _ in self.connection.recv_streaming(): + self.fail("did not raise") + + # Test send. + + async def test_send_text(self): + """send sends a text message.""" + await self.connection.send("😀") + self.assertEqual(await self.remote_connection.recv(), "😀") + + async def test_send_binary(self): + """send sends a binary message.""" + await self.connection.send(b"\x01\x02\xfe\xff") + self.assertEqual(await self.remote_connection.recv(), b"\x01\x02\xfe\xff") + + async def test_send_binary_from_str(self): + """send sends a binary message from a str.""" + await self.connection.send("😀", text=False) + self.assertEqual(await self.remote_connection.recv(), "😀".encode()) + + async def test_send_text_from_bytes(self): + """send sends a text message from bytes.""" + await self.connection.send("😀".encode(), text=True) + self.assertEqual(await self.remote_connection.recv(), "😀") + + async def test_send_fragmented_text(self): + """send sends a fragmented text message.""" + await self.connection.send(["😀", "😀"]) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.remote_connection.recv_streaming()), + ["😀", "😀", ""], + ) + + async def test_send_fragmented_binary(self): + """send sends a fragmented binary message.""" + await self.connection.send([b"\x01\x02", b"\xfe\xff"]) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.remote_connection.recv_streaming()), + [b"\x01\x02", b"\xfe\xff", b""], + ) + + async def test_send_fragmented_binary_from_str(self): + """send sends a fragmented binary message from a str.""" + await self.connection.send(["😀", "😀"], text=False) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.remote_connection.recv_streaming()), + ["😀".encode(), "😀".encode(), b""], + ) + + async def test_send_fragmented_text_from_bytes(self): + """send sends a fragmented text message from bytes.""" + await self.connection.send(["😀".encode(), "😀".encode()], text=True) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.remote_connection.recv_streaming()), + ["😀", "😀", ""], + ) + + async def test_send_async_fragmented_text(self): + """send sends a fragmented text message asynchronously.""" + + async def fragments(): + yield "😀" + yield "😀" + + await self.connection.send(fragments()) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.remote_connection.recv_streaming()), + ["😀", "😀", ""], + ) + + async def test_send_async_fragmented_binary(self): + """send sends a fragmented binary message asynchronously.""" + + async def fragments(): + yield b"\x01\x02" + yield b"\xfe\xff" + + await self.connection.send(fragments()) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.remote_connection.recv_streaming()), + [b"\x01\x02", b"\xfe\xff", b""], + ) + + async def test_send_async_fragmented_binary_from_str(self): + """send sends a fragmented binary message from a str asynchronously.""" + + async def fragments(): + yield "😀" + yield "😀" + + await self.connection.send(fragments(), text=False) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.remote_connection.recv_streaming()), + ["😀".encode(), "😀".encode(), b""], + ) + + async def test_send_async_fragmented_text_from_bytes(self): + """send sends a fragmented text message from bytes asynchronously.""" + + async def fragments(): + yield "😀".encode() + yield "😀".encode() + + await self.connection.send(fragments(), text=True) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.remote_connection.recv_streaming()), + ["😀", "😀", ""], + ) + + async def test_send_connection_closed_ok(self): + """send raises ConnectionClosedOK after a normal closure.""" + await self.remote_connection.aclose() + with self.assertRaises(ConnectionClosedOK): + await self.connection.send("😀") + + async def test_send_connection_closed_error(self): + """send raises ConnectionClosedError after an error.""" + await self.remote_connection.aclose(code=CloseCode.INTERNAL_ERROR) + with self.assertRaises(ConnectionClosedError): + await self.connection.send("😀") + + async def test_send_during_send(self): + """send waits for a previous call to send to complete.""" + # This test fails if the guard with send_in_progress is removed + # from send() in the case when message is an AsyncIterable. + gate = trio.Event() + + async def fragments(): + yield "⏳" + await gate.wait() + yield "⌛️" + + self.nursery.start_soon(self.connection.send, fragments()) + await trio.testing.wait_all_tasks_blocked() + await self.assertFrameSent( + Frame(Opcode.TEXT, "⏳".encode(), fin=False), + ) + + self.nursery.start_soon(self.connection.send, "✅") + await trio.testing.wait_all_tasks_blocked() + await self.assertNoFrameSent() + + gate.set() + await trio.testing.wait_all_tasks_blocked() + await self.assertFramesSent( + [ + Frame(Opcode.CONT, "⌛️".encode(), fin=False), + Frame(Opcode.CONT, b"", fin=True), + Frame(Opcode.TEXT, "✅".encode()), + ] + ) + + # test_send_while_send_blocked and test_send_while_send_async_blocked aren't + # implemented because I don't know how to simulate backpressure on writes. + + async def test_send_empty_iterable(self): + """send does nothing when called with an empty iterable.""" + await self.connection.send([]) + await self.connection.aclose() + self.assertEqual(await alist(self.remote_connection), []) + + async def test_send_mixed_iterable(self): + """send raises TypeError when called with an iterable of inconsistent types.""" + with self.assertRaises(TypeError): + await self.connection.send(["😀", b"\xfe\xff"]) + + async def test_send_unsupported_iterable(self): + """send raises TypeError when called with an iterable of unsupported type.""" + with self.assertRaises(TypeError): + await self.connection.send([None]) + + async def test_send_empty_async_iterable(self): + """send does nothing when called with an empty async iterable.""" + + async def fragments(): + return + yield # pragma: no cover + + await self.connection.send(fragments()) + await self.connection.aclose() + self.assertEqual(await alist(self.remote_connection), []) + + async def test_send_mixed_async_iterable(self): + """send raises TypeError when called with an iterable of inconsistent types.""" + + async def fragments(): + yield "😀" + yield b"\xfe\xff" + + iterator = fragments() + async with contextlib.aclosing(iterator): + with self.assertRaises(TypeError): + await self.connection.send(iterator) + + async def test_send_unsupported_async_iterable(self): + """send raises TypeError when called with an iterable of unsupported type.""" + + async def fragments(): + yield None + + iterator = fragments() + async with contextlib.aclosing(iterator): + with self.assertRaises(TypeError): + await self.connection.send(iterator) + + async def test_send_dict(self): + """send raises TypeError when called with a dict.""" + with self.assertRaises(TypeError): + await self.connection.send({"type": "object"}) + + async def test_send_unsupported_type(self): + """send raises TypeError when called with an unsupported type.""" + with self.assertRaises(TypeError): + await self.connection.send(None) + + # Test aclose. + + async def test_aclose(self): + """aclose sends a close frame.""" + await self.connection.aclose() + await self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe8")) + + async def test_aclose_explicit_code_reason(self): + """aclose sends a close frame with a given code and reason.""" + await self.connection.aclose(CloseCode.GOING_AWAY, "bye!") + await self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe9bye!")) + + async def test_aclose_waits_for_close_frame(self): + """aclose waits for a close frame then EOF before returning.""" + t0 = trio.current_time() + async with self.delay_frames_rcvd(MS): + await self.connection.aclose() + t1 = trio.current_time() + + self.assertEqual(self.connection.state, State.CLOSED) + self.assertEqual(self.connection.close_code, CloseCode.NORMAL_CLOSURE) + self.assertGreater(t1 - t0, MS) + + with self.assertRaises(ConnectionClosedOK) as raised: + await self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") + self.assertIsNone(exc.__cause__) + + async def test_aclose_waits_for_connection_closed(self): + """aclose waits for EOF before returning.""" + if self.LOCAL is SERVER: + self.skipTest("only relevant on the client-side") + + t0 = trio.current_time() + async with self.delay_eof_rcvd(MS): + await self.connection.aclose() + t1 = trio.current_time() + + self.assertEqual(self.connection.state, State.CLOSED) + self.assertEqual(self.connection.close_code, CloseCode.NORMAL_CLOSURE) + self.assertGreater(t1 - t0, MS) + + with self.assertRaises(ConnectionClosedOK) as raised: + await self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") + self.assertIsNone(exc.__cause__) + + async def test_aclose_no_timeout_waits_for_close_frame(self): + """aclose without timeout waits for a close frame then EOF before returning.""" + self.connection.close_timeout = None + + t0 = trio.current_time() + async with self.delay_frames_rcvd(MS): + await self.connection.aclose() + t1 = trio.current_time() + + self.assertEqual(self.connection.state, State.CLOSED) + self.assertEqual(self.connection.close_code, CloseCode.NORMAL_CLOSURE) + self.assertGreater(t1 - t0, MS) + + with self.assertRaises(ConnectionClosedOK) as raised: + await self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") + self.assertIsNone(exc.__cause__) + + async def test_aclose_no_timeout_waits_for_connection_closed(self): + """aclose without timeout waits for EOF before returning.""" + if self.LOCAL is SERVER: + self.skipTest("only relevant on the client-side") + + self.connection.close_timeout = None + + t0 = trio.current_time() + async with self.delay_eof_rcvd(MS): + await self.connection.aclose() + t1 = trio.current_time() + + self.assertEqual(self.connection.state, State.CLOSED) + self.assertEqual(self.connection.close_code, CloseCode.NORMAL_CLOSURE) + self.assertGreater(t1 - t0, MS) + + with self.assertRaises(ConnectionClosedOK) as raised: + await self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") + self.assertIsNone(exc.__cause__) + + async def test_close_timeout_waiting_for_close_frame(self): + """aclose times out if no close frame is received.""" + t0 = trio.current_time() + async with self.drop_eof_rcvd(), self.drop_frames_rcvd(): + await self.connection.aclose() + t1 = trio.current_time() + + self.assertEqual(self.connection.state, State.CLOSED) + self.assertEqual(self.connection.close_code, CloseCode.ABNORMAL_CLOSURE) + self.assertGreater(t1 - t0, 2 * MS) + + with self.assertRaises(ConnectionClosedError) as raised: + await self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); no close frame received") + self.assertIsInstance(exc.__cause__, TimeoutError) + + async def test_close_timeout_waiting_for_connection_closed(self): + """aclose times out if EOF isn't received.""" + if self.LOCAL is SERVER: + self.skipTest("only relevant on the client-side") + + t0 = trio.current_time() + async with self.drop_eof_rcvd(): + await self.connection.aclose() + t1 = trio.current_time() + + self.assertEqual(self.connection.state, State.CLOSED) + self.assertEqual(self.connection.close_code, CloseCode.NORMAL_CLOSURE) + self.assertGreater(t1 - t0, 2 * MS) + + with self.assertRaises(ConnectionClosedOK) as raised: + await self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") + self.assertIsInstance(exc.__cause__, TimeoutError) + + async def test_aclose_preserves_queued_messages(self): + """aclose preserves messages buffered in the assembler.""" + await self.remote_connection.send("😀") + await self.connection.aclose() + + self.assertEqual(await self.connection.recv(), "😀") + with self.assertRaises(ConnectionClosedOK): + await self.connection.recv() + + async def test_aclose_idempotency(self): + """aclose does nothing if the connection is already closed.""" + await self.connection.aclose() + await self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe8")) + + await self.connection.aclose() + await self.assertNoFrameSent() + + async def test_aclose_during_recv(self): + """aclose aborts recv when called concurrently with recv.""" + + async def closer(): + await trio.sleep(MS) + await self.connection.aclose() + + self.nursery.start_soon(closer) + with self.assertRaises(ConnectionClosedOK) as raised: + await self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") + self.assertIsNone(exc.__cause__) + + async def test_aclose_during_recv_streaming(self): + """aclose aborts recv_streaming when called concurrently with recv_streaming.""" + + async def closer(): + await trio.sleep(MS) + await self.connection.aclose() + + self.nursery.start_soon(closer) + with self.assertRaises(ConnectionClosedOK) as raised: + async for _ in self.connection.recv_streaming(): + self.fail("did not raise") + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") + self.assertIsNone(exc.__cause__) + + async def test_aclose_during_send(self): + """aclose fails the connection when called concurrently with send.""" + close_gate = trio.Event() + exit_gate = trio.Event() + + async def closer(): + await close_gate.wait() + await self.connection.aclose() + exit_gate.set() + + async def fragments(): + yield "⏳" + close_gate.set() + await exit_gate.wait() + yield "⌛️" + + self.nursery.start_soon(closer) + iterator = fragments() + async with contextlib.aclosing(iterator): + with self.assertRaises(ConnectionClosedError) as raised: + await self.connection.send(iterator) + + exc = raised.exception + self.assertEqual( + str(exc), + "sent 1011 (internal error) close during fragmented message; " + "no close frame received", + ) + self.assertIsNone(exc.__cause__) + + # Test wait_closed. + + async def test_wait_closed(self): + """wait_closed waits for the connection to close.""" + closed = trio.Event() + + async def closer(): + await self.connection.wait_closed() + closed.set() + + self.nursery.start_soon(closer) + await trio.testing.wait_all_tasks_blocked() + self.assertFalse(closed.is_set()) + + await self.connection.aclose() + await trio.testing.wait_all_tasks_blocked() + self.assertTrue(closed.is_set()) + + # Test ping. + + @patch("random.getrandbits") + async def test_ping(self, getrandbits): + """ping sends a ping frame with a random payload.""" + getrandbits.side_effect = itertools.count(1918987876) + await self.connection.ping() + getrandbits.assert_called_once_with(32) + await self.assertFrameSent(Frame(Opcode.PING, b"rand")) + + async def test_ping_explicit_text(self): + """ping sends a ping frame with a payload provided as text.""" + await self.connection.ping("ping") + await self.assertFrameSent(Frame(Opcode.PING, b"ping")) + + async def test_ping_explicit_binary(self): + """ping sends a ping frame with a payload provided as binary.""" + await self.connection.ping(b"ping") + await self.assertFrameSent(Frame(Opcode.PING, b"ping")) + + async def test_acknowledge_ping(self): + """ping is acknowledged by a pong with the same payload.""" + async with self.drop_frames_rcvd(): # drop automatic response to ping + pong_received = await self.connection.ping("this") + await self.remote_connection.pong("this") + with trio.fail_after(MS): + await pong_received.wait() + + async def test_acknowledge_ping_non_matching_pong(self): + """ping isn't acknowledged by a pong with a different payload.""" + async with self.drop_frames_rcvd(): # drop automatic response to ping + pong_received = await self.connection.ping("this") + await self.remote_connection.pong("that") + with self.assertRaises(trio.TooSlowError): + with trio.fail_after(MS): + await pong_received.wait() + + async def test_acknowledge_previous_ping(self): + """ping is acknowledged by a pong for a later ping.""" + async with self.drop_frames_rcvd(): # drop automatic response to ping + pong_received = await self.connection.ping("this") + await self.connection.ping("that") + await self.remote_connection.pong("that") + with trio.fail_after(MS): + await pong_received.wait() + + async def test_acknowledge_ping_on_close(self): + """ping with ack_on_close is acknowledged when the connection is closed.""" + async with self.drop_frames_rcvd(): # drop automatic response to ping + pong_received_aoc = await self.connection.ping("this", ack_on_close=True) + pong_received = await self.connection.ping("that") + await self.connection.aclose() + with trio.fail_after(MS): + await pong_received_aoc.wait() + with self.assertRaises(trio.TooSlowError): + with trio.fail_after(MS): + await pong_received.wait() + + async def test_ping_duplicate_payload(self): + """ping rejects the same payload until receiving the pong.""" + async with self.drop_frames_rcvd(): # drop automatic response to ping + pong_received = await self.connection.ping("idem") + + with self.assertRaises(ConcurrencyError) as raised: + await self.connection.ping("idem") + self.assertEqual( + str(raised.exception), + "already waiting for a pong with the same data", + ) + + await self.remote_connection.pong("idem") + with trio.fail_after(MS): + await pong_received.wait() + + await self.connection.ping("idem") # doesn't raise an exception + + async def test_ping_unsupported_type(self): + """ping raises TypeError when called with an unsupported type.""" + with self.assertRaises(TypeError): + await self.connection.ping([]) + + # Test pong. + + async def test_pong(self): + """pong sends a pong frame.""" + await self.connection.pong() + await self.assertFrameSent(Frame(Opcode.PONG, b"")) + + async def test_pong_explicit_text(self): + """pong sends a pong frame with a payload provided as text.""" + await self.connection.pong("pong") + await self.assertFrameSent(Frame(Opcode.PONG, b"pong")) + + async def test_pong_explicit_binary(self): + """pong sends a pong frame with a payload provided as binary.""" + await self.connection.pong(b"pong") + await self.assertFrameSent(Frame(Opcode.PONG, b"pong")) + + async def test_pong_unsupported_type(self): + """pong raises TypeError when called with an unsupported type.""" + with self.assertRaises(TypeError): + await self.connection.pong([]) + + # Test keepalive. + + def keepalive_task_is_running(self): + return any( + task.name == "websockets.trio.connection.Connection.keepalive" + for task in self.nursery.child_tasks + ) + + @patch("random.getrandbits") + async def test_keepalive(self, getrandbits): + """keepalive sends pings at ping_interval and measures latency.""" + getrandbits.side_effect = itertools.count(1918987876) + self.connection.ping_interval = 3 * MS + self.connection.start_keepalive() + self.assertTrue(self.keepalive_task_is_running()) + self.assertEqual(self.connection.latency, 0) + # 3 ms: keepalive() sends a ping frame. + # 3.x ms: a pong frame is received. + await trio.sleep(4 * MS) + # 4 ms: check that the ping frame was sent. + await self.assertFrameSent(Frame(Opcode.PING, b"rand")) + self.assertGreater(self.connection.latency, 0) + self.assertLess(self.connection.latency, MS) + + async def test_disable_keepalive(self): + """keepalive is disabled when ping_interval is None.""" + self.connection.ping_interval = None + self.connection.start_keepalive() + self.assertFalse(self.keepalive_task_is_running()) + + @patch("random.getrandbits") + async def test_keepalive_times_out(self, getrandbits): + """keepalive closes the connection if ping_timeout elapses.""" + getrandbits.side_effect = itertools.count(1918987876) + self.connection.ping_interval = 4 * MS + self.connection.ping_timeout = 2 * MS + async with self.drop_frames_rcvd(): + self.connection.start_keepalive() + # 4 ms: keepalive() sends a ping frame. + # 4.x ms: a pong frame is dropped. + await trio.sleep(5 * MS) + # 6 ms: no pong frame is received; the connection is closed. + await trio.sleep(3 * MS) + # 8 ms: check that the connection is closed. + self.assertEqual(self.connection.state, State.CLOSED) + + @patch("random.getrandbits") + async def test_keepalive_ignores_timeout(self, getrandbits): + """keepalive ignores timeouts if ping_timeout isn't set.""" + getrandbits.side_effect = itertools.count(1918987876) + self.connection.ping_interval = 4 * MS + self.connection.ping_timeout = None + async with self.drop_frames_rcvd(): + self.connection.start_keepalive() + # 4 ms: keepalive() sends a ping frame. + # 4.x ms: a pong frame is dropped. + await trio.sleep(5 * MS) + # 6 ms: no pong frame is received; the connection remains open. + await trio.sleep(3 * MS) + # 8 ms: check that the connection is still open. + self.assertEqual(self.connection.state, State.OPEN) + + async def test_keepalive_terminates_while_sleeping(self): + """keepalive task terminates while waiting to send a ping.""" + self.connection.ping_interval = 3 * MS + self.connection.start_keepalive() + await trio.testing.wait_all_tasks_blocked() + self.assertTrue(self.keepalive_task_is_running()) + await self.connection.aclose() + await trio.testing.wait_all_tasks_blocked() + self.assertFalse(self.keepalive_task_is_running()) + + async def test_keepalive_terminates_when_sending_ping_fails(self): + """keepalive task terminates when sending a ping fails.""" + self.connection.ping_interval = MS + self.connection.start_keepalive() + self.assertTrue(self.keepalive_task_is_running()) + async with self.drop_eof_rcvd(), self.drop_frames_rcvd(): + await self.connection.aclose() + await trio.testing.wait_all_tasks_blocked() + self.assertFalse(self.keepalive_task_is_running()) + + async def test_keepalive_terminates_while_waiting_for_pong(self): + """keepalive task terminates while waiting to receive a pong.""" + self.connection.ping_interval = MS + self.connection.ping_timeout = 4 * MS + async with self.drop_frames_rcvd(): + self.connection.start_keepalive() + # 1 ms: keepalive() sends a ping frame. + # 1.x ms: a pong frame is dropped. + await trio.sleep(2 * MS) + # 2 ms: close the connection before ping_timeout elapses. + await self.connection.aclose() + await trio.testing.wait_all_tasks_blocked() + self.assertFalse(self.keepalive_task_is_running()) + + async def test_keepalive_reports_errors(self): + """keepalive reports unexpected errors in logs.""" + self.connection.ping_interval = 2 * MS + self.connection.start_keepalive() + # Inject a fault when waiting to receive a pong. + with self.assertLogs("websockets", logging.ERROR) as logs: + with patch("trio.Event.wait", side_effect=Exception("BOOM")): + # 2 ms: keepalive() sends a ping frame. + # 2.x ms: a pong frame is dropped. + await trio.sleep(3 * MS) + self.assertEqual( + [record.getMessage() for record in logs.records], + ["keepalive ping failed"], + ) + self.assertEqual( + [str(record.exc_info[1]) for record in logs.records], + ["BOOM"], + ) + + # Test parameters. + + async def test_close_timeout(self): + """close_timeout parameter configures close timeout.""" + stream, remote_stream = trio.testing.memory_stream_pair() + async with contextlib.aclosing(remote_stream): + connection = Connection( + self.nursery, + stream, + Protocol(self.LOCAL), + close_timeout=42 * MS, + ) + self.assertEqual(connection.close_timeout, 42 * MS) + + async def test_max_queue(self): + """max_queue configures high-water mark of frames buffer.""" + stream, remote_stream = trio.testing.memory_stream_pair() + async with contextlib.aclosing(remote_stream): + connection = Connection( + self.nursery, + stream, + Protocol(self.LOCAL), + max_queue=4, + ) + self.assertEqual(connection.recv_messages.high, 4) + + async def test_max_queue_none(self): + """max_queue disables high-water mark of frames buffer.""" + stream, remote_stream = trio.testing.memory_stream_pair() + async with contextlib.aclosing(remote_stream): + connection = Connection( + self.nursery, + stream, + Protocol(self.LOCAL), + max_queue=None, + ) + self.assertEqual(connection.recv_messages.high, None) + self.assertEqual(connection.recv_messages.low, None) + + async def test_max_queue_tuple(self): + """max_queue configures high-water and low-water marks of frames buffer.""" + stream, remote_stream = trio.testing.memory_stream_pair() + async with contextlib.aclosing(remote_stream): + connection = Connection( + self.nursery, + stream, + Protocol(self.LOCAL), + max_queue=(4, 2), + ) + self.assertEqual(connection.recv_messages.high, 4) + self.assertEqual(connection.recv_messages.low, 2) + + # Test attributes. + + async def test_id(self): + """Connection has an id attribute.""" + self.assertIsInstance(self.connection.id, uuid.UUID) + + async def test_logger(self): + """Connection has a logger attribute.""" + self.assertIsInstance(self.connection.logger, logging.LoggerAdapter) + + @contextlib.asynccontextmanager + async def get_server_and_client_streams(self): + listeners = await trio.open_tcp_listeners(0, host="127.0.0.1") + assert len(listeners) == 1 + listener = listeners[0] + client_stream = await trio.testing.open_stream_to_socket_listener(listener) + client_port = client_stream.socket.getsockname()[1] + server_stream = await listener.accept() + server_port = listener.socket.getsockname()[1] + try: + yield client_stream, server_stream, client_port, server_port + finally: + await server_stream.aclose() + await client_stream.aclose() + await listener.aclose() + + async def test_local_address(self): + """Connection has a local_address attribute.""" + async with self.get_server_and_client_streams() as ( + client_stream, + server_stream, + client_port, + server_port, + ): + stream = {CLIENT: client_stream, SERVER: server_stream}[self.LOCAL] + port = {CLIENT: client_port, SERVER: server_port}[self.LOCAL] + connection = Connection(self.nursery, stream, Protocol(self.LOCAL)) + self.assertEqual(connection.local_address, ("127.0.0.1", port)) + + async def test_remote_address(self): + """Connection has a remote_address attribute.""" + async with self.get_server_and_client_streams() as ( + client_stream, + server_stream, + client_port, + server_port, + ): + stream = {CLIENT: client_stream, SERVER: server_stream}[self.LOCAL] + remote_port = {CLIENT: server_port, SERVER: client_port}[self.LOCAL] + connection = Connection(self.nursery, stream, Protocol(self.LOCAL)) + self.assertEqual(connection.remote_address, ("127.0.0.1", remote_port)) + + async def test_state(self): + """Connection has a state attribute.""" + self.assertIs(self.connection.state, State.OPEN) + + async def test_request(self): + """Connection has a request attribute.""" + self.assertIsNone(self.connection.request) + + async def test_response(self): + """Connection has a response attribute.""" + self.assertIsNone(self.connection.response) + + async def test_subprotocol(self): + """Connection has a subprotocol attribute.""" + self.assertIsNone(self.connection.subprotocol) + + async def test_close_code(self): + """Connection has a close_code attribute.""" + self.assertIsNone(self.connection.close_code) + + async def test_close_reason(self): + """Connection has a close_reason attribute.""" + self.assertIsNone(self.connection.close_reason) + + # Test reporting of network errors. + + async def test_writing_in_recv_events_fails(self): + """Error when responding to incoming frames is correctly reported.""" + # Inject a fault by shutting down the stream for writing — but not the + # stream for reading because that would terminate the connection. + self.connection.stream.send_stream.close() + # Receive a ping. Responding with a pong will fail. + await self.remote_connection.ping() + with self.assertRaises(ConnectionClosedError) as raised: + await self.connection.recv() + self.assertIsInstance(raised.exception.__cause__, trio.ClosedResourceError) + + async def test_writing_in_send_context_fails(self): + """Error when sending outgoing frame is correctly reported.""" + # Inject a fault by shutting down the stream for writing — but not the + # stream for reading because that would terminate the connection. + self.connection.stream.send_stream.close() + # Sending a pong will fail. + with self.assertRaises(ConnectionClosedError) as raised: + await self.connection.pong() + self.assertIsInstance(raised.exception.__cause__, trio.ClosedResourceError) + + # Test safety nets — catching all exceptions in case of bugs. + + @patch("websockets.protocol.Protocol.events_received", side_effect=AssertionError) + async def test_unexpected_failure_in_recv_events(self, events_received): + """Unexpected internal error in recv_events() is correctly reported.""" + await self.remote_connection.send("😀") + # Reading the message will trigger the injected fault. + with self.assertRaises(ConnectionClosedError) as raised: + await self.connection.recv() + self.assertIsInstance(raised.exception.__cause__, AssertionError) + + @patch("websockets.protocol.Protocol.send_text", side_effect=AssertionError) + async def test_unexpected_failure_in_send_context(self, send_text): + """Unexpected internal error in send_context() is correctly reported.""" + # Sending a message will trigger the injected fault. + with self.assertRaises(ConnectionClosedError) as raised: + await self.connection.send("😀") + self.assertIsInstance(raised.exception.__cause__, AssertionError) + + +class ServerConnectionTests(ClientConnectionTests): + LOCAL = SERVER + REMOTE = CLIENT