# modeling_modernbert_reward.py from typing import Optional, Union, Tuple import torch import torch.nn as nn import torch.nn.functional as F from transformers.modeling_outputs import SequenceClassifierOutput from transformers import ModernBertPreTrainedModel from transformers.models.modernbert.modeling_modernbert import ( ModernBertModel, ModernBertPredictionHead ) import math class ModernBertForOrdinalAndRegression(ModernBertPreTrainedModel): """ ModernBERT 本体の上に CORAL(順序) + 回帰ヘッドを載せる多目的報酬器。 - config.num_labels = K (例: 51 → 0.2刻み) - 学習: L = L_ordinal + lambda_reg * L_regression (両方に sample_weight を掛ける) - 推論: ord/reg のアンサンブル(blend) """ def __init__(self, config): super().__init__(config) self.config = config self.model = ModernBertModel(config) self.head = ModernBertPredictionHead(config) self.drop = nn.Dropout(config.classifier_dropout) self.num_bins = int(getattr(config, "num_labels", 51)) self.lambda_reg = float(getattr(config, "lambda_reg", 0.3)) self.reg_temperature = float(getattr(config, "reg_temperature", 1.0)) self.reg_eps = float(getattr(config, "reg_eps", 1e-4)) self.gamma = float(getattr(config, "gamma", 0.05)) self.blend = float(getattr(config, "blend", 0.5)) self.score_min = float(getattr(config, "score_min", 0.0)) self.score_max = float(getattr(config, "score_max", 10.0)) # CORAL: 共通重み + 単調しきい値 self.coral_fc = nn.Linear(config.hidden_size, 1, bias=False) self.coral_bias_raw = nn.Parameter(torch.zeros(self.num_bins - 1)) # 回帰ヘッド self.reg_head = nn.Linear(config.hidden_size, 1) self.config.problem_type = "regression" self.post_init() def _init_weights(self, module: nn.Module): super()._init_weights(module) cutoff_factor = self.config.initializer_cutoff_factor if cutoff_factor is None: cutoff_factor = 3 def init_weight(module: nn.Module, std: float): nn.init.trunc_normal_( module.weight, mean=0.0, std=std, a=-cutoff_factor * std, b=cutoff_factor * std, ) if isinstance(module, nn.Linear): if module.bias is not None: nn.init.zeros_(module.bias) stds = { "in": self.config.initializer_range, "out": self.config.initializer_range / math.sqrt(2.0 * self.config.num_hidden_layers), "embedding": self.config.initializer_range, "final_out": self.config.hidden_size**-0.5, } if isinstance(module, ModernBertForOrdinalAndRegression): init_weight(module.coral_fc, stds["final_out"]) init_weight(module.reg_head, stds["final_out"]) module.coral_bias_raw.zero_() def _thresholds(self) -> torch.Tensor: # softplus で正の差分 → 累積で単調に return torch.cumsum(F.softplus(self.coral_bias_raw), dim=0) def _pool(self, last_hidden, attention_mask) -> torch.Tensor: pooling = getattr(self.config, "classifier_pooling", "cls") if pooling == "mean": mask = attention_mask.unsqueeze(-1).to(last_hidden.dtype) return (last_hidden * mask).sum(dim=1) / mask.sum(dim=1).clamp_min(1e-6) return last_hidden[:, 0] # "cls" def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, sliding_window_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, # 未使用 labels_cont: Optional[torch.Tensor] = None, # [B] 0..10 labels_bin: Optional[torch.Tensor] = None, # [B] 0..K-1 sample_weight: Optional[torch.Tensor] = None, # [B] indices: Optional[torch.Tensor] = None, cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None, batch_size: Optional[int] = None, seq_len: Optional[int] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs, ) -> Union[Tuple, SequenceClassifierOutput]: return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, sliding_window_mask=sliding_window_mask, position_ids=position_ids, inputs_embeds=inputs_embeds, indices=indices, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, batch_size=batch_size, seq_len=seq_len, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=True, ) last_hidden = outputs.last_hidden_state pooled = self.head(self._pool(last_hidden, attention_mask)) pooled = self.drop(pooled) # ----- Ordinal (CORAL) ----- z = self.coral_fc(pooled).squeeze(-1) # [B] th = self._thresholds() # [K-1] logits_ord = z.unsqueeze(-1) - th.unsqueeze(0) # [B,K-1] p_gt = torch.sigmoid(logits_ord) ones = torch.ones(p_gt.size(0), 1, device=p_gt.device, dtype=p_gt.dtype) zeros = torch.zeros(p_gt.size(0), 1, device=p_gt.device, dtype=p_gt.dtype) p_left = torch.cat([ones, p_gt], dim=1) p_right = torch.cat([p_gt, zeros], dim=1) p_cls = (p_left - p_right).clamp_min(0.0) # [B,K] bins = torch.arange(self.num_bins, device=p_gt.device, dtype=p_gt.dtype).unsqueeze(0) expected_bin = (p_cls * bins).sum(dim=-1) # [B] score_ord = self.score_min + (self.score_max - self.score_min) * (expected_bin / (self.num_bins - 1)) # ----- Regression ----- reg_raw = self.reg_head(pooled).squeeze(-1) # [B] p = torch.sigmoid(reg_raw / self.reg_temperature) p = p.clamp(self.reg_eps, 1.0 - self.reg_eps) score_reg = self.score_min + (self.score_max - self.score_min) * p # [B] # ----- Blend(最終スコア)----- score = (1.0 - self.blend) * score_reg + self.blend * score_ord # [B] logits = score.unsqueeze(-1) # [B,1] 0..10 # ----- Loss ----- loss = None if (labels_cont is not None) or (labels_bin is not None): if sample_weight is None: sample_weight = torch.ones_like(score) sw = sample_weight.to(score.device).float() sw = sw / (sw.mean() + 1e-12) loss_total = 0.0 if labels_bin is not None: # CORAL loss y = labels_bin.to(logits_ord.device).long() Km1 = self.num_bins - 1 thr = torch.arange(Km1, device=y.device).unsqueeze(0) target_ord = (y.unsqueeze(1) > thr).float() # [B,K-1] bce = F.binary_cross_entropy_with_logits(logits_ord, target_ord, reduction="none").mean(dim=-1) loss_ord = (bce * sw).sum() / sw.sum() loss_total = loss_total + loss_ord if labels_cont is not None and self.lambda_reg > 0.0: # Huber loss y_cont = labels_cont.to(score.device).float().clamp(self.score_min, self.score_max) pt = (y_cont - self.score_min) / (self.score_max - self.score_min) pt = pt.clamp(self.reg_eps, 1.0 - self.reg_eps) t = torch.log(pt) - torch.log1p(-pt) t = self.reg_temperature * t huber = F.smooth_l1_loss(reg_raw, t, reduction="none") loss_reg = (huber * sw).sum() / sw.sum() loss_total = loss_total + self.lambda_reg * loss_reg if self.gamma > 0: loss_total += self.gamma * (F.smooth_l1_loss(score, y_cont, reduction="none") * sw).sum() / sw.sum() loss = loss_total if not return_dict: out = (logits,) return ((loss,) + out) if loss is not None else out return SequenceClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, )