From 554499999ea40205ee3342351083d551b2eb6d35 Mon Sep 17 00:00:00 2001 From: Matthew Keeler Date: Wed, 29 Oct 2025 12:13:31 -0400 Subject: [PATCH] feat: Support dynamic query parameters on reconnect The `ConnectStrategy` can be created with a `query_params` callable. This callable will return a set of parameters that should be used to update any static query parameters initially configured. This functionality enables FDv2 selector behavior where we want to resume from our last known checkpoint. --- ld_eventsource/config/connect_strategy.py | 9 +++- ld_eventsource/http.py | 26 ++++++++++- ld_eventsource/testing/http_util.py | 2 +- ...t_http_connect_strategy_with_sse_client.py | 44 +++++++++++++++++++ 4 files changed, 76 insertions(+), 5 deletions(-) diff --git a/ld_eventsource/config/connect_strategy.py b/ld_eventsource/config/connect_strategy.py index 1b59f67..4770831 100644 --- a/ld_eventsource/config/connect_strategy.py +++ b/ld_eventsource/config/connect_strategy.py @@ -1,11 +1,13 @@ from __future__ import annotations +from dataclasses import dataclass from logging import Logger from typing import Callable, Iterator, Optional, Union from urllib3 import PoolManager -from ld_eventsource.http import _HttpClientImpl, _HttpConnectParams +from ld_eventsource.http import (DynamicQueryParams, _HttpClientImpl, + _HttpConnectParams) class ConnectStrategy: @@ -38,6 +40,7 @@ def http( headers: Optional[dict] = None, pool: Optional[PoolManager] = None, urllib3_request_options: Optional[dict] = None, + query_params: Optional[DynamicQueryParams] = None ) -> ConnectStrategy: """ Creates the default HTTP implementation, specifying request parameters. @@ -47,9 +50,11 @@ def http( :param pool: optional urllib3 ``PoolManager`` to provide an HTTP client :param urllib3_request_options: optional ``kwargs`` to add to the ``request`` call; these can include any parameters supported by ``urllib3``, such as ``timeout`` + :param query_params: optional callable that can be used to affect query parameters + dynamically for each connection attempt """ return _HttpConnectStrategy( - _HttpConnectParams(url, headers, pool, urllib3_request_options) + _HttpConnectParams(url, headers, pool, urllib3_request_options, query_params) ) diff --git a/ld_eventsource/http.py b/ld_eventsource/http.py index 8d1096b..c97ed6d 100644 --- a/ld_eventsource/http.py +++ b/ld_eventsource/http.py @@ -1,5 +1,6 @@ from logging import Logger from typing import Callable, Iterator, Optional, Tuple +from urllib.parse import parse_qsl, urlencode, urlsplit, urlunsplit from urllib3 import PoolManager from urllib3.exceptions import MaxRetryError @@ -9,6 +10,12 @@ _CHUNK_SIZE = 10000 +DynamicQueryParams = Callable[[], dict[str, str]] +""" +A callable that returns a dictionary of query parameters to add to the URL. +This can be used to modify query parameters dynamically for each connection attempt. +""" + class _HttpConnectParams: def __init__( @@ -17,16 +24,22 @@ def __init__( headers: Optional[dict] = None, pool: Optional[PoolManager] = None, urllib3_request_options: Optional[dict] = None, + query_params: Optional[DynamicQueryParams] = None ): self.__url = url self.__headers = headers self.__pool = pool self.__urllib3_request_options = urllib3_request_options + self.__query_params = query_params @property def url(self) -> str: return self.__url + @property + def query_params(self) -> Optional[DynamicQueryParams]: + return self.__query_params + @property def headers(self) -> Optional[dict]: return self.__headers @@ -48,7 +61,16 @@ def __init__(self, params: _HttpConnectParams, logger: Logger): self.__logger = logger def connect(self, last_event_id: Optional[str]) -> Tuple[Iterator[bytes], Callable]: - self.__logger.info("Connecting to stream at %s" % self.__params.url) + url = self.__params.url + if self.__params.query_params is not None: + qp = self.__params.query_params() + if qp: + url_parts = list(urlsplit(url)) + query = dict(parse_qsl(url_parts[3])) + query.update(qp) + url_parts[3] = urlencode(query) + url = urlunsplit(url_parts) + self.__logger.info("Connecting to stream at %s" % url) headers = self.__params.headers.copy() if self.__params.headers else {} headers['Cache-Control'] = 'no-cache' @@ -67,7 +89,7 @@ def connect(self, last_event_id: Optional[str]) -> Tuple[Iterator[bytes], Callab try: resp = self.__pool.request( 'GET', - self.__params.url, + url, preload_content=False, retries=Retry( total=None, read=0, connect=0, status=0, other=0, redirect=3 diff --git a/ld_eventsource/testing/http_util.py b/ld_eventsource/testing/http_util.py index 80e61dd..03b348a 100644 --- a/ld_eventsource/testing/http_util.py +++ b/ld_eventsource/testing/http_util.py @@ -113,7 +113,7 @@ def do_POST(self): def _do_request(self): server_wrapper = self.server.server_wrapper server_wrapper.requests.put(MockServerRequest(self)) - handler = server_wrapper.matchers.get(self.path) + handler = server_wrapper.matchers.get(self.path.split("?")[0], None) if handler: handler.write(self) else: diff --git a/ld_eventsource/testing/test_http_connect_strategy_with_sse_client.py b/ld_eventsource/testing/test_http_connect_strategy_with_sse_client.py index 5b6bbdf..2502fe7 100644 --- a/ld_eventsource/testing/test_http_connect_strategy_with_sse_client.py +++ b/ld_eventsource/testing/test_http_connect_strategy_with_sse_client.py @@ -1,3 +1,5 @@ +from urllib.parse import parse_qsl + from ld_eventsource import * from ld_eventsource.config import * from ld_eventsource.testing.helpers import * @@ -56,6 +58,48 @@ def test_sse_client_reconnects_after_socket_closed(): assert event2.data == 'data2' +def test_sse_client_allows_modifying_query_params_dynamically(): + count = 0 + + def dynamic_query_params() -> dict[str, str]: + nonlocal count + count += 1 + params = {'count': str(count)} + if count > 1: + params['option'] = 'updated' + + return params + + with start_server() as server: + with make_stream() as stream1: + with make_stream() as stream2: + server.for_path('/', SequentialHandler(stream1, stream2)) + stream1.push("event: a\ndata: data1\nid: id123\n\n") + stream2.push("event: b\ndata: data2\n\n") + with SSEClient( + connect=ConnectStrategy.http(f"{server.uri}?basis=unchanging&option=initial", query_params=dynamic_query_params), + error_strategy=ErrorStrategy.always_continue(), + initial_retry_delay=0, + ) as client: + client.start() + next(client.events) + stream1.close() + next(client.events) + r1 = server.await_request() + r1_query_params = dict(parse_qsl(r1.path.split('?', 1)[1])) + + # Ensure we can add, retain, and modify query parameters + assert r1_query_params.get('count') == '1' + assert r1_query_params.get('basis') == 'unchanging' + assert r1_query_params.get('option') == 'initial' + + r2 = server.await_request() + r2_query_params = dict(parse_qsl(r2.path.split('?', 1)[1])) + assert r2_query_params.get('count') == '2' + assert r2_query_params.get('basis') == 'unchanging' + assert r2_query_params.get('option') == 'updated' + + def test_sse_client_sends_last_event_id_on_reconnect(): with start_server() as server: with make_stream() as stream1: