@@ -86,6 +86,10 @@ class GRPOLoss(LossModule):
8686 When set, tokens with 0.5 * (log(pi_theta/pi_ref))^2 > kl_mask_threshold are masked out from the loss.
8787 This stabilizes updates by skipping tokens that drifted too far from the reference distribution
8888 (see table and description; enables per-token trust region).
89+ aggregation (str, optional): loss aggregation strategy for the policy objective.
90+ - "token_mean": global masked token mean (weights long sequences more). Default.
91+ - "prompt_mean": per-sample masked mean over tokens, then mean across samples (equal sample weight).
92+ - "none": return per-token loss (mask applied, no aggregation). Useful for downstream custom reductions.
8993 entropy_bonus (bool, optional): if ``True``, an entropy bonus will be added to the
9094 loss to favour exploratory policies.
9195 samples_mc_entropy (int, optional): if the distribution retrieved from the policy
@@ -150,6 +154,7 @@ def __init__(
150154 * ,
151155 clip_epsilon : float | tuple [float , float ] = 0.2 ,
152156 kl_mask_threshold : float | None = None ,
157+ aggregation : str | None = "token_mean" ,
153158 entropy_bonus : bool = True ,
154159 samples_mc_entropy : int = 1 ,
155160 entropy_coeff : float = 0.01 ,
@@ -170,6 +175,7 @@ def __init__(
170175 self .entropy_coeff = entropy_coeff
171176 self .reduction = reduction
172177 self .kl_mask_threshold = kl_mask_threshold
178+ self .aggregation = aggregation or "token_mean"
173179
174180 # Determine device and register clip epsilon as buffer
175181 if device is None :
@@ -396,13 +402,13 @@ def forward(self, tensordict: TensorDictBase) -> GRPOLossOutput:
396402 td_out .set ("loss_entropy" , - self .entropy_coeff * entropy )
397403
398404 td_out .set ("ESS" , _reduce (ess / batch , self .reduction ))
399- td_out = td_out . named_apply (
400- lambda name , value : _reduce (
401- value , reduction = self . reduction , mask = mask
402- ). squeeze ( - 1 )
403- if name .startswith ("loss_" )
404- else value ,
405- )
405+ # Aggregate loss terms according to aggregation strategy
406+ for key in list ( td_out . keys ()):
407+ if isinstance ( key , tuple ) or not isinstance ( key , str ):
408+ continue
409+ if key .startswith ("loss_" ):
410+ val = td_out . get ( key )
411+ td_out . set ( key , self . _aggregate_loss_value ( val , mask ) )
406412 if self .kl_to_ref_coeff is not None and self .kl_to_ref_coeff > 0 :
407413 # FIXME: parameterize this
408414 loss_kl , kl_penalty = self ._kl_to_ref (
@@ -446,6 +452,34 @@ def _compute_policy_objective(
446452 gain = torch .stack ([gain1 , gain2 ], - 1 ).min (dim = - 1 ).values
447453 return - gain , clip_fraction
448454
455+ def _aggregate_loss_value (
456+ self , value : torch .Tensor , mask : torch .Tensor
457+ ) -> torch .Tensor :
458+ """Aggregate a per-token loss tensor using the configured strategy.
459+
460+ Supports:
461+ - token_mean: masked mean across all tokens (default)
462+ - prompt_mean: per-sample masked mean over tokens, then mean across batch
463+ - none: return per-token loss with masked-out tokens set to 0
464+
465+ The input `value` is expected to have shape [..., T, 1] where T is the token dimension,
466+ and `mask` has shape [..., T].
467+ """
468+ if self .aggregation == "none" or self .reduction == "none" :
469+ mask_exp = expand_as_right (mask , value )
470+ return torch .where (mask_exp , value , value .new_zeros (()).expand_as (value ))
471+
472+ if self .aggregation == "prompt_mean" :
473+ # Mean over valid tokens per sample, then mean across batch
474+ mask_exp = expand_as_right (mask , value ).to (value .dtype )
475+ token_sum = (value * mask_exp ).sum (dim = - 2 , keepdim = False )
476+ token_count = mask_exp .sum (dim = - 2 , keepdim = False ).clamp_min (1.0 )
477+ sample_mean = token_sum / token_count
478+ return sample_mean .mean (dim = 0 , keepdim = False )
479+
480+ # token_mean (global masked mean)
481+ return _reduce (value , reduction = "mean" , mask = mask ).squeeze (- 1 )
482+
449483 def _get_entropy (
450484 self , dist : d .Distribution , adv_shape : torch .Size
451485 ) -> torch .Tensor | TensorDict :
0 commit comments