from __future__ import annotations from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import numpy as np import torch import torch.nn.functional as F from transformers import AutoModelForSeq2SeqLM, AutoTokenizer from transformers.modeling_outputs import BaseModelOutput DEFAULT_INSTRUCTION = "Given a query, retrieve documents that answer the query." DEFAULT_SYSTEM_INSTRUCTION = ( "Judge whether the Document meets the requirements based on the Query and " 'the Instruct provided. Note that the answer can only be "yes" or "no".' ) class KaLMReranker: """Score query-document relevance with a KaLM encoder-decoder reranker. The returned score is ``P(yes)`` after applying a two-class softmax to the model's ``yes`` and ``no`` logits. """ def __init__( self, model_name_or_path: str, *, device: Optional[Union[str, torch.device]] = None, dtype: Optional[Union[str, torch.dtype]] = None, batch_size: int = 32, query_max_length: int = 512, max_length: int = 1024, chunk_size: Optional[int] = 4, instruction: str = DEFAULT_INSTRUCTION, system_instruction: str = DEFAULT_SYSTEM_INSTRUCTION, **model_kwargs: Any, ) -> None: if not isinstance(model_name_or_path, str) or not model_name_or_path: raise ValueError("model_name_or_path must be a non-empty string.") if batch_size <= 0: raise ValueError("batch_size must be positive.") if query_max_length <= 0 or max_length <= 0: raise ValueError("query_max_length and max_length must be positive.") if chunk_size is not None and chunk_size <= 0: raise ValueError("chunk_size must be positive or None.") if not isinstance(instruction, str) or not isinstance(system_instruction, str): raise TypeError("instruction and system_instruction must be strings.") self.device = self._resolve_device(device) self.dtype = self._resolve_dtype(dtype, self.device) self.batch_size = batch_size self.query_max_length = query_max_length self.max_length = max_length self.chunk_size = chunk_size self.instruction = instruction self.system_instruction = system_instruction self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) if self.tokenizer.pad_token_id is None: if self.tokenizer.eos_token_id is None: raise ValueError("The tokenizer must define a pad token or an EOS token.") self.tokenizer.pad_token = self.tokenizer.eos_token self.tokenizer.padding_side = "right" self.model = AutoModelForSeq2SeqLM.from_pretrained( model_name_or_path, dtype=self.dtype, **model_kwargs, ) for parameter in self.model.parameters(): if parameter.is_floating_point() and parameter.dtype != self.dtype: parameter.data = parameter.data.to(dtype=self.dtype) self.model.to(device=self.device) self.model.eval() self.yes_token_id = self._answer_token_id("yes") self.no_token_id = self._answer_token_id("no") @staticmethod def _resolve_device(device: Optional[Union[str, torch.device]]) -> torch.device: if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" resolved = torch.device(device) if resolved.type == "cuda" and not torch.cuda.is_available(): raise RuntimeError("CUDA was requested, but no CUDA device is available.") return resolved @staticmethod def _resolve_dtype( dtype: Optional[Union[str, torch.dtype]], device: torch.device ) -> torch.dtype: if dtype is None: return torch.bfloat16 if device.type == "cuda" else torch.float32 if isinstance(dtype, torch.dtype): return dtype if not isinstance(dtype, str): raise TypeError("dtype must be a torch.dtype or a string such as 'bfloat16'.") normalized = dtype.lower().removeprefix("torch.") supported = { "bfloat16": torch.bfloat16, "bf16": torch.bfloat16, "float16": torch.float16, "fp16": torch.float16, "float32": torch.float32, "fp32": torch.float32, } if normalized not in supported: raise ValueError(f"Unsupported dtype: {dtype!r}.") return supported[normalized] def _answer_token_id(self, answer: str) -> int: token_ids = self.tokenizer(answer, add_special_tokens=False)["input_ids"] if not token_ids: raise ValueError(f"Failed to tokenize the answer {answer!r}.") return token_ids[-1] def _get_encoder(self): if hasattr(self.model, "get_encoder"): return self.model.get_encoder() if hasattr(self.model, "encoder"): return self.model.encoder raise AttributeError(f"Cannot find the encoder on {type(self.model).__name__}.") @staticmethod def _pool_encoder_chunks( hidden_states: torch.Tensor, attention_mask: torch.Tensor, chunk_size: int, ) -> Tuple[torch.Tensor, torch.Tensor]: batch_size, sequence_length, hidden_size = hidden_states.shape num_chunks = (sequence_length + chunk_size - 1) // chunk_size padded_length = num_chunks * chunk_size pad_length = padded_length - sequence_length if pad_length: hidden_states = F.pad(hidden_states, (0, 0, 0, pad_length)) attention_mask = F.pad(attention_mask, (0, pad_length)) hidden_states = hidden_states.view( batch_size, num_chunks, chunk_size, hidden_size ) chunk_mask = attention_mask.view(batch_size, num_chunks, chunk_size) expanded_mask = chunk_mask.unsqueeze(-1).to(hidden_states.dtype) pooled_hidden = (hidden_states * expanded_mask).sum(dim=2) pooled_hidden = pooled_hidden / chunk_mask.sum(dim=2).clamp(min=1).unsqueeze(-1) pooled_mask = (chunk_mask.sum(dim=2) > 0).to(attention_mask.dtype) return pooled_hidden, pooled_mask def _decoder_text(self, query: str, instruction: str) -> str: query_ids = self.tokenizer( query, add_special_tokens=False, truncation=True, max_length=self.query_max_length, )["input_ids"] truncated_query = self.tokenizer.decode( query_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False, ) return ( "user\n" f"{self.system_instruction}\n\n" f": {instruction}\n" f": {truncated_query}\n" "model\n\n\n\n" ) @staticmethod def _validate_pairs( pairs: Sequence[Tuple[str, str]], ) -> List[Tuple[str, str]]: if isinstance(pairs, (str, bytes)) or not isinstance(pairs, Sequence): raise TypeError("pairs must be a sequence of (query, document) pairs.") validated: List[Tuple[str, str]] = [] for index, pair in enumerate(pairs): if ( isinstance(pair, (str, bytes)) or not isinstance(pair, Sequence) or len(pair) != 2 ): raise ValueError(f"pairs[{index}] must contain exactly two strings.") query, document = pair if not isinstance(query, str) or not isinstance(document, str): raise TypeError(f"pairs[{index}] must contain exactly two strings.") validated.append((query, document)) return validated @torch.inference_mode() def _predict_batch( self, pairs: Sequence[Tuple[str, str]], instruction: str ) -> List[float]: encoder_texts = [f": {document}" for _, document in pairs] decoder_texts = [self._decoder_text(query, instruction) for query, _ in pairs] encoder_batch = self.tokenizer( encoder_texts, padding=True, truncation=True, max_length=self.max_length, add_special_tokens=False, return_tensors="pt", ).to(self.device) decoder_batch = self.tokenizer( decoder_texts, padding=True, pad_to_multiple_of=8, add_special_tokens=False, return_tensors="pt", ).to(self.device) if self.chunk_size is None: outputs = self.model( input_ids=encoder_batch["input_ids"], attention_mask=encoder_batch["attention_mask"], decoder_input_ids=decoder_batch["input_ids"], decoder_attention_mask=decoder_batch["attention_mask"], return_dict=True, ) else: encoder_outputs = self._get_encoder()( input_ids=encoder_batch["input_ids"], attention_mask=encoder_batch["attention_mask"], return_dict=True, ) pooled_hidden, pooled_mask = self._pool_encoder_chunks( encoder_outputs.last_hidden_state, encoder_batch["attention_mask"], self.chunk_size, ) outputs = self.model( encoder_outputs=BaseModelOutput(last_hidden_state=pooled_hidden), attention_mask=pooled_mask, decoder_input_ids=decoder_batch["input_ids"], decoder_attention_mask=decoder_batch["attention_mask"], return_dict=True, ) sequence_lengths = decoder_batch["attention_mask"].sum(dim=1) - 1 batch_indices = torch.arange(outputs.logits.shape[0], device=self.device) last_logits = outputs.logits[batch_indices, sequence_lengths] yes_no_logits = torch.stack( ( last_logits[:, self.yes_token_id], last_logits[:, self.no_token_id], ), dim=-1, ).float() if not torch.isfinite(yes_no_logits).all(): bad_count = (~torch.isfinite(yes_no_logits).all(dim=-1)).sum().item() raise RuntimeError( f"The model produced non-finite yes/no logits for {bad_count} input(s). " "Use bfloat16 or float32 instead of float16." ) return torch.softmax(yes_no_logits, dim=-1)[:, 0].cpu().tolist() def predict( self, pairs: Sequence[Tuple[str, str]], *, instruction: Optional[str] = None, batch_size: Optional[int] = None, ) -> List[float]: """Return ``P(yes)`` scores in the same order as ``pairs``.""" validated_pairs = self._validate_pairs(pairs) if not validated_pairs: return [] effective_instruction = self.instruction if instruction is None else instruction if not isinstance(effective_instruction, str): raise TypeError("instruction must be a string or None.") effective_batch_size = self.batch_size if batch_size is None else batch_size if not isinstance(effective_batch_size, int) or effective_batch_size <= 0: raise ValueError("batch_size must be a positive integer.") length_sorted_indices = np.argsort( [-(len(query) + len(document)) for query, document in validated_pairs] ) sorted_pairs = [validated_pairs[index] for index in length_sorted_indices] tested_batch_size = effective_batch_size while tested_batch_size > 1: try: self._predict_batch( sorted_pairs[: min(len(sorted_pairs), tested_batch_size)], effective_instruction, ) break except torch.cuda.OutOfMemoryError: if torch.cuda.is_available(): torch.cuda.empty_cache() tested_batch_size = max(1, tested_batch_size * 3 // 4) sorted_scores: List[float] = [] try: for start in range(0, len(sorted_pairs), tested_batch_size): sorted_scores.extend( self._predict_batch( sorted_pairs[start : start + tested_batch_size], effective_instruction, ) ) except torch.cuda.OutOfMemoryError as error: if torch.cuda.is_available(): torch.cuda.empty_cache() raise RuntimeError( "CUDA ran out of memory during reranking. Retry with a smaller batch_size " "or shorter max_length." ) from error inverse_indices = np.argsort(length_sorted_indices) return [sorted_scores[index] for index in inverse_indices] def rank( self, query: str, documents: Sequence[str], *, instruction: Optional[str] = None, top_k: Optional[int] = None, batch_size: Optional[int] = None, ) -> List[Dict[str, Union[int, float]]]: """Rank documents and return ``corpus_id``/``score`` dictionaries.""" if not isinstance(query, str): raise TypeError("query must be a string.") if isinstance(documents, (str, bytes)) or not isinstance(documents, Sequence): raise TypeError("documents must be a sequence of strings.") if any(not isinstance(document, str) for document in documents): raise TypeError("every document must be a string.") if top_k is not None and (not isinstance(top_k, int) or top_k < 0): raise ValueError("top_k must be a non-negative integer or None.") scores = self.predict( [(query, document) for document in documents], instruction=instruction, batch_size=batch_size, ) rankings: List[Dict[str, Union[int, float]]] = [ {"corpus_id": corpus_id, "score": score} for corpus_id, score in enumerate(scores) ] rankings.sort(key=lambda item: item["score"], reverse=True) return rankings if top_k is None else rankings[:top_k] __all__ = ["KaLMReranker"]