Skip to content
6 changes: 5 additions & 1 deletion durabletask/aio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(
log_formatter: Optional[logging.Formatter] = None,
secure_channel: bool = False,
interceptors: Optional[Sequence[ClientInterceptor]] = None,
channel_options: Optional[Sequence[tuple[str, Any]]] = None,
):
if interceptors is not None:
interceptors = list(interceptors)
Expand All @@ -46,7 +47,10 @@ def __init__(
interceptors = None

channel = get_grpc_aio_channel(
host_address=host_address, secure_channel=secure_channel, interceptors=interceptors
host_address=host_address,
secure_channel=secure_channel,
interceptors=interceptors,
options=channel_options,
)
self._channel = channel
self._stub = stubs.TaskHubSidecarServiceStub(channel)
Expand Down
16 changes: 14 additions & 2 deletions durabletask/aio/internal/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import grpc
from grpc import aio as grpc_aio
from grpc.aio import ChannelArgumentType

from durabletask.internal.shared import (
INSECURE_PROTOCOLS,
Expand All @@ -24,7 +25,16 @@ def get_grpc_aio_channel(
host_address: Optional[str],
secure_channel: bool = False,
interceptors: Optional[Sequence[ClientInterceptor]] = None,
options: Optional[ChannelArgumentType] = None,
) -> grpc_aio.Channel:
"""create a grpc asyncio channel

Args:
host_address: The host address of the gRPC server. If None, uses the default address.
secure_channel: Whether to use a secure channel (TLS/SSL). Defaults to False.
interceptors: Optional sequence of client interceptors to apply to the channel.
options: Optional sequence of gRPC channel options as (key, value) tuples. Keys defined in https://grpc.github.io/grpc/core/group__grpc__arg__keys.html
"""
if host_address is None:
host_address = get_default_host_address()

Expand All @@ -42,9 +52,11 @@ def get_grpc_aio_channel(

if secure_channel:
channel = grpc_aio.secure_channel(
host_address, grpc.ssl_channel_credentials(), interceptors=interceptors
host_address, grpc.ssl_channel_credentials(), interceptors=interceptors, options=options
)
else:
channel = grpc_aio.insecure_channel(host_address, interceptors=interceptors)
channel = grpc_aio.insecure_channel(
host_address, interceptors=interceptors, options=options
)

return channel
6 changes: 5 additions & 1 deletion durabletask/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def __init__(
log_formatter: Optional[logging.Formatter] = None,
secure_channel: bool = False,
interceptors: Optional[Sequence[shared.ClientInterceptor]] = None,
channel_options: Optional[Sequence[tuple[str, Any]]] = None,
):
# If the caller provided metadata, we need to create a new interceptor for it and
# add it to the list of interceptors.
Expand All @@ -121,7 +122,10 @@ def __init__(
interceptors = None

channel = shared.get_grpc_channel(
host_address=host_address, secure_channel=secure_channel, interceptors=interceptors
host_address=host_address,
secure_channel=secure_channel,
interceptors=interceptors,
options=channel_options,
)
self._stub = stubs.TaskHubSidecarServiceStub(channel)
self._logger = shared.get_logger("client", log_handler, log_formatter)
Expand Down
18 changes: 13 additions & 5 deletions durabletask/internal/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ def get_default_host_address() -> str:
Honors environment variables if present; otherwise defaults to localhost:4001.

Supported environment variables (checked in order):
- DURABLETASK_GRPC_ENDPOINT (e.g., "localhost:4001", "grpcs://host:443")
- DURABLETASK_GRPC_HOST and DURABLETASK_GRPC_PORT
- DAPR_GRPC_ENDPOINT (e.g., "localhost:4001", "grpcs://host:443")
- DAPR_GRPC_HOST/DAPR_RUNTIME_HOST and DAPR_GRPC_PORT
"""

# Full endpoint overrides
Expand All @@ -54,7 +54,16 @@ def get_grpc_channel(
host_address: Optional[str],
secure_channel: bool = False,
interceptors: Optional[Sequence[ClientInterceptor]] = None,
options: Optional[Sequence[tuple[str, Any]]] = None,
) -> grpc.Channel:
"""create a grpc channel

Args:
host_address: The host address of the gRPC server. If None, uses the default address (as defined in get_default_host_address above).
secure_channel: Whether to use a secure channel (TLS/SSL). Defaults to False.
interceptors: Optional sequence of client interceptors to apply to the channel.
options: Optional sequence of gRPC channel options as (key, value) tuples. Keys defined in https://grpc.github.io/grpc/core/group__grpc__arg__keys.html
"""
if host_address is None:
host_address = get_default_host_address()

Expand All @@ -72,11 +81,10 @@ def get_grpc_channel(
host_address = host_address[len(protocol) :]
break

# Create the base channel
if secure_channel:
channel = grpc.secure_channel(host_address, grpc.ssl_channel_credentials())
channel = grpc.secure_channel(host_address, grpc.ssl_channel_credentials(), options=options)
else:
channel = grpc.insecure_channel(host_address)
channel = grpc.insecure_channel(host_address, options=options)

# Apply interceptors ONLY if they exist
if interceptors:
Expand Down
7 changes: 6 additions & 1 deletion durabletask/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,13 +223,15 @@ def __init__(
secure_channel: bool = False,
interceptors: Optional[Sequence[shared.ClientInterceptor]] = None,
concurrency_options: Optional[ConcurrencyOptions] = None,
channel_options: Optional[Sequence[tuple[str, Any]]] = None,
):
self._registry = _Registry()
self._host_address = host_address if host_address else shared.get_default_host_address()
self._logger = shared.get_logger("worker", log_handler, log_formatter)
self._shutdown = Event()
self._is_running = False
self._secure_channel = secure_channel
self._channel_options = channel_options

# Use provided concurrency options or create default ones
self._concurrency_options = (
Expand Down Expand Up @@ -306,7 +308,10 @@ def create_fresh_connection():
current_stub = None
try:
current_channel = shared.get_grpc_channel(
self._host_address, self._secure_channel, self._interceptors
self._host_address,
self._secure_channel,
self._interceptors,
options=self._channel_options,
)
current_stub = stubs.TaskHubSidecarServiceStub(current_channel)
current_stub.Hello(empty_pb2.Empty())
Expand Down
79 changes: 64 additions & 15 deletions tests/durabletask/test_client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from unittest.mock import ANY, patch
from unittest.mock import patch

from durabletask.internal.grpc_interceptor import DefaultClientInterceptorImpl
from durabletask.internal.shared import get_default_host_address, get_grpc_channel
Expand All @@ -11,7 +11,9 @@
def test_get_grpc_channel_insecure():
with patch("grpc.insecure_channel") as mock_channel:
get_grpc_channel(HOST_ADDRESS, False, interceptors=INTERCEPTORS)
mock_channel.assert_called_once_with(HOST_ADDRESS)
args, kwargs = mock_channel.call_args
assert args[0] == HOST_ADDRESS
assert "options" in kwargs and kwargs["options"] is None


def test_get_grpc_channel_secure():
Expand All @@ -20,13 +22,18 @@ def test_get_grpc_channel_secure():
patch("grpc.ssl_channel_credentials") as mock_credentials,
):
get_grpc_channel(HOST_ADDRESS, True, interceptors=INTERCEPTORS)
mock_channel.assert_called_once_with(HOST_ADDRESS, mock_credentials.return_value)
args, kwargs = mock_channel.call_args
assert args[0] == HOST_ADDRESS
assert args[1] == mock_credentials.return_value
assert "options" in kwargs and kwargs["options"] is None


def test_get_grpc_channel_default_host_address():
with patch("grpc.insecure_channel") as mock_channel:
get_grpc_channel(None, False, interceptors=INTERCEPTORS)
mock_channel.assert_called_once_with(get_default_host_address())
args, kwargs = mock_channel.call_args
assert args[0] == get_default_host_address()
assert "options" in kwargs and kwargs["options"] is None


def test_get_grpc_channel_with_metadata():
Expand All @@ -35,7 +42,9 @@ def test_get_grpc_channel_with_metadata():
patch("grpc.intercept_channel") as mock_intercept_channel,
):
get_grpc_channel(HOST_ADDRESS, False, interceptors=INTERCEPTORS)
mock_channel.assert_called_once_with(HOST_ADDRESS)
args, kwargs = mock_channel.call_args
assert args[0] == HOST_ADDRESS
assert "options" in kwargs and kwargs["options"] is None
mock_intercept_channel.assert_called_once()

# Capture and check the arguments passed to intercept_channel()
Expand All @@ -54,40 +63,80 @@ def test_grpc_channel_with_host_name_protocol_stripping():

prefix = "grpc://"
get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS)
mock_insecure_channel.assert_called_with(host_name)
args, kwargs = mock_insecure_channel.call_args
assert args[0] == host_name
assert "options" in kwargs and kwargs["options"] is None

prefix = "http://"
get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS)
mock_insecure_channel.assert_called_with(host_name)
args, kwargs = mock_insecure_channel.call_args
assert args[0] == host_name
assert "options" in kwargs and kwargs["options"] is None

prefix = "HTTP://"
get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS)
mock_insecure_channel.assert_called_with(host_name)
args, kwargs = mock_insecure_channel.call_args
assert args[0] == host_name
assert "options" in kwargs and kwargs["options"] is None

prefix = "GRPC://"
get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS)
mock_insecure_channel.assert_called_with(host_name)
args, kwargs = mock_insecure_channel.call_args
assert args[0] == host_name
assert "options" in kwargs and kwargs["options"] is None

prefix = ""
get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS)
mock_insecure_channel.assert_called_with(host_name)
args, kwargs = mock_insecure_channel.call_args
assert args[0] == host_name
assert "options" in kwargs and kwargs["options"] is None

prefix = "grpcs://"
get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS)
mock_secure_channel.assert_called_with(host_name, ANY)
args, kwargs = mock_secure_channel.call_args
assert args[0] == host_name
assert "options" in kwargs and kwargs["options"] is None

prefix = "https://"
get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS)
mock_secure_channel.assert_called_with(host_name, ANY)
args, kwargs = mock_secure_channel.call_args
assert args[0] == host_name
assert "options" in kwargs and kwargs["options"] is None

prefix = "HTTPS://"
get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS)
mock_secure_channel.assert_called_with(host_name, ANY)
args, kwargs = mock_secure_channel.call_args
assert args[0] == host_name
assert "options" in kwargs and kwargs["options"] is None

prefix = "GRPCS://"
get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS)
mock_secure_channel.assert_called_with(host_name, ANY)
args, kwargs = mock_secure_channel.call_args
assert args[0] == host_name
assert "options" in kwargs and kwargs["options"] is None

prefix = ""
get_grpc_channel(prefix + host_name, True, interceptors=INTERCEPTORS)
mock_secure_channel.assert_called_with(host_name, ANY)
args, kwargs = mock_secure_channel.call_args
assert args[0] == host_name
assert "options" in kwargs and kwargs["options"] is None


def test_sync_channel_passes_base_options_and_max_lengths():
base_options = [
("grpc.max_send_message_length", 1234),
("grpc.max_receive_message_length", 5678),
("grpc.primary_user_agent", "durabletask-tests"),
]
with patch("grpc.insecure_channel") as mock_channel:
get_grpc_channel(HOST_ADDRESS, False, options=base_options)
# Ensure called with options kwarg
assert mock_channel.call_count == 1
args, kwargs = mock_channel.call_args
assert args[0] == HOST_ADDRESS
assert "options" in kwargs
opts = kwargs["options"]
# Check our base options made it through
assert ("grpc.max_send_message_length", 1234) in opts
assert ("grpc.max_receive_message_length", 5678) in opts
assert ("grpc.primary_user_agent", "durabletask-tests") in opts
Loading
Loading