Skip to content

Commit d9f513e

Browse files
committed
Update
[ghstack-poisoned]
2 parents 601668c + cf89ad2 commit d9f513e

File tree

10 files changed

+303
-979
lines changed

10 files changed

+303
-979
lines changed

docs/source/reference/llms.rst

Lines changed: 160 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -633,7 +633,7 @@ Collectors
633633
.. _Collectors:
634634

635635
TorchRL offers specialized collector classes (:class:`~torchrl.collectors.llm.LLMCollector` and :class:`~torchrl.collectors.llm.RayLLMCollector`)
636-
that are tailored for LLM use cases. We also provide dedicated updaters for some inference engines.
636+
that are tailored for LLM use cases. We also provide weight synchronization schemes for vLLM inference engines.
637637

638638
See :ref:`ref_collectors` for more details on the collector API. In brief, the idea of a collector is to isolate the inference part of the pipeline
639639
in a dedicated class.
@@ -649,8 +649,126 @@ Collectors are defined by the following parameters and features:
649649
In other cases, the collector can be iterated over to collect data.
650650
- **Steps**: A collector is built with a certain number of steps budget, as well as a number of steps to be
651651
included in each batch yield during collection.
652-
- **Weight Updater**: Weight updaters are the classes that update the policy weights. Isolating the weight update
653-
in a dedicated class allows to easily implement different weight update strategies depending on the policy specification.
652+
- **Weight Synchronization Schemes**: Weight sync schemes handle the synchronization of weights between the training model
653+
and the inference engine. The new scheme-based approach provides flexible, high-performance weight updates for vLLM and
654+
other inference backends.
655+
656+
vLLM Weight Synchronization Schemes
657+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
658+
659+
TorchRL provides two weight synchronization schemes for vLLM engines, offering different trade-offs between
660+
performance and simplicity:
661+
662+
**1. NCCL-Based Synchronization** (:class:`~torchrl.weight_update.llm.VLLMWeightSyncScheme`)
663+
664+
Uses NCCL collectives for high-bandwidth GPU-to-GPU weight transfers. Best for:
665+
666+
- High-frequency weight updates
667+
- Large models where transfer speed is critical
668+
- Setups with GPU interconnect (NVLink, InfiniBand)
669+
670+
**2. Double-Buffer Synchronization** (:class:`~torchrl.weight_update.llm.VLLMDoubleBufferSyncScheme`)
671+
672+
Uses memory-mapped file storage for asynchronous weight transfers. Best for:
673+
674+
- Simpler setup without NCCL coordination
675+
- Distributed setups with shared filesystems (NFS)
676+
- Cases where update frequency is lower
677+
678+
**Usage Example with NCCL:**
679+
680+
.. code-block:: python
681+
682+
from torchrl.collectors.llm import RayLLMCollector
683+
from torchrl.weight_update.llm import VLLMWeightSyncScheme
684+
from torchrl.modules.llm import AsyncVLLM, vLLMWrapper
685+
686+
# Create vLLM engine
687+
vllm_engine = AsyncVLLM.from_pretrained(
688+
"Qwen/Qwen2.5-7B",
689+
num_devices=2,
690+
num_replicas=2,
691+
)
692+
policy = vLLMWrapper(vllm_engine, input_mode="history")
693+
694+
# Create NCCL weight sync scheme
695+
weight_sync_scheme = VLLMWeightSyncScheme(
696+
master_address="localhost",
697+
master_port=29500,
698+
gpus_per_replica=2, # tp_size × dp_size × pp_size
699+
num_replicas=2,
700+
strategy="state_dict"
701+
)
702+
703+
# Create collector with weight sync scheme
704+
collector = RayLLMCollector(
705+
env=make_env,
706+
policy=policy,
707+
dialog_turns_per_batch=256,
708+
total_dialog_turns=10000,
709+
weight_sync_schemes={"policy": weight_sync_scheme},
710+
track_policy_version=True,
711+
)
712+
713+
# During training, get the sender and update weights
714+
sender = collector._weight_senders["policy"]
715+
sender.register_model(training_model)
716+
717+
# Initialize collective group (must be called before first update)
718+
metadata = get_model_metadata(training_model)
719+
sender.init_all_workers_group(metadata, vllm_engine=vllm_engine)
720+
721+
# Update weights during training
722+
for i, data in enumerate(collector):
723+
# ... training step ...
724+
if i % 10 == 0:
725+
sender.update_weights() # Broadcasts via NCCL
726+
727+
**Usage Example with Double-Buffer:**
728+
729+
.. code-block:: python
730+
731+
from torchrl.collectors.llm import RayLLMCollector
732+
from torchrl.weight_update.llm import VLLMDoubleBufferSyncScheme
733+
from torchrl.modules.llm import AsyncVLLM, vLLMWrapper
734+
735+
# Create vLLM engine
736+
vllm_engine = AsyncVLLM.from_pretrained(
737+
"Qwen/Qwen2.5-7B",
738+
num_devices=2,
739+
num_replicas=1,
740+
)
741+
policy = vLLMWrapper(vllm_engine, input_mode="history")
742+
743+
# Create double-buffer weight sync scheme
744+
weight_sync_scheme = VLLMDoubleBufferSyncScheme(
745+
remote_addr="/tmp/weights", # Or "/mnt/shared/weights" for NFS
746+
num_threads=128,
747+
strategy="state_dict"
748+
)
749+
750+
# Create collector with weight sync scheme
751+
collector = RayLLMCollector(
752+
env=make_env,
753+
policy=policy,
754+
dialog_turns_per_batch=256,
755+
total_dialog_turns=10000,
756+
weight_sync_schemes={"policy": weight_sync_scheme},
757+
track_policy_version=True,
758+
)
759+
760+
# During training, get the sender and receiver
761+
sender = collector._weight_senders["policy"]
762+
sender.register_model(training_model)
763+
764+
# No initialization needed for double-buffer scheme!
765+
766+
# Update weights during training
767+
for i, data in enumerate(collector):
768+
# ... training step ...
769+
if i % 10 == 0:
770+
sender.update_weights() # Writes to shared storage
771+
# vLLM workers can poll and apply: receiver.poll_and_apply()
654772
655773
Policy Version Tracking
656774
~~~~~~~~~~~~~~~~~~~~~~~
@@ -662,19 +780,52 @@ transform, or a boolean to the collector constructor.
662780

