diff --git a/ld_eventsource/config/connect_strategy.py b/ld_eventsource/config/connect_strategy.py index 1b59f67..4770831 100644 --- a/ld_eventsource/config/connect_strategy.py +++ b/ld_eventsource/config/connect_strategy.py @@ -1,11 +1,13 @@ from __future__ import annotations +from dataclasses import dataclass from logging import Logger from typing import Callable, Iterator, Optional, Union from urllib3 import PoolManager -from ld_eventsource.http import _HttpClientImpl, _HttpConnectParams +from ld_eventsource.http import (DynamicQueryParams, _HttpClientImpl, + _HttpConnectParams) class ConnectStrategy: @@ -38,6 +40,7 @@ def http( headers: Optional[dict] = None, pool: Optional[PoolManager] = None, urllib3_request_options: Optional[dict] = None, + query_params: Optional[DynamicQueryParams] = None ) -> ConnectStrategy: """ Creates the default HTTP implementation, specifying request parameters. @@ -47,9 +50,11 @@ def http( :param pool: optional urllib3 ``PoolManager`` to provide an HTTP client :param urllib3_request_options: optional ``kwargs`` to add to the ``request`` call; these can include any parameters supported by ``urllib3``, such as ``timeout`` + :param query_params: optional callable that can be used to affect query parameters + dynamically for each connection attempt """ return _HttpConnectStrategy( - _HttpConnectParams(url, headers, pool, urllib3_request_options) + _HttpConnectParams(url, headers, pool, urllib3_request_options, query_params) ) diff --git a/ld_eventsource/http.py b/ld_eventsource/http.py index 8d1096b..c97ed6d 100644 --- a/ld_eventsource/http.py +++ b/ld_eventsource/http.py @@ -1,5 +1,6 @@ from logging import Logger from typing import Callable, Iterator, Optional, Tuple +from urllib.parse import parse_qsl, urlencode, urlsplit, urlunsplit from urllib3 import PoolManager from urllib3.exceptions import MaxRetryError @@ -9,6 +10,12 @@ _CHUNK_SIZE = 10000 +DynamicQueryParams = Callable[[], dict[str, str]] +""" +A callable that returns a dictionary of query parameters to add to the URL. +This can be used to modify query parameters dynamically for each connection attempt. +""" + class _HttpConnectParams: def __init__( @@ -17,16 +24,22 @@ def __init__( headers: Optional[dict] = None, pool: Optional[PoolManager] = None, urllib3_request_options: Optional[dict] = None, + query_params: Optional[DynamicQueryParams] = None ): self.__url = url self.__headers = headers self.__pool = pool self.__urllib3_request_options = urllib3_request_options + self.__query_params = query_params @property def url(self) -> str: return self.__url + @property + def query_params(self) -> Optional[DynamicQueryParams]: + return self.__query_params + @property def headers(self) -> Optional[dict]: return self.__headers @@ -48,7 +61,16 @@ def __init__(self, params: _HttpConnectParams, logger: Logger): self.__logger = logger def connect(self, last_event_id: Optional[str]) -> Tuple[Iterator[bytes], Callable]: - self.__logger.info("Connecting to stream at %s" % self.__params.url) + url = self.__params.url + if self.__params.query_params is not None: + qp = self.__params.query_params() + if qp: + url_parts = list(urlsplit(url)) + query = dict(parse_qsl(url_parts[3])) + query.update(qp) + url_parts[3] = urlencode(query) + url = urlunsplit(url_parts) + self.__logger.info("Connecting to stream at %s" % url) headers = self.__params.headers.copy() if self.__params.headers else {} headers['Cache-Control'] = 'no-cache' @@ -67,7 +89,7 @@ def connect(self, last_event_id: Optional[str]) -> Tuple[Iterator[bytes], Callab try: resp = self.__pool.request( 'GET', - self.__params.url, + url, preload_content=False, retries=Retry( total=None, read=0, connect=0, status=0, other=0, redirect=3 diff --git a/ld_eventsource/testing/http_util.py b/ld_eventsource/testing/http_util.py index 80e61dd..03b348a 100644 --- a/ld_eventsource/testing/http_util.py +++ b/ld_eventsource/testing/http_util.py @@ -113,7 +113,7 @@ def do_POST(self): def _do_request(self): server_wrapper = self.server.server_wrapper server_wrapper.requests.put(MockServerRequest(self)) - handler = server_wrapper.matchers.get(self.path) + handler = server_wrapper.matchers.get(self.path.split("?")[0], None) if handler: handler.write(self) else: diff --git a/ld_eventsource/testing/test_http_connect_strategy_with_sse_client.py b/ld_eventsource/testing/test_http_connect_strategy_with_sse_client.py index 5b6bbdf..2502fe7 100644 --- a/ld_eventsource/testing/test_http_connect_strategy_with_sse_client.py +++ b/ld_eventsource/testing/test_http_connect_strategy_with_sse_client.py @@ -1,3 +1,5 @@ +from urllib.parse import parse_qsl + from ld_eventsource import * from ld_eventsource.config import * from ld_eventsource.testing.helpers import * @@ -56,6 +58,48 @@ def test_sse_client_reconnects_after_socket_closed(): assert event2.data == 'data2' +def test_sse_client_allows_modifying_query_params_dynamically(): + count = 0 + + def dynamic_query_params() -> dict[str, str]: + nonlocal count + count += 1 + params = {'count': str(count)} + if count > 1: + params['option'] = 'updated' + + return params + + with start_server() as server: + with make_stream() as stream1: + with make_stream() as stream2: + server.for_path('/', SequentialHandler(stream1, stream2)) + stream1.push("event: a\ndata: data1\nid: id123\n\n") + stream2.push("event: b\ndata: data2\n\n") + with SSEClient( + connect=ConnectStrategy.http(f"{server.uri}?basis=unchanging&option=initial", query_params=dynamic_query_params), + error_strategy=ErrorStrategy.always_continue(), + initial_retry_delay=0, + ) as client: + client.start() + next(client.events) + stream1.close() + next(client.events) + r1 = server.await_request() + r1_query_params = dict(parse_qsl(r1.path.split('?', 1)[1])) + + # Ensure we can add, retain, and modify query parameters + assert r1_query_params.get('count') == '1' + assert r1_query_params.get('basis') == 'unchanging' + assert r1_query_params.get('option') == 'initial' + + r2 = server.await_request() + r2_query_params = dict(parse_qsl(r2.path.split('?', 1)[1])) + assert r2_query_params.get('count') == '2' + assert r2_query_params.get('basis') == 'unchanging' + assert r2_query_params.get('option') == 'updated' + + def test_sse_client_sends_last_event_id_on_reconnect(): with start_server() as server: with make_stream() as stream1: