Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions torchrl/trainers/algorithms/configs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
ValueModelConfig,
)
from torchrl.trainers.algorithms.configs.objectives import (
GAEConfig,
HardUpdateConfig,
LossConfig,
PPOLossConfig,
Expand Down Expand Up @@ -126,6 +127,7 @@
InitTrackerConfig,
KLRewardTransformConfig,
LineariseRewardsConfig,
ModuleTransformConfig,
MultiActionConfig,
MultiStepTransformConfig,
NoopResetEnvConfig,
Expand Down Expand Up @@ -179,6 +181,18 @@
SGDConfig,
SparseAdamConfig,
)
from torchrl.trainers.algorithms.configs.weight_sync_schemes import (
DistributedWeightSyncSchemeConfig,
MultiProcessWeightSyncSchemeConfig,
NoWeightSyncSchemeConfig,
RayModuleTransformSchemeConfig,
RayWeightSyncSchemeConfig,
RPCWeightSyncSchemeConfig,
SharedMemWeightSyncSchemeConfig,
VLLMDoubleBufferSyncSchemeConfig,
VLLMWeightSyncSchemeConfig,
WeightSyncSchemeConfig,
)
from torchrl.trainers.algorithms.configs.weight_update import (
DistributedWeightUpdaterConfig,
MultiProcessedWeightUpdaterConfig,
Expand Down Expand Up @@ -273,6 +287,7 @@
"InitTrackerConfig",
"KLRewardTransformConfig",
"LineariseRewardsConfig",
"ModuleTransformConfig",
"MultiActionConfig",
"MultiStepTransformConfig",
"NoopResetEnvConfig",
Expand Down Expand Up @@ -330,6 +345,8 @@
"LossConfig",
"PPOLossConfig",
"SACLossConfig",
# Value functions
"GAEConfig",
# Trainers
"PPOTrainerConfig",
"SACTrainerConfig",
Expand All @@ -348,6 +365,17 @@
"RPCWeightUpdaterConfig",
"DistributedWeightUpdaterConfig",
"vLLMUpdaterConfig",
# Weight Sync Schemes
"WeightSyncSchemeConfig",
"MultiProcessWeightSyncSchemeConfig",
"SharedMemWeightSyncSchemeConfig",
"NoWeightSyncSchemeConfig",
"RayWeightSyncSchemeConfig",
"RayModuleTransformSchemeConfig",
"RPCWeightSyncSchemeConfig",
"DistributedWeightSyncSchemeConfig",
"VLLMWeightSyncSchemeConfig",
"VLLMDoubleBufferSyncSchemeConfig",
]


Expand All @@ -356,6 +384,10 @@ def _register_configs():

This function is called lazily to avoid GlobalHydra initialization issues
during testing. It should be called explicitly when needed.