663781
>>> from torchrl.envs.llm.transforms import PolicyVersion
664782
>>> from torchrl.collectors.llm import LLMCollector
665-
>>> from torchrl.collectors.llm.weight_update import vLLMUpdater
783+
>>> from torchrl.weight_update.llm import VLLMWeightSyncScheme, get_model_metadata
666784
>>> env = make_env() # place your code here
667785
>>> policy = make_policy() # place your code here
668-
>>> collector = LLMCollector(env, policy=policy, weight_updater=vLLMUpdater(), track_policy_version=True)
669-
>>> # init the updater
670-
>>> collector.weight_updater.init(...)
671-
>>> # the version is incremented after each weight update
672-
>>> collector.update_policy_weights_(state_dict=...)
786+
>>> scheme = VLLMWeightSyncScheme(master_port=29500, gpus_per_replica=1, num_replicas=1)
787+
>>> collector = LLMCollector(env, policy=policy, weight_sync_schemes={"policy": scheme}, track_policy_version=True)
788+
>>> # Get the sender and register model
789+
>>> sender = collector._weight_senders["policy"]
790+
>>> sender.register_model(training_model)
791+
>>> # Initialize the collective group
792+
>>> metadata = get_model_metadata(training_model)
793+
>>> sender.init_all_workers_group(metadata, vllm_engine=policy.model)
794+
>>> # Update weights
795+
>>> sender.update_weights()
673796
>>> print(collector.policy_version_tracker.version)
674797
>>> # the policy version is written in the data
675798
>>> for data in collector:
676799
... print(data["policy_version"])
677800

