from __future__ import annotations import re from dataclasses import dataclass from typing import Iterable import torch import torch.nn as nn from transformers import AutoTokenizer, PreTrainedModel from transformers.modeling_outputs import ModelOutput from transformers.models.deberta_v2.modeling_deberta_v2 import DebertaV2Model try: from .configuration_leg import LEGConfig except ImportError: # pragma: no cover from configuration_leg import LEGConfig _PROMPT_CLEAN_PATTERN = r"[^\w\s-]|(? torch.Tensor: attn_scores = self.attn(embeddings).squeeze(-1) if attention_mask is not None: neg_inf = torch.finfo(attn_scores.dtype).min attn_scores = attn_scores.masked_fill(attention_mask == 0, neg_inf) attn_weights = torch.softmax(attn_scores, dim=-1) return torch.sum(embeddings * attn_weights.unsqueeze(-1), dim=1) class LEGForSafetyExplanation(PreTrainedModel): config_class = LEGConfig base_model_prefix = "bert" def __init__(self, config: LEGConfig): super().__init__(config) self.bert = DebertaV2Model(config) self.attention_pooling = AttentionPooling(config.hidden_size) self.prompt_classifier = nn.Linear(config.hidden_size, 2) self.token_classifier = nn.Linear(config.hidden_size, 2) # Kept only because these parameters exist in the source checkpoint. self.log_sigma_prompt = nn.Parameter(torch.zeros(())) self.log_sigma_token = nn.Parameter(torch.zeros(())) self._cached_tokenizer = None self._inference_tokenizer_source = None self.post_init() @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): model = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) model._inference_tokenizer_source = str(pretrained_model_name_or_path) return model def _get_tokenizer(self, tokenizer=None): if tokenizer is not None: return tokenizer if self._cached_tokenizer is not None: return self._cached_tokenizer source = ( self._inference_tokenizer_source or getattr(self, "name_or_path", None) or getattr(self.config, "_name_or_path", None) ) if not source: raise ValueError( "Tokenizer source could not be resolved automatically. " "Pass tokenizer=AutoTokenizer.from_pretrained(...) explicitly." ) self._cached_tokenizer = AutoTokenizer.from_pretrained(source, use_fast=True) return self._cached_tokenizer @staticmethod def _clean_and_split_prompt(prompt: str) -> list[str]: cleaned_text = re.sub(_PROMPT_CLEAN_PATTERN, "", prompt) return cleaned_text.split() @staticmethod def _normalize_prompts(prompts: str | Iterable[str]) -> tuple[list[str], bool]: if isinstance(prompts, str): return [prompts], True prompt_list = list(prompts) return prompt_list, False def _predict_from_tokenized( self, encodings, words_batch: list[list[str]], prompt_threshold: float, word_threshold: float, ) -> list[dict]: device = next(self.parameters()).device model_inputs = { "input_ids": encodings["input_ids"].to(device), "attention_mask": encodings["attention_mask"].to(device), } with torch.inference_mode(): outputs = self.forward(**model_inputs) prompt_probs = torch.softmax(outputs.prompt_logits, dim=1).cpu() token_probs = torch.softmax(outputs.token_logits, dim=2).cpu() formatted_outputs = [] for batch_index, words in enumerate(words_batch): prompt_safe = prompt_probs[batch_index, 0].item() prompt_unsafe = prompt_probs[batch_index, 1].item() safety_label = int( prompt_unsafe > prompt_safe and prompt_unsafe > prompt_threshold ) token_safe = token_probs[batch_index, :, 0].tolist() token_unsafe = token_probs[batch_index, :, 1].tolist() word_ids = encodings.word_ids(batch_index=batch_index) word_id_to_label = {} for token_index, word_id in enumerate(word_ids): if word_id is None or token_index >= len(token_unsafe): continue predicted_label = int( token_unsafe[token_index] > token_safe[token_index] and token_unsafe[token_index] > word_threshold ) if word_id not in word_id_to_label: word_id_to_label[word_id] = predicted_label explanation = [ (word, word_id_to_label.get(word_index, 0)) for word_index, word in enumerate(words) ] formatted_outputs.append( { "safety_label": safety_label, "explanation": explanation, } ) return formatted_outputs def predict_safety( self, prompts: str | Iterable[str], tokenizer=None, prompt_threshold: float | None = None, word_threshold: float | None = None, max_length: int | None = None, batch_size: int | None = None, ): prompt_list, single_input = self._normalize_prompts(prompts) tokenizer = self._get_tokenizer(tokenizer=tokenizer) if not prompt_list: return [] if not single_input else { "safety_label": 0, "explanation": [], } prompt_threshold = ( self.config.prompt_threshold if prompt_threshold is None else prompt_threshold ) word_threshold = ( self.config.word_threshold if word_threshold is None else word_threshold ) max_length = ( self.config.inference_max_length if max_length is None else max_length ) effective_batch_size = len(prompt_list) if batch_size is not None: if batch_size <= 0: raise ValueError("batch_size must be a positive integer when provided.") effective_batch_size = batch_size formatted_outputs = [] for start_idx in range(0, len(prompt_list), effective_batch_size): prompt_chunk = prompt_list[start_idx : start_idx + effective_batch_size] words_batch = [ self._clean_and_split_prompt(prompt_text or "") for prompt_text in prompt_chunk ] encodings = tokenizer( words_batch, is_split_into_words=True, max_length=max_length, truncation=True, padding="max_length", return_tensors="pt", ) formatted_outputs.extend( self._predict_from_tokenized( encodings=encodings, words_batch=words_batch, prompt_threshold=prompt_threshold, word_threshold=word_threshold, ) ) return formatted_outputs[0] if single_input else formatted_outputs def forward( self, input_ids: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, token_type_ids: torch.Tensor | None = None, prompts: str | Iterable[str] | None = None, prompt: str | None = None, tokenizer=None, prompt_threshold: float | None = None, word_threshold: float | None = None, max_length: int | None = None, batch_size: int | None = None, **kwargs, ): if prompts is None and prompt is not None: prompts = prompt if prompts is not None and input_ids is None: return self.predict_safety( prompts=prompts, tokenizer=tokenizer, prompt_threshold=prompt_threshold, word_threshold=word_threshold, max_length=max_length, batch_size=batch_size, ) if input_ids is None: raise ValueError( "Provide either tokenized inputs (`input_ids`, `attention_mask`) or " "raw `prompts`/`prompt` strings." ) encoder_kwargs = { "input_ids": input_ids, "attention_mask": attention_mask, "return_dict": True, } if token_type_ids is not None: encoder_kwargs["token_type_ids"] = token_type_ids encoder_outputs = self.bert(**encoder_kwargs) hidden_states = encoder_outputs.last_hidden_state pooled_output = self.attention_pooling(hidden_states, attention_mask) return LEGModelOutput( prompt_logits=self.prompt_classifier(pooled_output), token_logits=self.token_classifier(hidden_states), )