diff --git a/torchrl/trainers/algorithms/configs/__init__.py b/torchrl/trainers/algorithms/configs/__init__.py index f76b83b1387..ca322c84938 100644 --- a/torchrl/trainers/algorithms/configs/__init__.py +++ b/torchrl/trainers/algorithms/configs/__init__.py @@ -85,6 +85,7 @@ ValueModelConfig, ) from torchrl.trainers.algorithms.configs.objectives import ( + GAEConfig, HardUpdateConfig, LossConfig, PPOLossConfig, @@ -126,6 +127,7 @@ InitTrackerConfig, KLRewardTransformConfig, LineariseRewardsConfig, + ModuleTransformConfig, MultiActionConfig, MultiStepTransformConfig, NoopResetEnvConfig, @@ -179,6 +181,18 @@ SGDConfig, SparseAdamConfig, ) +from torchrl.trainers.algorithms.configs.weight_sync_schemes import ( + DistributedWeightSyncSchemeConfig, + MultiProcessWeightSyncSchemeConfig, + NoWeightSyncSchemeConfig, + RayModuleTransformSchemeConfig, + RayWeightSyncSchemeConfig, + RPCWeightSyncSchemeConfig, + SharedMemWeightSyncSchemeConfig, + VLLMDoubleBufferSyncSchemeConfig, + VLLMWeightSyncSchemeConfig, + WeightSyncSchemeConfig, +) from torchrl.trainers.algorithms.configs.weight_update import ( DistributedWeightUpdaterConfig, MultiProcessedWeightUpdaterConfig, @@ -273,6 +287,7 @@ "InitTrackerConfig", "KLRewardTransformConfig", "LineariseRewardsConfig", + "ModuleTransformConfig", "MultiActionConfig", "MultiStepTransformConfig", "NoopResetEnvConfig", @@ -330,6 +345,8 @@ "LossConfig", "PPOLossConfig", "SACLossConfig", + # Value functions + "GAEConfig", # Trainers "PPOTrainerConfig", "SACTrainerConfig", @@ -348,6 +365,17 @@ "RPCWeightUpdaterConfig", "DistributedWeightUpdaterConfig", "vLLMUpdaterConfig", + # Weight Sync Schemes + "WeightSyncSchemeConfig", + "MultiProcessWeightSyncSchemeConfig", + "SharedMemWeightSyncSchemeConfig", + "NoWeightSyncSchemeConfig", + "RayWeightSyncSchemeConfig", + "RayModuleTransformSchemeConfig", + "RPCWeightSyncSchemeConfig", + "DistributedWeightSyncSchemeConfig", + "VLLMWeightSyncSchemeConfig", + "VLLMDoubleBufferSyncSchemeConfig", ] @@ -356,6 +384,10 @@ def _register_configs(): This function is called lazily to avoid GlobalHydra initialization issues during testing. It should be called explicitly when needed. + + To add a new config: + - Write the config class in the appropriate file (e.g. torchrl/trainers/algorithms/configs/transforms.py) and add it to the __all__ list in torchrl/trainers/algorithms/configs/__init__.py + - Register the config in the appropriate group, e.g. cs.store(group="transform", name="new_transform", node=NewTransformConfig) """ cs = ConfigStore.instance() @@ -461,6 +493,7 @@ def _register_configs(): cs.store(group="transform", name="action_discretizer", node=ActionDiscretizerConfig) cs.store(group="transform", name="traj_counter", node=TrajCounterConfig) cs.store(group="transform", name="linearise_rewards", node=LineariseRewardsConfig) + cs.store(group="transform", name="module", node=ModuleTransformConfig) cs.store(group="transform", name="conditional_skip", node=ConditionalSkipConfig) cs.store(group="transform", name="multi_action", node=MultiActionConfig) cs.store(group="transform", name="timer", node=TimerConfig) @@ -487,6 +520,7 @@ def _register_configs(): cs.store(group="transform", name="vip", node=VIPTransformConfig) cs.store(group="transform", name="vip_reward", node=VIPRewardTransformConfig) cs.store(group="transform", name="vec_norm_v2", node=VecNormV2Config) + cs.store(group="transform", name="module", node=ModuleTransformConfig) # ============================================================================= # Loss Configurations @@ -496,6 +530,16 @@ def _register_configs(): cs.store(group="loss", name="ppo", node=PPOLossConfig) cs.store(group="loss", name="sac", node=SACLossConfig) + # ============================================================================= + # Value Function Configurations + # ============================================================================= + + cs.store(group="value", name="gae", node=GAEConfig) + + # ============================================================================= + # Target Net Updater Configurations + # ============================================================================= + cs.store(group="target_net_updater", name="soft", node=SoftUpdateConfig) cs.store(group="target_net_updater", name="hard", node=HardUpdateConfig) @@ -595,6 +639,41 @@ def _register_configs(): ) cs.store(group="weight_updater", name="vllm", node=vLLMUpdaterConfig) + # ============================================================================= + # Weight Sync Scheme Configurations + # ============================================================================= + + cs.store(group="weight_sync_scheme", name="base", node=WeightSyncSchemeConfig) + cs.store( + group="weight_sync_scheme", + name="multiprocess", + node=MultiProcessWeightSyncSchemeConfig, + ) + cs.store( + group="weight_sync_scheme", + name="shared_mem", + node=SharedMemWeightSyncSchemeConfig, + ) + cs.store(group="weight_sync_scheme", name="no_sync", node=NoWeightSyncSchemeConfig) + cs.store(group="weight_sync_scheme", name="ray", node=RayWeightSyncSchemeConfig) + cs.store( + group="weight_sync_scheme", + name="ray_module_transform", + node=RayModuleTransformSchemeConfig, + ) + cs.store(group="weight_sync_scheme", name="rpc", node=RPCWeightSyncSchemeConfig) + cs.store( + group="weight_sync_scheme", + name="distributed", + node=DistributedWeightSyncSchemeConfig, + ) + cs.store(group="weight_sync_scheme", name="vllm", node=VLLMWeightSyncSchemeConfig) + cs.store( + group="weight_sync_scheme", + name="vllm_double_buffer", + node=VLLMDoubleBufferSyncSchemeConfig, + ) + if not sys.version_info < (3, 10): #  type: ignore # noqa _register_configs() diff --git a/torchrl/trainers/algorithms/configs/collectors.py b/torchrl/trainers/algorithms/configs/collectors.py index 9e57b7c19fa..f90f5b6f6b9 100644 --- a/torchrl/trainers/algorithms/configs/collectors.py +++ b/torchrl/trainers/algorithms/configs/collectors.py @@ -51,7 +51,9 @@ class SyncDataCollectorConfig(DataCollectorConfig): cudagraph_policy: Any = None no_cuda_sync: bool = False weight_updater: Any = None + weight_sync_schemes: Any = None track_policy_version: bool = False + local_init_rb: bool = False _target_: str = "torchrl.collectors.SyncDataCollector" _partial_: bool = False @@ -94,7 +96,9 @@ class AsyncDataCollectorConfig(DataCollectorConfig): cudagraph_policy: Any = None no_cuda_sync: bool = False weight_updater: Any = None + weight_sync_schemes: Any = None track_policy_version: bool = False + local_init_rb: bool = False _target_: str = "torchrl.collectors.aSyncDataCollector" _partial_: bool = False @@ -136,7 +140,9 @@ class MultiSyncDataCollectorConfig(DataCollectorConfig): cudagraph_policy: Any = None no_cuda_sync: bool = False weight_updater: Any = None + weight_sync_schemes: Any = None track_policy_version: bool = False + local_init_rb: bool = False _target_: str = "torchrl.collectors.MultiSyncDataCollector" _partial_: bool = False @@ -179,7 +185,9 @@ class MultiaSyncDataCollectorConfig(DataCollectorConfig): cudagraph_policy: Any = None no_cuda_sync: bool = False weight_updater: Any = None + weight_sync_schemes: Any = None track_policy_version: bool = False + local_init_rb: bool = False _target_: str = "torchrl.collectors.MultiaSyncDataCollector" _partial_: bool = False diff --git a/torchrl/trainers/algorithms/configs/data.py b/torchrl/trainers/algorithms/configs/data.py index 08a8eb44cc3..f840c52145b 100644 --- a/torchrl/trainers/algorithms/configs/data.py +++ b/torchrl/trainers/algorithms/configs/data.py @@ -254,6 +254,7 @@ class LazyMemmapStorageConfig(StorageConfig): device: Any = None ndim: int = 1 compilable: bool = False + shared_init: bool = False @dataclass @@ -265,6 +266,7 @@ class LazyTensorStorageConfig(StorageConfig): device: Any = None ndim: int = 1 compilable: bool = False + shared_init: bool = False @dataclass diff --git a/torchrl/trainers/algorithms/configs/modules.py b/torchrl/trainers/algorithms/configs/modules.py index 8ec1a4df984..189d47c2561 100644 --- a/torchrl/trainers/algorithms/configs/modules.py +++ b/torchrl/trainers/algorithms/configs/modules.py @@ -202,6 +202,7 @@ class ModelConfig(ConfigBase): _partial_: bool = False in_keys: Any = None out_keys: Any = None + shared: bool = False def __post_init__(self) -> None: """Post-initialization hook for model configurations.""" @@ -226,7 +227,7 @@ class TensorDictModuleConfig(ModelConfig): def __post_init__(self) -> None: """Post-initialization hook for TensorDict module configurations.""" - super().__post_init__() + return super().__post_init__() @dataclass @@ -312,6 +313,7 @@ def _make_tanh_normal_model(*args, **kwargs): return_log_prob = kwargs.pop("return_log_prob", False) eval_mode = kwargs.pop("eval_mode", False) exploration_type = kwargs.pop("exploration_type", "RANDOM") + shared = kwargs.pop("shared", False) # Now instantiate the network if hasattr(network, "_target_"): @@ -328,6 +330,8 @@ def _make_tanh_normal_model(*args, **kwargs): ) module = TensorDictModule(network, in_keys=in_keys, out_keys=param_keys) + if shared: + module = module.share_memory() # Create ProbabilisticTensorDictModule prob_module = ProbabilisticTensorDictModule( @@ -350,4 +354,7 @@ def _make_value_model(*args, **kwargs): from torchrl.modules import ValueOperator network = kwargs.pop("network") + shared = kwargs.pop("shared", False) + if shared: + network = network.share_memory() return ValueOperator(network, **kwargs) diff --git a/torchrl/trainers/algorithms/configs/objectives.py b/torchrl/trainers/algorithms/configs/objectives.py index a0d1c8fb0d3..6be1bc845b9 100644 --- a/torchrl/trainers/algorithms/configs/objectives.py +++ b/torchrl/trainers/algorithms/configs/objectives.py @@ -148,3 +148,30 @@ class HardUpdateConfig(TargetNetUpdaterConfig): _target_: str = "torchrl.objectives.utils.HardUpdate." value_network_update_interval: int = 1000 + + +@dataclass +class GAEConfig(LossConfig): + """A class to configure a GAELoss.""" + + gamma: float | None = None + lmbda: float | None = None + value_network: Any = None + average_gae: bool = True + differentiable: bool = False + vectorized: bool | None = None + skip_existing: bool | None = None + advantage_key: str | None = None + value_target_key: str | None = None + value_key: str | None = None + shifted: bool = False + device: Any = None + time_dim: int | None = None + auto_reset_env: bool = False + deactivate_vmap: bool = False + _target_: str = "torchrl.objectives.value.GAE" + _partial_: bool = False + + def __post_init__(self) -> None: + """Post-initialization hook for GAELoss configurations.""" + super().__post_init__() diff --git a/torchrl/trainers/algorithms/configs/trainers.py b/torchrl/trainers/algorithms/configs/trainers.py index a35e686d124..fce985f4557 100644 --- a/torchrl/trainers/algorithms/configs/trainers.py +++ b/torchrl/trainers/algorithms/configs/trainers.py @@ -9,10 +9,12 @@ from typing import Any import torch +from tensordict.nn import TensorDictModuleBase from torchrl.collectors import DataCollectorBase from torchrl.objectives.common import LossModule from torchrl.objectives.utils import TargetNetUpdater +from torchrl.objectives.value.advantages import GAE from torchrl.trainers.algorithms.configs.common import ConfigBase from torchrl.trainers.algorithms.ppo import PPOTrainer from torchrl.trainers.algorithms.sac import SACTrainer @@ -54,6 +56,7 @@ class SACTrainerConfig(TrainerConfig): critic_network: Any = None target_net_updater: Any = None async_collection: bool = False + log_timings: bool = False _target_: str = "torchrl.trainers.algorithms.configs.trainers._make_sac_trainer" @@ -87,6 +90,7 @@ def _make_sac_trainer(*args, **kwargs) -> SACTrainer: kwargs.pop("create_env_fn") target_net_updater = kwargs.pop("target_net_updater") async_collection = kwargs.pop("async_collection", False) + log_timings = kwargs.pop("log_timings", False) # Instantiate networks first if actor_network is not None: @@ -152,6 +156,7 @@ def _make_sac_trainer(*args, **kwargs) -> SACTrainer: replay_buffer=replay_buffer, target_net_updater=target_net_updater, async_collection=async_collection, + log_timings=log_timings, ) @@ -161,6 +166,37 @@ class PPOTrainerConfig(TrainerConfig): This class defines the configuration parameters for creating a PPO trainer, including both required and optional fields with sensible defaults. + + Args: + collector: The data collector for gathering training data. + total_frames: Total number of frames to train for. + optim_steps_per_batch: Number of optimization steps per batch. + loss_module: The loss module for computing policy and value losses. + optimizer: The optimizer for training. + logger: Logger for tracking training metrics. + save_trainer_file: File path for saving trainer state. + replay_buffer: Replay buffer for storing data. + frame_skip: Frame skip value for the environment. Default: 1. + clip_grad_norm: Whether to clip gradient norms. Default: True. + clip_norm: Maximum gradient norm value. + progress_bar: Whether to show a progress bar. Default: True. + seed: Random seed for reproducibility. + save_trainer_interval: Interval for saving trainer state. Default: 10000. + log_interval: Interval for logging metrics. Default: 10000. + create_env_fn: Environment creation function. + actor_network: Actor network configuration. + critic_network: Critic network configuration. + num_epochs: Number of epochs per batch. Default: 4. + async_collection: Whether to use async collection. Default: False. + add_gae: Whether to add GAE computation. Default: True. + gae: Custom GAE module configuration. + weight_update_map: Mapping from collector destination paths to trainer source paths. + Required if collector has weight_sync_schemes configured. + Example: {"policy": "loss_module.actor_network", + "replay_buffer.transforms[0]": "loss_module.critic_network"} + log_timings: Whether to automatically log timing information for all hooks. + If True, timing metrics will be logged to the logger (e.g., wandb, tensorboard) + with prefix "time/" (e.g., "time/hook/UpdateWeights"). Default: False. """ collector: Any @@ -183,6 +219,10 @@ class PPOTrainerConfig(TrainerConfig): critic_network: Any = None num_epochs: int = 4 async_collection: bool = False + add_gae: bool = True + gae: Any = None + weight_update_map: dict[str, str] | None = None + log_timings: bool = False _target_: str = "torchrl.trainers.algorithms.configs.trainers._make_ppo_trainer" @@ -213,7 +253,12 @@ def _make_ppo_trainer(*args, **kwargs) -> PPOTrainer: seed = kwargs.pop("seed") actor_network = kwargs.pop("actor_network") critic_network = kwargs.pop("critic_network") + add_gae = kwargs.pop("add_gae", True) + gae = kwargs.pop("gae") create_env_fn = kwargs.pop("create_env_fn") + weight_update_map = kwargs.pop("weight_update_map", None) + log_timings = kwargs.pop("log_timings", False) + if create_env_fn is not None: # could be referenced somewhere else, no need to raise an error pass @@ -225,6 +270,19 @@ def _make_ppo_trainer(*args, **kwargs) -> PPOTrainer: actor_network = actor_network() if critic_network is not None: critic_network = critic_network() + else: + critic_network = loss_module.critic_network + + # Ensure GAE in replay buffer uses the same value network instance as loss module + # This fixes the issue where Hydra instantiates separate instances of value_model + if ( + replay_buffer is not None + and hasattr(replay_buffer, "_transform") + and len(replay_buffer._transform) > 1 + and hasattr(replay_buffer._transform[1], "module") + and hasattr(replay_buffer._transform[1].module, "value_network") + ): + replay_buffer._transform[1].module.value_network = critic_network if not isinstance(collector, DataCollectorBase): # then it's a partial config @@ -258,6 +316,9 @@ def _make_ppo_trainer(*args, **kwargs) -> PPOTrainer: ) if not isinstance(logger, Logger) and logger is not None: raise ValueError(f"logger must be a Logger, got {type(logger)}") + # instantiate gae if it is a partial config + if not isinstance(gae, (GAE, TensorDictModuleBase)) and gae is not None: + gae = gae() return PPOTrainer( collector=collector, @@ -277,4 +338,8 @@ def _make_ppo_trainer(*args, **kwargs) -> PPOTrainer: replay_buffer=replay_buffer, num_epochs=num_epochs, async_collection=async_collection, + add_gae=add_gae, + gae=gae, + weight_update_map=weight_update_map, + log_timings=log_timings, ) diff --git a/torchrl/trainers/algorithms/configs/transforms.py b/torchrl/trainers/algorithms/configs/transforms.py index 4a60e2da9b0..be221ec7934 100644 --- a/torchrl/trainers/algorithms/configs/transforms.py +++ b/torchrl/trainers/algorithms/configs/transforms.py @@ -938,3 +938,19 @@ class FlattenTensorDictConfig(TransformConfig): def __post_init__(self) -> None: """Post-initialization hook for FlattenTensorDict configuration.""" super().__post_init__() + + +@dataclass +class ModuleTransformConfig(TransformConfig): + """Configuration for ModuleTransform.""" + + module: Any = None + device: Any = None + no_grad: bool = False + inverse: bool = False + _target_: str = "torchrl.envs.transforms.module.ModuleTransform" + _partial_: bool = False + + def __post_init__(self) -> None: + """Post-initialization hook for ModuleTransform configuration.""" + super().__post_init__() diff --git a/torchrl/trainers/algorithms/configs/weight_sync_schemes.py b/torchrl/trainers/algorithms/configs/weight_sync_schemes.py new file mode 100644 index 00000000000..4417e5c2cb3 --- /dev/null +++ b/torchrl/trainers/algorithms/configs/weight_sync_schemes.py @@ -0,0 +1,196 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +from torchrl.trainers.algorithms.configs.common import ConfigBase + + +@dataclass +class WeightSyncSchemeConfig(ConfigBase): + """Base configuration for weight synchronization schemes.""" + + _target_: str = "torchrl.weight_update.WeightSyncScheme" + _partial_: bool = False + + # Common argument for all schemes + strategy: str = "tensordict" # "tensordict" or "state_dict" + + def __post_init__(self) -> None: + """Post-initialization hook for weight sync scheme configurations.""" + + +@dataclass +class MultiProcessWeightSyncSchemeConfig(ConfigBase): + """Configuration for MultiProcessWeightSyncScheme. + + Weight synchronization for multiprocess operations using pipes. + This scheme creates transports that communicate via multiprocessing pipes. + """ + + _target_: str = "torchrl.weight_update.MultiProcessWeightSyncScheme" + _partial_: bool = False + + strategy: str = "tensordict" # "tensordict" or "state_dict" + + def __post_init__(self) -> None: + """Post-initialization hook for multiprocess weight sync scheme configurations.""" + + +@dataclass +class SharedMemWeightSyncSchemeConfig(ConfigBase): + """Configuration for SharedMemWeightSyncScheme. + + Weight synchronization using shared memory for in-place weight updates. + Workers automatically see weight updates without explicit message passing. + + By default, uses lazy registration (auto_register=True) which makes it seamless + to use with Hydra configs - models are automatically registered on first weight send. + """ + + _target_: str = "torchrl.weight_update.SharedMemWeightSyncScheme" + _partial_: bool = False + + policy_weights: Any = None # dict[str, TensorDictBase] | None + strategy: str = "tensordict" # "tensordict" or "state_dict" + auto_register: bool = True # Enable lazy registration by default + + def __post_init__(self) -> None: + """Post-initialization hook for shared memory weight sync scheme configurations.""" + + +@dataclass +class NoWeightSyncSchemeConfig(ConfigBase): + """Configuration for NoWeightSyncScheme. + + No-op weight synchronization scheme that disables weight synchronization entirely. + """ + + _target_: str = "torchrl.weight_update.NoWeightSyncScheme" + _partial_: bool = False + + strategy: str = "tensordict" # Not really used, but kept for consistency + + def __post_init__(self) -> None: + """Post-initialization hook for no weight sync scheme configurations.""" + + +@dataclass +class RayWeightSyncSchemeConfig(ConfigBase): + """Configuration for RayWeightSyncScheme. + + Weight synchronization for Ray distributed computing. Uses Ray's object store + and remote calls to synchronize weights across distributed workers (Ray actors). + """ + + _target_: str = "torchrl.weight_update.RayWeightSyncScheme" + _partial_: bool = False + + strategy: str = "tensordict" # "tensordict" or "state_dict" + + def __post_init__(self) -> None: + """Post-initialization hook for Ray weight sync scheme configurations.""" + + +@dataclass +class RayModuleTransformSchemeConfig(ConfigBase): + """Configuration for RayModuleTransformScheme. + + Weight synchronization for RayModuleTransform actors. This scheme is designed + specifically for updating models hosted within Ray actors. + """ + + _target_: str = "torchrl.weight_update.RayModuleTransformScheme" + _partial_: bool = False + + strategy: str = "tensordict" # "tensordict" or "state_dict" + + def __post_init__(self) -> None: + """Post-initialization hook for Ray module transform scheme configurations.""" + + +@dataclass +class RPCWeightSyncSchemeConfig(ConfigBase): + """Configuration for RPCWeightSyncScheme. + + Weight synchronization for torch.distributed.rpc. Uses RPC calls to synchronize + weights across distributed workers. + """ + + _target_: str = "torchrl.weight_update.RPCWeightSyncScheme" + _partial_: bool = False + + strategy: str = "tensordict" # "tensordict" or "state_dict" + + def __post_init__(self) -> None: + """Post-initialization hook for RPC weight sync scheme configurations.""" + + +@dataclass +class DistributedWeightSyncSchemeConfig(ConfigBase): + """Configuration for DistributedWeightSyncScheme. + + Weight synchronization for torch.distributed. Uses torch.distributed primitives + (send/recv) to synchronize weights across distributed workers. + """ + + _target_: str = "torchrl.weight_update.DistributedWeightSyncScheme" + _partial_: bool = False + + backend: str = "gloo" # "gloo", "nccl", etc. + sync: bool = True + strategy: str = "tensordict" # "tensordict" or "state_dict" + + def __post_init__(self) -> None: + """Post-initialization hook for distributed weight sync scheme configurations.""" + + +@dataclass +class VLLMWeightSyncSchemeConfig(ConfigBase): + """Configuration for VLLMWeightSyncScheme. + + Weight synchronization scheme for vLLM engines using collective communication (NCCL). + Broadcasts weights from a trainer to vLLM inference workers with parallelism support. + """ + + _target_: str = "torchrl.weight_update.llm.VLLMWeightSyncScheme" + _partial_: bool = False + + master_address: str | None = None # Defaults to "localhost" + master_port: int | None = None # Auto-assigned if None + gpus_per_replica: int = 1 # tp_size × dp_size × pp_size + num_replicas: int = 1 + strategy: str = "tensordict" # "tensordict" or "state_dict" + device: Any = 0 # torch.device | str | int + + def __post_init__(self) -> None: + """Post-initialization hook for vLLM weight sync scheme configurations.""" + + +@dataclass +class VLLMDoubleBufferSyncSchemeConfig(ConfigBase): + """Configuration for VLLMDoubleBufferSyncScheme. + + Weight synchronization scheme for vLLM using double-buffered memory-mapped storage. + Uses TensorDict's memory-mapping capabilities to transfer weights via filesystem. + """ + + _target_: str = "torchrl.weight_update.llm.VLLMDoubleBufferSyncScheme" + _partial_: bool = False + + remote_addr: str | None = None # Directory path where sender writes weights + local_addr: str | None = None # Directory path where receiver reads weights + num_threads: int = 1 # Number of threads for memmap operations + strategy: str = "tensordict" # "tensordict" or "state_dict" + + def __post_init__(self) -> None: + """Post-initialization hook for vLLM double buffer sync scheme configurations.""" + if self.remote_addr is None: + raise ValueError("remote_addr is required for VLLMDoubleBufferSyncScheme") + if self.local_addr is None: + self.local_addr = self.remote_addr