Skip to content

Commit 6148040

Browse files
committed
Fix: AlephClient class could not use unix sockets
1 parent cd49ef8 commit 6148040

File tree

2 files changed

+67
-8
lines changed

2 files changed

+67
-8
lines changed

src/aleph/sdk/client.py

Lines changed: 50 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
MultipleMessagesError,
5555
)
5656
from .models import MessagesResponse
57-
from .utils import get_message_type_value
57+
from .utils import check_unix_socket_valid, get_message_type_value
5858

5959
logger = logging.getLogger(__name__)
6060

@@ -94,14 +94,14 @@ def func_caller(*args, **kwargs):
9494

9595

9696
async def run_async_watcher(
97-
*args, output_queue: queue.Queue, api_server: str, **kwargs
97+
*args, output_queue: queue.Queue, api_server: Optional[str], **kwargs
9898
):
9999
async with AlephClient(api_server=api_server) as session:
100100
async for message in session.watch_messages(*args, **kwargs):
101101
output_queue.put(message)
102102

103103

104-
def watcher_thread(output_queue: queue.Queue, api_server: str, args, kwargs):
104+
def watcher_thread(output_queue: queue.Queue, api_server: Optional[str], args, kwargs):
105105
asyncio.run(
106106
run_async_watcher(
107107
output_queue=output_queue, api_server=api_server, *args, **kwargs
@@ -443,9 +443,39 @@ class AlephClient:
443443
api_server: str
444444
http_session: aiohttp.ClientSession
445445

446-
def __init__(self, api_server: str):
447-
self.api_server = api_server
448-
self.http_session = aiohttp.ClientSession(base_url=api_server)
446+
def __init__(
447+
self,
448+
api_server: Optional[str],
449+
api_unix_socket: Optional[str] = None,
450+
allow_unix_sockets: bool = True,
451+
timeout: Optional[aiohttp.ClientTimeout] = None,
452+
):
453+
"""AlephClient can use HTTP(S) or HTTP over Unix sockets.
454+
Unix sockets are used when running inside a virtual machine,
455+
and can be shared across containers in a more secure way than TCP ports.
456+
"""
457+
self.api_server = api_server or settings.API_HOST
458+
if not self.api_server:
459+
raise ValueError("Missing API host")
460+
461+
unix_socket_path = api_unix_socket or settings.API_UNIX_SOCKET
462+
if unix_socket_path and allow_unix_sockets:
463+
check_unix_socket_valid(unix_socket_path)
464+
connector = aiohttp.UnixConnector(path=unix_socket_path)
465+
else:
466+
connector = None
467+
468+
# ClientSession timeout defaults to a private sentinel object and may not be None.
469+
self.http_session = (
470+
aiohttp.ClientSession(
471+
base_url=self.api_server, connector=connector, timeout=timeout
472+
)
473+
if timeout
474+
else aiohttp.ClientSession(
475+
base_url=self.api_server,
476+
connector=connector,
477+
)
478+
)
449479

450480
def __enter__(self) -> UserSessionSync:
451481
return UserSessionSync(async_session=self)
@@ -825,8 +855,20 @@ class AuthenticatedAlephClient(AlephClient):
825855
"channel",
826856
}
827857

828-
def __init__(self, account: Account, api_server: str):
829-
super().__init__(api_server=api_server)
858+
def __init__(
859+
self,
860+
account: Account,
861+
api_server: Optional[str],
862+
api_unix_socket: Optional[str] = None,
863+
allow_unix_sockets: bool = True,
864+
timeout: Optional[aiohttp.ClientTimeout] = None,
865+
):
866+
super().__init__(
867+
api_server=api_server,
868+
api_unix_socket=api_unix_socket,
869+
allow_unix_sockets=allow_unix_sockets,
870+
timeout=timeout,
871+
)
830872
self.account = account
831873

832874
def __enter__(self) -> "AuthenticatedUserSessionSync":

src/aleph/sdk/utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import errno
12
import logging
23
import os
34
from pathlib import Path
@@ -59,3 +60,19 @@ def get_message_type_value(message_type: Type[GenericMessage]) -> MessageType:
5960
"""Returns the value of the 'type' field of a message type class."""
6061
type_literal = message_type.__annotations__["type"]
6162
return type_literal.__args__[0] # Get the value from a Literal
63+
64+
65+
def check_unix_socket_valid(unix_socket_path: str) -> bool:
66+
"""Check that a unix socket exists at the given path, or raise a FileNotFoundError."""
67+
path = Path(unix_socket_path)
68+
if not path.exists():
69+
raise FileNotFoundError(
70+
errno.ENOENT, os.strerror(errno.ENOENT), unix_socket_path
71+
)
72+
if not path.is_socket():
73+
raise FileNotFoundError(
74+
errno.ENOTSOCK,
75+
os.strerror(errno.ENOENT),
76+
unix_socket_path,
77+
)
78+
return True

0 commit comments

Comments
 (0)