""" BGE-M3 Projection Model for Hugging Face Transformers A lightweight projection head trained on top of frozen BGE-M3 encoder for Vietnamese rental property search. """ from typing import List, Optional, Union import torch import torch.nn as nn import torch.nn.functional as F from transformers import AutoModel, AutoTokenizer from transformers import PretrainedConfig, PreTrainedModel from transformers.modeling_outputs import BaseModelOutput class BGEM3ProjectionConfig(PretrainedConfig): """ Configuration class for BGEM3ProjectionModel Args: base_model (str): Base model identifier (default: "BAAI/bge-m3") d_in (int): Input dimension from base encoder (default: 1024) d_out (int): Output dimension after projection (default: 128) use_layernorm (bool): Whether to use LayerNorm in projection head freeze_encoder (bool): Whether to freeze the base encoder max_length (int): Maximum sequence length for tokenization """ model_type = "bgem3_projection" def __init__( self, base_model: str = "BAAI/bge-m3", d_in: int = 1024, d_out: int = 128, use_layernorm: bool = False, freeze_encoder: bool = True, max_length: int = 512, **kwargs ): super().__init__(**kwargs) self.base_model = base_model self.d_in = d_in self.d_out = d_out self.use_layernorm = use_layernorm self.freeze_encoder = freeze_encoder self.max_length = max_length def mean_pool(last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: """ Mean pooling with attention mask Args: last_hidden_state: [batch_size, seq_len, hidden_size] attention_mask: [batch_size, seq_len] Returns: pooled: [batch_size, hidden_size] """ mask = attention_mask.unsqueeze(-1).type_as(last_hidden_state) # [B, T, 1] summed = (last_hidden_state * mask).sum(dim=1) # [B, H] counts = mask.sum(dim=1).clamp(min=1e-6) # [B, 1] return summed / counts class ProjectionHead(nn.Module): """ Projection head: Linear + Optional LayerNorm + L2 Normalization """ def __init__(self, d_in: int, d_out: int, use_layernorm: bool = False): super().__init__() self.linear = nn.Linear(d_in, d_out, bias=False) self.ln = nn.LayerNorm(d_out) if use_layernorm else None def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: x: [batch_size, d_in] Returns: [batch_size, d_out] L2-normalized """ x = self.linear(x) if self.ln is not None: x = self.ln(x) # L2 normalize for cosine similarity x = F.normalize(x, p=2, dim=-1) return x class BGEM3ProjectionModel(PreTrainedModel): """ BGE-M3 with trainable projection head This model combines: 1. Frozen BGE-M3 encoder (1024-dim embeddings) 2. Trainable projection head (1024 -> d_out, default 128) Usage: >>> from transformers import AutoModel, AutoTokenizer >>> >>> model = AutoModel.from_pretrained("your-username/bge-m3-vietnamese-rental-projection", trust_remote_code=True) >>> tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-m3") >>> >>> # Encode texts >>> texts = ["Phòng trọ Quận 10, 25m2, giá 5tr"] >>> inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt", max_length=512) >>> embeddings = model(**inputs).last_hidden_state >>> >>> # Or use the encode method >>> embeddings = model.encode(texts) """ config_class = BGEM3ProjectionConfig base_model_prefix = "bgem3_projection" supports_gradient_checkpointing = False def __init__(self, config: BGEM3ProjectionConfig): super().__init__(config) self.config = config # Load base encoder self.encoder = AutoModel.from_pretrained(config.base_model) # Tokenizer will be lazy-loaded when needed self._tokenizer = None # Freeze encoder if specified if config.freeze_encoder: for param in self.encoder.parameters(): param.requires_grad = False # Projection head (trainable) self.head = ProjectionHead( d_in=config.d_in, d_out=config.d_out, use_layernorm=config.use_layernorm ) self.post_init() @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): """ Override from_pretrained to only load projection head weights The base encoder should always be loaded from BAAI/bge-m3, NOT from our checkpoint (which only contains projection head weights). """ # Load config config = kwargs.pop("config", None) if config is None: config = cls.config_class.from_pretrained(pretrained_model_name_or_path, **kwargs) # Initialize model (this loads encoder from BAAI/bge-m3) model = cls(config) # Now load ONLY the projection head weights from our checkpoint from safetensors import safe_open import os # Find the safetensors file if os.path.isdir(pretrained_model_name_or_path): # Local directory safetensors_path = os.path.join(pretrained_model_name_or_path, "model.safetensors") else: # Download from HF Hub from huggingface_hub import hf_hub_download safetensors_path = hf_hub_download( repo_id=pretrained_model_name_or_path, filename="model.safetensors" ) # Load only head weights state_dict = {} with safe_open(safetensors_path, framework="pt", device="cpu") as f: for key in f.keys(): if key.startswith("head."): state_dict[key] = f.get_tensor(key) # Load into model (only affects head) model.load_state_dict(state_dict, strict=False) return model @property def tokenizer(self): """Lazy load tokenizer when needed""" if self._tokenizer is None: from transformers import AutoTokenizer self._tokenizer = AutoTokenizer.from_pretrained( self.config.base_model, use_fast=True ) return self._tokenizer def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[tuple, BaseModelOutput]: """ Forward pass through encoder and projection head Returns: BaseModelOutput with last_hidden_state = projected embeddings [batch_size, d_out] """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict # Encode with base model with torch.set_grad_enabled(not self.config.freeze_encoder): encoder_outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=True, ) # Mean pooling pooled = mean_pool( encoder_outputs.last_hidden_state, attention_mask ) # [batch_size, 1024] # Project to d_out projected = self.head(pooled) # [batch_size, d_out], L2-normalized if not return_dict: return (projected,) return BaseModelOutput( last_hidden_state=projected, hidden_states=encoder_outputs.hidden_states if output_hidden_states else None, attentions=encoder_outputs.attentions if output_attentions else None, ) @torch.no_grad() def encode( self, texts: Union[str, List[str]], batch_size: int = 32, max_length: Optional[int] = None, show_progress: bool = False, device: Optional[torch.device] = None, ) -> torch.Tensor: """ Encode texts to embeddings (convenience method) Args: texts: Single text or list of texts batch_size: Batch size for encoding max_length: Maximum sequence length (default: config.max_length) show_progress: Show progress bar device: Target device (default: model device) Returns: Tensor of shape [num_texts, d_out], L2-normalized """ if isinstance(texts, str): texts = [texts] if device is None: device = next(self.parameters()).device if max_length is None: max_length = self.config.max_length self.eval() all_embeddings = [] # Optional progress bar iterator = range(0, len(texts), batch_size) if show_progress: try: from tqdm import tqdm iterator = tqdm(iterator, desc="Encoding") except ImportError: pass for i in iterator: batch_texts = texts[i:i + batch_size] # Tokenize inputs = self.tokenizer( batch_texts, padding=True, truncation=True, max_length=max_length, return_tensors="pt" ) # Move to device inputs = {k: v.to(device) for k, v in inputs.items()} # Forward pass outputs = self.forward(**inputs) embeddings = outputs.last_hidden_state all_embeddings.append(embeddings.cpu()) return torch.cat(all_embeddings, dim=0) def compute_similarity( self, text1: Union[str, List[str]], text2: Union[str, List[str]], ) -> torch.Tensor: """ Compute cosine similarity between texts Args: text1: Single text or list of texts text2: Single text or list of texts Returns: Similarity scores (cosine similarity) """ emb1 = self.encode(text1) emb2 = self.encode(text2) # Cosine similarity (already L2-normalized, so just dot product) if emb1.dim() == 1: emb1 = emb1.unsqueeze(0) if emb2.dim() == 1: emb2 = emb2.unsqueeze(0) similarity = emb1 @ emb2.T return similarity.squeeze() # Register model for AutoModel try: from transformers import AutoModel, AutoConfig AutoConfig.register("bgem3_projection", BGEM3ProjectionConfig) AutoModel.register(BGEM3ProjectionConfig, BGEM3ProjectionModel) except Exception as e: # Registration may fail if models are already registered pass