ai-talent-finder-backend / ai_module /matching /bert_classifier_adapter.py
ilyass yani
Deploiement backend dans HF Spaces
9df97a2
Raw
History Blame
5 kB
"""BERT classifier adapter for camembert-based candidate/job compatibility scoring.
The underlying model (CamembertForSequenceClassification) outputs 3 classes:
0 = incompatible, 1 = partial, 2 = compatible
The adapter converts softmax probabilities to a continuous [0, 1] score via
weighted class expectation: score = 0.0*P(incompatible) + 0.5*P(partial) + 1.0*P(compatible)
"""
from __future__ import annotations
import logging
from pathlib import Path
from typing import Optional
import numpy as np
logger = logging.getLogger(__name__)
try:
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
TRANSFORMERS_AVAILABLE = True
except ImportError:
TRANSFORMERS_AVAILABLE = False
logger.warning("transformers/torch not installed — BertClassifierAdapter will be disabled.")
class BertClassifierAdapter:
"""Thin wrapper around a fine-tuned CamembertForSequenceClassification model."""
# Class-level weights: incompatible=0, partial=0.5, compatible=1.0
_CLASS_WEIGHTS = np.array([0.0, 0.5, 1.0], dtype=np.float32)
def __init__(self, model_dir: str | Path):
if not TRANSFORMERS_AVAILABLE:
raise RuntimeError(
"transformers and torch are required. "
"Install them with: pip install transformers torch"
)
model_dir = Path(model_dir)
if not model_dir.exists():
raise FileNotFoundError(f"Model directory not found: {model_dir}")
self._device = "cuda" if torch.cuda.is_available() else "cpu"
self._tokenizer = AutoTokenizer.from_pretrained(str(model_dir))
self._model = AutoModelForSequenceClassification.from_pretrained(str(model_dir))
self._model.to(self._device)
self._model.eval()
self._max_length = 384
logger.info("BertClassifierAdapter loaded from %s on %s", model_dir, self._device)
@classmethod
def load(cls, model_dir: str | Path) -> "BertClassifierAdapter":
return cls(model_dir)
def predict_score(self, candidate_text: str, job_text: str) -> float:
"""Return a compatibility score in [0, 1].
Combines candidate and job text as a text-pair sequence and returns
the soft-label expectation over (incompatible=0, partial=0.5, compatible=1.0).
"""
if not TRANSFORMERS_AVAILABLE:
return 0.0
try:
inputs = self._tokenizer(
candidate_text,
job_text,
return_tensors="pt",
truncation=True,
max_length=self._max_length,
padding=True,
)
inputs = {k: v.to(self._device) for k, v in inputs.items()}
with torch.no_grad():
logits = self._model(**inputs).logits
probs = torch.softmax(logits, dim=-1).cpu().numpy()[0]
score = float(np.dot(probs, self._CLASS_WEIGHTS))
return float(np.clip(score, 0.0, 1.0))
except Exception as exc:
logger.warning("BertClassifierAdapter.predict_score failed: %s", exc)
return 0.0
def predict_label(self, candidate_text: str, job_text: str) -> str:
"""Return the top predicted label: 'incompatible', 'partial', or 'compatible'."""
if not TRANSFORMERS_AVAILABLE:
return "incompatible"
try:
inputs = self._tokenizer(
candidate_text,
job_text,
return_tensors="pt",
truncation=True,
max_length=self._max_length,
padding=True,
)
inputs = {k: v.to(self._device) for k, v in inputs.items()}
with torch.no_grad():
logits = self._model(**inputs).logits
label_id = int(torch.argmax(logits, dim=-1).item())
id2label = {0: "incompatible", 1: "partial", 2: "compatible"}
return id2label[label_id]
except Exception as exc:
logger.warning("BertClassifierAdapter.predict_label failed: %s", exc)
return "incompatible"
@property
def available(self) -> bool:
return TRANSFORMERS_AVAILABLE
_default_adapter: Optional[BertClassifierAdapter] = None
def get_default_adapter(model_dir: Optional[str | Path] = None) -> Optional[BertClassifierAdapter]:
"""Lazy-load a shared adapter instance. Returns None if unavailable."""
global _default_adapter
if _default_adapter is not None:
return _default_adapter
if model_dir is None:
# Resolve relative to this file: backend/ai_module/matching/ -> backend/models/bert_matching/
model_dir = Path(__file__).parent.parent.parent / "models" / "bert_matching"
try:
_default_adapter = BertClassifierAdapter.load(model_dir)
except Exception as exc:
logger.warning("Could not load default BertClassifierAdapter: %s", exc)
return None
return _default_adapter