801+
.. currentmodule:: torchrl.weight_update.llm
802+
803+
.. autosummary::
804+
:toctree: generated/
805+
:template: rl_template.rst
806+
807+
VLLMWeightSyncScheme
808+
VLLMWeightSender
809+
VLLMWeightReceiver
810+
VLLMCollectiveTransport
811+
VLLMDoubleBufferSyncScheme
812+
VLLMDoubleBufferWeightSender
813+
VLLMDoubleBufferWeightReceiver
814+
VLLMDoubleBufferTransport
815+
get_model_metadata
816+
817+
Legacy Weight Updaters (Deprecated)
818+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
819+
820+
.. deprecated:: 0.11
821+
The `vLLMUpdater` and `vLLMUpdaterV2` classes are deprecated in favor of the new weight synchronization schemes
822+
(:class:`~torchrl.weight_update.llm.VLLMWeightSyncScheme` and :class:`~torchrl.weight_update.llm.VLLMDoubleBufferSyncScheme`).
823+
These schemes provide better performance, more flexibility, and cleaner integration with collectors.
824+
The legacy updaters will be removed in a future release.
825+
826+
The legacy weight updaters (`vLLMUpdater` and `vLLMUpdaterV2`) are still available but are no longer recommended.
827+
Please migrate to the new weight synchronization schemes shown above.
828+
678829
.. currentmodule:: torchrl.collectors.llm
679830

680831
.. autosummary::

examples/collectors/weight_sync_standalone.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def example_multiprocess_sync():
129129
print(
130130
f"Main: Worker completed. Worker's weight sum: {model_state['weight_sum']:.4f}"
131131
)
132-
print("Weight synchronization successful!")
132+
print("Weight synchronization successful!")
133133

134134

135135
def example_shared_memory_sync():
@@ -179,7 +179,7 @@ def example_shared_memory_sync():
179179
print(
180180
f"Main: Worker completed. Worker's weight sum: {model_state['weight_sum']:.4f}"
181181
)
182-
print("Shared memory synchronization successful!")
182+
print("Shared memory synchronization successful!")
183183

184184

185185
def main():

sota-implementations/expert-iteration/ei_utils.py

Lines changed: 24 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@
1515
from torch import device as torch_device, dtype as torch_dtype
1616

1717
from torchrl._utils import logger as torchrl_logger
18-
from torchrl.collectors.llm.weight_update.vllm import vLLMUpdater
1918
from torchrl.envs.llm import RetrieveLogProb
2019
from torchrl.envs.llm.datasets.ifeval import IFEvalEnv
2120
from torchrl.modules.llm import TransformersWrapper, vLLMWrapper
21+
from torchrl.weight_update.llm import VLLMWeightSyncScheme
2222
from transformers.models.auto.modeling_auto import AutoModelForCausalLM
2323
from transformers.tokenization_utils import PreTrainedTokenizer
2424

