diff --git a/durabletask/client.py b/durabletask/client.py index 7a72e1a..b155bd6 100644 --- a/durabletask/client.py +++ b/durabletask/client.py @@ -127,10 +127,14 @@ def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator[TInpu name = orchestrator if isinstance(orchestrator, str) else task.get_name(orchestrator) + input_pb = ( + wrappers_pb2.StringValue(value=shared.to_json(input)) if input is not None else None + ) + req = pb.CreateInstanceRequest( name=name, instanceId=instance_id if instance_id else uuid.uuid4().hex, - input=wrappers_pb2.StringValue(value=shared.to_json(input)) if input is not None else None, + input=input_pb, scheduledStartTimestamp=helpers.new_timestamp(start_at) if start_at else None, version=wrappers_pb2.StringValue(value=""), orchestrationIdReusePolicy=reuse_id_policy, diff --git a/durabletask/internal/helpers.py b/durabletask/internal/helpers.py index 48ab14b..682ab89 100644 --- a/durabletask/internal/helpers.py +++ b/durabletask/internal/helpers.py @@ -158,18 +158,25 @@ def get_string_value(val: Optional[str]) -> Optional[wrappers_pb2.StringValue]: def new_complete_orchestration_action( - id: int, - status: pb.OrchestrationStatus, - result: Optional[str] = None, - failure_details: Optional[pb.TaskFailureDetails] = None, - carryover_events: Optional[list[pb.HistoryEvent]] = None) -> pb.OrchestratorAction: + id: int, + status: pb.OrchestrationStatus, + result: Optional[str] = None, + failure_details: Optional[pb.TaskFailureDetails] = None, + carryover_events: Optional[list[pb.HistoryEvent]] = None, + router: Optional[pb.TaskRouter] = None, +) -> pb.OrchestratorAction: completeOrchestrationAction = pb.CompleteOrchestrationAction( orchestrationStatus=status, result=get_string_value(result), failureDetails=failure_details, - carryoverEvents=carryover_events) + carryoverEvents=carryover_events, + ) - return pb.OrchestratorAction(id=id, completeOrchestration=completeOrchestrationAction) + return pb.OrchestratorAction( + id=id, + completeOrchestration=completeOrchestrationAction, + router=router, + ) def new_create_timer_action(id: int, fire_at: datetime) -> pb.OrchestratorAction: diff --git a/durabletask/internal/shared.py b/durabletask/internal/shared.py index c0fbe74..d15141f 100644 --- a/durabletask/internal/shared.py +++ b/durabletask/internal/shared.py @@ -4,6 +4,7 @@ import dataclasses import json import logging +import os from types import SimpleNamespace from typing import Any, Optional, Sequence, Union @@ -13,7 +14,7 @@ grpc.UnaryUnaryClientInterceptor, grpc.UnaryStreamClientInterceptor, grpc.StreamUnaryClientInterceptor, - grpc.StreamStreamClientInterceptor + grpc.StreamStreamClientInterceptor, ] # Field name used to indicate that an object was automatically serialized @@ -25,6 +26,27 @@ def get_default_host_address() -> str: + """Resolve the default Durable Task sidecar address. + + Honors environment variables if present; otherwise defaults to localhost:4001. + + Supported environment variables (checked in order): + - DURABLETASK_GRPC_ENDPOINT (e.g., "localhost:4001", "grpcs://host:443") + - DURABLETASK_GRPC_HOST and DURABLETASK_GRPC_PORT + """ + + # Full endpoint overrides + endpoint = os.environ.get("DAPR_GRPC_ENDPOINT") + if endpoint: + return endpoint + + # Host/port split overrides + host = os.environ.get("DAPR_GRPC_HOST") or os.environ.get("DAPR_RUNTIME_HOST") + if host: + port = os.environ.get("DAPR_GRPC_PORT", "4001") + return f"{host}:{port}" + + # Default to durabletask-go default port return "localhost:4001" diff --git a/durabletask/worker.py b/durabletask/worker.py index e8e1fa9..695dc44 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -643,7 +643,10 @@ def set_complete( if result is not None: result_json = result if is_result_encoded else shared.to_json(result) action = ph.new_complete_orchestration_action( - self.next_sequence_number(), status, result_json + self.next_sequence_number(), + status, + result_json, + router=pb.TaskRouter(sourceAppID=self._app_id) if self._app_id else None, ) self._pending_actions[action.id] = action @@ -660,6 +663,7 @@ def set_failed(self, ex: Exception): pb.ORCHESTRATION_STATUS_FAILED, None, ph.new_failure_details(ex), + router=pb.TaskRouter(sourceAppID=self._app_id) if self._app_id else None, ) self._pending_actions[action.id] = action @@ -692,11 +696,10 @@ def get_actions(self) -> list[pb.OrchestratorAction]: action = ph.new_complete_orchestration_action( self.next_sequence_number(), pb.ORCHESTRATION_STATUS_CONTINUED_AS_NEW, - result=shared.to_json(self._new_input) - if self._new_input is not None - else None, + result=shared.to_json(self._new_input) if self._new_input is not None else None, failure_details=None, carryover_events=carryover_events, + router=pb.TaskRouter(sourceAppID=self._app_id) if self._app_id else None, ) return [action] else: diff --git a/tests/durabletask/test_orchestration_e2e.py b/tests/durabletask/test_orchestration_e2e.py index 76ec355..f5651ff 100644 --- a/tests/durabletask/test_orchestration_e2e.py +++ b/tests/durabletask/test_orchestration_e2e.py @@ -366,6 +366,45 @@ def orchestrator(ctx: task.OrchestrationContext, input: int): assert all_results == [1, 2, 3, 4, 5] +def test_continue_as_new_with_activity_e2e(): + """E2E test for continue_as_new with activities (generator-based).""" + activity_results = [] + + def double_activity(ctx: task.ActivityContext, value: int) -> int: + """Activity that doubles the value.""" + result = value * 2 + activity_results.append(result) + return result + + def orchestrator(ctx: task.OrchestrationContext, counter: int): + # Call activity to process the counter + processed = yield ctx.call_activity(double_activity, input=counter) + + # Continue as new up to 3 times + if counter < 3: + ctx.continue_as_new(counter + 1, save_events=False) + else: + return {"counter": counter, "processed": processed, "all_results": activity_results} + + with worker.TaskHubGrpcWorker() as w: + w.add_activity(double_activity) + w.add_orchestrator(orchestrator) + w.start() + + task_hub_client = client.TaskHubGrpcClient() + id = task_hub_client.schedule_new_orchestration(orchestrator, input=1) + + state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + + output = json.loads(state.serialized_output) + # Should have called activity 3 times with input values 1, 2, 3 + assert activity_results == [2, 4, 6] + assert output["counter"] == 3 + assert output["processed"] == 6 + + # NOTE: This test fails when running against durabletask-go with sqlite because the sqlite backend does not yet # support orchestration ID reuse. This gap is being tracked here: # https://github.com/microsoft/durabletask-go/issues/42