To add a new config:
- Write the config class in the appropriate file (e.g. torchrl/trainers/algorithms/configs/transforms.py) and add it to the __all__ list in torchrl/trainers/algorithms/configs/__init__.py
- Register the config in the appropriate group, e.g. cs.store(group="transform", name="new_transform", node=NewTransformConfig)
"""
cs = ConfigStore.instance()

Expand Down Expand Up @@ -461,6 +493,7 @@ def _register_configs():
cs.store(group="transform", name="action_discretizer", node=ActionDiscretizerConfig)
cs.store(group="transform", name="traj_counter", node=TrajCounterConfig)
cs.store(group="transform", name="linearise_rewards", node=LineariseRewardsConfig)
cs.store(group="transform", name="module", node=ModuleTransformConfig)
cs.store(group="transform", name="conditional_skip", node=ConditionalSkipConfig)
cs.store(group="transform", name="multi_action", node=MultiActionConfig)
cs.store(group="transform", name="timer", node=TimerConfig)
Expand All @@ -487,6 +520,7 @@ def _register_configs():
cs.store(group="transform", name="vip", node=VIPTransformConfig)
cs.store(group="transform", name="vip_reward", node=VIPRewardTransformConfig)
cs.store(group="transform", name="vec_norm_v2", node=VecNormV2Config)
cs.store(group="transform", name="module", node=ModuleTransformConfig)

# =============================================================================
# Loss Configurations
Expand All @@ -496,6 +530,16 @@ def _register_configs():
cs.store(group="loss", name="ppo", node=PPOLossConfig)
cs.store(group="loss", name="sac", node=SACLossConfig)

# =============================================================================
# Value Function Configurations
# =============================================================================

cs.store(group="value", name="gae", node=GAEConfig)

# =============================================================================
# Target Net Updater Configurations
# =============================================================================

cs.store(group="target_net_updater", name="soft", node=SoftUpdateConfig)
cs.store(group="target_net_updater", name="hard", node=HardUpdateConfig)

Expand Down Expand Up @@ -595,6 +639,41 @@ def _register_configs():
)
cs.store(group="weight_updater", name="vllm", node=vLLMUpdaterConfig)

# =============================================================================
# Weight Sync Scheme Configurations
# =============================================================================

cs.store(group="weight_sync_scheme", name="base", node=WeightSyncSchemeConfig)
cs.store(
group="weight_sync_scheme",
name="multiprocess",
node=MultiProcessWeightSyncSchemeConfig,
)
cs.store(
group="weight_sync_scheme",
name="shared_mem",
node=SharedMemWeightSyncSchemeConfig,
)
cs.store(group="weight_sync_scheme", name="no_sync", node=NoWeightSyncSchemeConfig)
cs.store(group="weight_sync_scheme", name="ray", node=RayWeightSyncSchemeConfig)
cs.store(
group="weight_sync_scheme",
name="ray_module_transform",
node=RayModuleTransformSchemeConfig,
)
cs.store(group="weight_sync_scheme", name="rpc", node=RPCWeightSyncSchemeConfig)
cs.store(
group="weight_sync_scheme",
name="distributed",
node=DistributedWeightSyncSchemeConfig,
)
cs.store(group="weight_sync_scheme", name="vllm", node=VLLMWeightSyncSchemeConfig)
cs.store(
group="weight_sync_scheme",
name="vllm_double_buffer",
node=VLLMDoubleBufferSyncSchemeConfig,
)


if not sys.version_info < (3, 10): #  type: ignore # noqa
_register_configs()
8 changes: 8 additions & 0 deletions torchrl/trainers/algorithms/configs/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ class SyncDataCollectorConfig(DataCollectorConfig):
cudagraph_policy: Any = None
no_cuda_sync: bool = False
weight_updater: Any = None
weight_sync_schemes: Any = None
track_policy_version: bool = False
local_init_rb: bool = False
_target_: str = "torchrl.collectors.SyncDataCollector"
_partial_: bool = False

Expand Down Expand Up @@ -94,7 +96,9 @@ class AsyncDataCollectorConfig(DataCollectorConfig):
cudagraph_policy: Any = None
no_cuda_sync: bool = False
weight_updater: Any = None
weight_sync_schemes: Any = None
track_policy_version: bool = False
local_init_rb: bool = False
_target_: str = "torchrl.collectors.aSyncDataCollector"
_partial_: bool = False

Expand Down Expand Up @@ -136,7 +140,9 @@ class MultiSyncDataCollectorConfig(DataCollectorConfig):
cudagraph_policy: Any = None
no_cuda_sync: bool = False
weight_updater: Any = None
weight_sync_schemes: Any = None
track_policy_version: bool = False
local_init_rb: bool = False
_target_: str = "torchrl.collectors.MultiSyncDataCollector"
_partial_: bool = False

Expand Down Expand Up @@ -179,7 +185,9 @@ class MultiaSyncDataCollectorConfig(DataCollectorConfig):
cudagraph_policy: Any = None
no_cuda_sync: bool = False
weight_updater: Any = None
weight_sync_schemes: Any = None
track_policy_version: bool = False
local_init_rb: bool = False
_target_: str = "torchrl.collectors.MultiaSyncDataCollector"
_partial_: bool = False

Expand Down
2 changes: 2 additions & 0 deletions torchrl/trainers/algorithms/configs/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ class LazyMemmapStorageConfig(StorageConfig):
device: Any = None
ndim: int = 1
compilable: bool = False
shared_init: bool = False


@dataclass
Expand All @@ -265,6 +266,7 @@ class LazyTensorStorageConfig(StorageConfig):
device: Any = None
ndim: int = 1
compilable: bool = False
shared_init: bool = False


@dataclass
Expand Down
9 changes: 8 additions & 1 deletion torchrl/trainers/algorithms/configs/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ class ModelConfig(ConfigBase):
_partial_: bool = False
in_keys: Any = None
out_keys: Any = None
shared: bool = False

def __post_init__(self) -> None:
"""Post-initialization hook for model configurations."""
Expand All @@ -226,7 +227,7 @@ class TensorDictModuleConfig(ModelConfig):

def __post_init__(self) -> None:
"""Post-initialization hook for TensorDict module configurations."""
super().__post_init__()
return super().__post_init__()


@dataclass
Expand Down Expand Up @@ -312,6 +313,7 @@ def _make_tanh_normal_model(*args, **kwargs):
return_log_prob = kwargs.pop("return_log_prob", False)
eval_mode = kwargs.pop("eval_mode", False)
exploration_type = kwargs.pop("exploration_type", "RANDOM")
shared = kwargs.pop("shared", False)

# Now instantiate the network
if hasattr(network, "_target_"):
Expand All @@ -328,6 +330,8 @@ def _make_tanh_normal_model(*args, **kwargs):
)

module = TensorDictModule(network, in_keys=in_keys, out_keys=param_keys)
if shared:
module = module.share_memory()

# Create ProbabilisticTensorDictModule
prob_module = ProbabilisticTensorDictModule(
Expand All @@ -350,4 +354,7 @@ def _make_value_model(*args, **kwargs):
from torchrl.modules import ValueOperator

network = kwargs.pop("network")
shared = kwargs.pop("shared", False)
if shared:
network = network.share_memory()
return ValueOperator(network, **kwargs)
27 changes: 27 additions & 0 deletions torchrl/trainers/algorithms/configs/objectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,30 @@ class HardUpdateConfig(TargetNetUpdaterConfig):

_target_: str = "torchrl.objectives.utils.HardUpdate."
value_network_update_interval: int = 1000


@dataclass
class GAEConfig(LossConfig):
"""A class to configure a GAELoss."""

gamma: float | None = None
lmbda: float | None = None
value_network: Any = None
average_gae: bool = True
differentiable: bool = False
vectorized: bool | None = None
skip_existing: bool | None = None
advantage_key: str | None = None
value_target_key: str | None = None
value_key: str | None = None
shifted: bool = False
device: Any = None
time_dim: int | None = None
auto_reset_env: bool = False
deactivate_vmap: bool = False
_target_: str = "torchrl.objectives.value.GAE"
_partial_: bool = False

def __post_init__(self) -> None:
"""Post-initialization hook for GAELoss configurations."""
super().__post_init__()
Loading
Loading