@@ -479,42 +479,40 @@ def get_hf_model(
479479
torch.set_default_dtype(original_dtype)
480480

481481

482-
def make_weight_updater(
483-
policy_training=None,
482+
def make_weight_sync_scheme(
484483
master_address=None,
485484
master_port=None,
486-
model_metadata=None,
487-
vllm_tp_size=None,
488-
) -> vLLMUpdater:
489-
"""Creates a vLLM weight updater for the policy.
485+
vllm_tp_size=1,
486+
) -> VLLMWeightSyncScheme:
487+
"""Creates a vLLM weight synchronization scheme using NCCL collectives.
490488
491-
This function can be used in two ways:
492-
1. Synchronous mode (expert-iteration-sync.py): Pass policy_training to get an initialized updater with metadata
493-
2. Async mode (expert-iteration-async.py): Pass master_address, master_port, model_metadata, and remote_actor
489+
This function creates a weight sync scheme that uses NCCL for high-performance
490+
GPU-to-GPU weight transfers from the training model to vLLM inference workers.
494491
495492
Args:
496-
policy_training (Optional[TransformersWrapper]): The training policy model. Required for sync mode.
497-
master_address (Optional[str]): Ray master address for async mode.
498-
master_port (Optional[int]): Ray master port for async mode.
499-
model_metadata (Optional[dict]): Model metadata for async mode. If not provided but policy_training is,
500-
it will be extracted from the policy.
501-
vllm_tp_size (Optional[int]): vLLM tensor parallel size. If not provided, will be set to 1.
493+
master_address (Optional[str]): Address of the master node for distributed init.
494+
Defaults to "localhost".
495+
master_port (Optional[int]): Port of the master node for distributed init.
496+
If None, will auto-assign.
497+
vllm_tp_size (int): vLLM tensor parallel size (gpus_per_replica). Defaults to 1.
502498
503499
Returns:
504-
vLLMUpdater: An instance of the weight updater configured to update
505-
the vLLM worker's weights.
500+
VLLMWeightSyncScheme: A weight sync scheme configured for the vLLM engine.
506501
"""
507-
if model_metadata is None and policy_training is not None:
508-
# Extract metadata from training policy
509-
model_metadata = {
510-
k: (v.dtype, v.shape) for k, v in policy_training.model.state_dict().items()
511-
}
502+
if master_address is None:
503+
master_address = "localhost"
504+
505+
torchrl_logger.info(
506+
f"Creating VLLMWeightSyncScheme with tp_size={vllm_tp_size}, "
507+
f"master_address={master_address}, master_port={master_port}"
508+
)
512509

513-
return vLLMUpdater(
510+
return VLLMWeightSyncScheme(
514511
master_address=master_address,
515512
master_port=master_port,
516-
model_metadata=model_metadata,
517-
vllm_tp_size=vllm_tp_size,
513+
gpus_per_replica=vllm_tp_size,
514+
num_replicas=1, # For expert iteration, typically 1 replica
515+
strategy="state_dict",
518516
)
519517

520518

sota-implementations/expert-iteration/expert-iteration-async.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@
1313
import hydra
1414

1515
from torchrl import torchrl_logger
16-
from torchrl.collectors.llm.weight_update.vllm import vLLMUpdater
1716
from torchrl.data.llm.history import History
1817
from torchrl.record.loggers.wandb import WandbLogger
18+
from torchrl.weight_update.llm import get_model_metadata
1919

2020
try:
2121
import ray
@@ -33,7 +33,7 @@
3333
get_train_model,
3434
log_training_metrics,
3535
make_env,
36-
make_weight_updater,
36+
make_weight_sync_scheme,
3737
RemoteDataLogger,
3838
)
3939
from omegaconf import DictConfig
@@ -115,26 +115,39 @@ def train(
115115
if cfg.model.compile:
116116
loss_fn = torch.compile(loss_fn)
117117

118-
# Get metadata
119-
model_metadata = vLLMUpdater.get_model_metadata(policy_training)
118+
# Get vLLM engine from the inference policy
119+
# Note: In expert iteration, the inference policy is typically created in get_inference_model
120+
# We need to get the vLLM engine from the collector's policy or create it
121+
# For now, we'll use the approach similar to GRPO with explicit scheme creation
120122

121-
# Create weight updater with remote LLM
122-
weight_updater: vLLMUpdater = make_weight_updater(
123+
# Create weight sync scheme
124+
weight_sync_scheme = make_weight_sync_scheme(
123125
master_address="localhost", # Since we're running locally
124126
master_port=None, # Will auto-assign an open port
125-
model_metadata=model_metadata,
126127
vllm_tp_size=cfg.inference_model.num_devices
127128
if cfg.inference_model.num_devices is not None
128129
else len(cfg.inference_model.get("devices", [1])),
129130
)
130-
collector.weight_updater = weight_updater
131131

132-
# Initialize the weight updater
133-
weight_updater.init(model_metadata=model_metadata)
132+
# Set up weight sender
133+
torchrl_logger.info("Setting up weight synchronization scheme...")
134+
sender = weight_sync_scheme.create_sender()
135+
sender.register_model(policy_training)
134136

135-
# First update the weights
137+
# Get vLLM engine reference from collector's policy
138+
# The collector has the policy which wraps the vLLM engine
139+
vllm_engine = collector.policy.model if hasattr(collector, "policy") else None
140+
if vllm_engine is None:
141+
raise RuntimeError("Could not get vLLM engine from collector policy")
142+
143+
# Initialize collective group
144+
torchrl_logger.info("Initializing collective group...")
145+
metadata = get_model_metadata(policy_training)
146+
sender.init_all_workers_group(metadata, vllm_engine=vllm_engine)
147+
148+
# First weight update
136149
with timeit("update_policy_weights"):
137-
weight_updater.push_weights(policy_training)
150+
sender.update_weights()
138151
timeit.print(prefix="First update_policy_weights_ time")
139152
timeit.reset()
140153

@@ -329,7 +342,7 @@ def train(
329342
if step % cfg.train.weight_update_frequency == 0:
330343
with timeit("update_policy_weights"):
331344
torchrl_logger.info("Updating policy weights...")
332-
weight_updater.push_weights(policy_training)
345+
sender.update_weights()
333346
# TODO: do we need this? Does it interfere with other processes?
334347
# torch.cuda.empty_cache()
335348
gc.collect()

0 commit comments

Comments
 (0)