Skip to content

Commit e461f0a

Browse files
committed
[Feature] kl_mask_threshold
ghstack-source-id: 492beab Pull-Request: #3208
1 parent 9b5ea9e commit e461f0a

File tree

1 file changed

+33
-0
lines changed

1 file changed

+33
-0
lines changed

torchrl/objectives/llm/grpo.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
ProbabilisticTensorDictSequential,
2525
set_composite_lp_aggregate,
2626
)
27+
from tensordict.utils import expand_as_right
2728
from torch import distributions as d
2829
from torchrl._utils import logger as torchrl_logger, VERBOSE
2930
from torchrl.envs.transforms.transforms import Transform
@@ -81,6 +82,10 @@ class GRPOLoss(LossModule):
8182
- float x: symmetric clipping [1 - x, 1 + x] (default: 0.2)
8283
- tuple (eps_low, eps_high): asymmetric clipping [1 - eps_low, 1 + eps_high] as in DAPO Clip-Higher
8384
recommended defaults from DAPO: (0.20, 0.28); see Eq. (10) in the paper.
85+
kl_mask_threshold (float | None, optional): enable token-wise trust-region filtering (KL-Mask).
86+
When set, tokens with 0.5 * (log(pi_theta/pi_ref))^2 > kl_mask_threshold are masked out from the loss.
87+
This stabilizes updates by skipping tokens that drifted too far from the reference distribution
88+
(see table and description; enables per-token trust region).
8489
entropy_bonus (bool, optional): if ``True``, an entropy bonus will be added to the
8590
loss to favour exploratory policies.
8691
samples_mc_entropy (int, optional): if the distribution retrieved from the policy
@@ -144,6 +149,7 @@ def __init__(
144149
actor_network: LLMWrapperBase | None = None,
145150
*,
146151
clip_epsilon: float | tuple[float, float] = 0.2,
152+
kl_mask_threshold: float | None = None,
147153
entropy_bonus: bool = True,
148154
samples_mc_entropy: int = 1,
149155
entropy_coeff: float = 0.01,
@@ -163,6 +169,7 @@ def __init__(
163169
self.samples_mc_entropy = samples_mc_entropy
164170
self.entropy_coeff = entropy_coeff
165171
self.reduction = reduction
172+
self.kl_mask_threshold = kl_mask_threshold
166173

167174
# Determine device and register clip epsilon as buffer
168175
if device is None:
@@ -333,6 +340,32 @@ def forward(self, tensordict: TensorDictBase) -> GRPOLossOutput:
333340
tensordict, adv_shape=advantage.shape[:-1]
334341
)
335342
mask = dist.mask
343+
344+
# Optional per-token trust-region filtering (KL-Mask) vs reference policy
345+
if self.kl_mask_threshold is not None and self.kl_mask_threshold > 0:
346+
try:
347+
ref_log_prob = tensordict.get(
348+
self.tensor_keys.ref_log_probs,
349+
as_padded_tensor=True,
350+
padding_side="left",
351+
padding_value=0.0,
352+
)
353+
except KeyError:
354+
ref_log_prob = None
355+
cur_log_prob = tensordict.get("_cur_log_prob", None)
356+
if (ref_log_prob is not None) and (cur_log_prob is not None):
357+
# Align to valid tokens only (safety)
358+
cur_log_prob_masked = torch.where(
359+
expand_as_right(mask, cur_log_prob), cur_log_prob, 0.0
360+
)
361+
ref_log_prob_masked = torch.where(
362+
expand_as_right(mask, ref_log_prob), ref_log_prob, 0.0
363+
)
364+
log_is_ref = cur_log_prob_masked - ref_log_prob_masked
365+
kl_token = 0.5 * (log_is_ref**2)
366+
tr_mask = kl_token <= self.kl_mask_threshold
367+
# Combine with attention mask
368+
mask = mask & tr_mask
336369
# ESS for logging
337370
with torch.no_grad():
338371
# In theory, ESS should be computed on particles sampled from the same source. Here we sample according

0 commit comments

Comments
 (0)