2424 ProbabilisticTensorDictSequential ,
2525 set_composite_lp_aggregate ,
2626)
27+ from tensordict .utils import expand_as_right
2728from torch import distributions as d
2829from torchrl ._utils import logger as torchrl_logger , VERBOSE
2930from torchrl .envs .transforms .transforms import Transform
@@ -81,6 +82,10 @@ class GRPOLoss(LossModule):
8182 - float x: symmetric clipping [1 - x, 1 + x] (default: 0.2)
8283 - tuple (eps_low, eps_high): asymmetric clipping [1 - eps_low, 1 + eps_high] as in DAPO Clip-Higher
8384 recommended defaults from DAPO: (0.20, 0.28); see Eq. (10) in the paper.
85+ kl_mask_threshold (float | None, optional): enable token-wise trust-region filtering (KL-Mask).
86+ When set, tokens with 0.5 * (log(pi_theta/pi_ref))^2 > kl_mask_threshold are masked out from the loss.
87+ This stabilizes updates by skipping tokens that drifted too far from the reference distribution
88+ (see table and description; enables per-token trust region).
8489 entropy_bonus (bool, optional): if ``True``, an entropy bonus will be added to the
8590 loss to favour exploratory policies.
8691 samples_mc_entropy (int, optional): if the distribution retrieved from the policy
@@ -144,6 +149,7 @@ def __init__(
144149 actor_network : LLMWrapperBase | None = None ,
145150 * ,
146151 clip_epsilon : float | tuple [float , float ] = 0.2 ,
152+ kl_mask_threshold : float | None = None ,
147153 entropy_bonus : bool = True ,
148154 samples_mc_entropy : int = 1 ,
149155 entropy_coeff : float = 0.01 ,
@@ -163,6 +169,7 @@ def __init__(
163169 self .samples_mc_entropy = samples_mc_entropy
164170 self .entropy_coeff = entropy_coeff
165171 self .reduction = reduction
172+ self .kl_mask_threshold = kl_mask_threshold
166173
167174 # Determine device and register clip epsilon as buffer
168175 if device is None :
@@ -333,6 +340,32 @@ def forward(self, tensordict: TensorDictBase) -> GRPOLossOutput:
333340 tensordict , adv_shape = advantage .shape [:- 1 ]
334341 )
335342 mask = dist .mask
343+
344+ # Optional per-token trust-region filtering (KL-Mask) vs reference policy
345+ if self .kl_mask_threshold is not None and self .kl_mask_threshold > 0 :
346+ try :
347+ ref_log_prob = tensordict .get (
348+ self .tensor_keys .ref_log_probs ,
349+ as_padded_tensor = True ,
350+ padding_side = "left" ,
351+ padding_value = 0.0 ,
352+ )
353+ except KeyError :
354+ ref_log_prob = None
355+ cur_log_prob = tensordict .get ("_cur_log_prob" , None )
356+ if (ref_log_prob is not None ) and (cur_log_prob is not None ):
357+ # Align to valid tokens only (safety)
358+ cur_log_prob_masked = torch .where (
359+ expand_as_right (mask , cur_log_prob ), cur_log_prob , 0.0
360+ )
361+ ref_log_prob_masked = torch .where (
362+ expand_as_right (mask , ref_log_prob ), ref_log_prob , 0.0
363+ )
364+ log_is_ref = cur_log_prob_masked - ref_log_prob_masked
365+ kl_token = 0.5 * (log_is_ref ** 2 )
366+ tr_mask = kl_token <= self .kl_mask_threshold
367+ # Combine with attention mask
368+ mask = mask & tr_mask
336369 # ESS for logging
337370 with torch .no_grad ():
338371 # In theory, ESS should be computed on particles sampled from the same source. Here we sample according
0 commit comments