"""doctrine_classifier.py — match a Hebrew query/text to top-K doctrines from the curated catalog at `tau_rag/data/doctrines.json`. V1 (no-ML): keyword + element matching with simple scoring. - Each keyword hit in the query → +1 to that doctrine's score - Each element string substring match → +0.5 - Each statute reference (e.g. 'סעיף 39') → +0.5 if mentioned in query - Doctrines are ranked by total score; ties broken by len(keywords). V2 ideas (not implemented): - Hebrew lemmatization (e.g. 'מטעים' → 'הטעיה') - Fine-tuned classifier on labeled corpus - Embedding similarity between query and doctrine description Why no ML in V1: a 100-doctrine catalog with ~10 keywords each → ~1000 patterns → regex sweep in <1ms per query. ML adds latency + complexity without clear quality win on the small catalog scale. Public API: catalog = load_doctrine_catalog() # singleton-cached matches = classify_doctrines(text, k=3) # → List[DoctrineMatch] doctrine = get_doctrine_by_id(doc_id) # for hydration Each DoctrineMatch has: doctrine_id, name_he, score, matched_keywords. """ from __future__ import annotations import json import os import re import threading from dataclasses import dataclass, field from pathlib import Path from typing import Dict, List, Optional # Default location relative to package _DEFAULT_CATALOG_PATH = (Path(__file__).resolve().parent.parent / "data" / "doctrines.json") @dataclass class DoctrineMatch: doctrine_id: str name_he: str domain: str score: float matched_keywords: List[str] = field(default_factory=list) matched_elements: List[str] = field(default_factory=list) matched_statutes: List[str] = field(default_factory=list) def to_dict(self) -> dict: return { "doctrine_id": self.doctrine_id, "name_he": self.name_he, "domain": self.domain, "score": round(self.score, 2), "matched_keywords": self.matched_keywords, "matched_elements": self.matched_elements, "matched_statutes": self.matched_statutes, } # ────────────────────────────────────────────────────────────────────── # Catalog loading — singleton, lazy # ────────────────────────────────────────────────────────────────────── _CATALOG: Optional[dict] = None _LOCK = threading.Lock() def load_doctrine_catalog(path: Optional[Path] = None) -> dict: """Return the parsed doctrines.json. Cached as a process singleton. Pass `path` to override default location (for tests / alt catalogs). """ global _CATALOG if _CATALOG is not None and path is None: return _CATALOG with _LOCK: if _CATALOG is not None and path is None: return _CATALOG catalog_path = Path(path or os.environ.get( "TAU_RAG_DOCTRINES_PATH", _DEFAULT_CATALOG_PATH)) if not catalog_path.exists(): # Empty catalog — classifier still works, just returns [] cat = {"_schema_version": 1, "doctrines": []} else: cat = json.loads(catalog_path.read_text(encoding="utf-8")) # Build lookup index for O(1) by-id access cat["_index"] = {d["id"]: d for d in cat.get("doctrines", [])} if path is None: _CATALOG = cat return cat def get_doctrine_by_id(doctrine_id: str) -> Optional[dict]: cat = load_doctrine_catalog() return cat.get("_index", {}).get(doctrine_id) # ────────────────────────────────────────────────────────────────────── # Classifier # ────────────────────────────────────────────────────────────────────── # Statute reference pattern: matches "סעיף N" + optional law name. # Captures the section number so we can compare to doctrine.statute_refs. _STATUTE_PAT = re.compile( r"סעיף\s*(\d+(?:[א-ת]\d?)?)\s*(?:לחוק\s+([^,\.\n]+))?", ) def _normalize_hebrew(text: str) -> str: """Normalize whitespace + dashes between Hebrew words. Day 47c note: tried to add programmatic prefix-stripping (ה/ב/ל/מ/ש/כ/ו from start of words) — broke real words that start with those letters legitimately (e.g. "שימוע" → "ימוע"). Backed out. Keyword variants are now expanded MANUALLY in doctrines.json instead (e.g. both "רשומה רפואית" AND "הרשומה הרפואית" as keywords). """ return re.sub(r"[\s\-־]+", " ", text) def _extract_statute_sections(text: str) -> List[str]: """Pull statute section numbers from text. Returns ['12', '39', ...].""" return [m.group(1) for m in _STATUTE_PAT.finditer(text)] def classify_doctrines(text: str, k: int = 3, min_score: float = 1.0) -> List[DoctrineMatch]: """Match `text` against catalog. Returns top-`k` DoctrineMatch objects with score >= `min_score`. Empty list if no matches. Scoring: • +1.0 per keyword hit (case-insensitive substring after normalization) • +0.5 per element string substring hit • +0.5 per matching statute section reference • Each match also tracks the matched terms for explainability. """ if not text: return [] catalog = load_doctrine_catalog() norm_text = _normalize_hebrew(text) text_lower = norm_text.lower() text_statutes = set(_extract_statute_sections(text)) matches: List[DoctrineMatch] = [] for doc in catalog.get("doctrines", []): score = 0.0 kw_hits: List[str] = [] el_hits: List[str] = [] st_hits: List[str] = [] # Keywords (high weight). Long multi-word keywords get a # bonus: a 3+ word phrase match is much higher signal than a # single-word match. Examples: # "פיטורים ללא שימוע" (3 words) → 1.5 # "חובת שימוע" (2 words) → 1.0 # "שימוע" (1 word) → 0.6 for kw in doc.get("keywords_he", []): norm_kw = _normalize_hebrew(kw).lower() if norm_kw and norm_kw in text_lower: n_words = len(norm_kw.split()) if n_words >= 3: weight = 1.5 elif n_words == 2: weight = 1.0 else: weight = 0.6 score += weight kw_hits.append(kw) # Elements (medium weight) — substring of element string for el in doc.get("elements", []): norm_el = _normalize_hebrew(el).lower() # Element strings are full sentences; check if any 3+ chars # of the element appear in the text. Conservative substring # match — could be improved with bag-of-words overlap. if norm_el and len(norm_el) >= 6 and norm_el in text_lower: score += 0.5 el_hits.append(el) # Statute matches (medium weight) for ref in doc.get("statute_refs", []): for section in ref.get("sections", []): if section in text_statutes: score += 0.5 st_hits.append(f"{ref['law']} ס׳ {section}") if score >= min_score: matches.append(DoctrineMatch( doctrine_id=doc["id"], name_he=doc["name_he"], domain=doc.get("domain", "general"), score=score, matched_keywords=kw_hits, matched_elements=el_hits, matched_statutes=st_hits, )) matches.sort(key=lambda m: (-m.score, m.doctrine_id)) return matches[:k] # ────────────────────────────────────────────────────────────────────── # CLI smoke test — run module directly to test catalog matching # ────────────────────────────────────────────────────────────────────── if __name__ == "__main__": import sys queries = sys.argv[1:] or [ "פיטורים ללא שימוע בעובד שעבד 3 חודשים", "הפרת חוזה בתום-לב — חובת הגילוי", "רשלנות רפואית באבחון מחלה", "מצג שווא בעסקת מכר דירה", "אחריות מנהל לפי סעיף 39 לחוק החוזים", ] cat = load_doctrine_catalog() print(f"Catalog loaded: {len(cat.get('doctrines', []))} doctrines") print() for q in queries: print(f" Query: {q}") matches = classify_doctrines(q, k=3) if not matches: print(f" → (no matches)") for m in matches: print(f" → {m.name_he}") print(f" score={m.score:.1f} · keywords={m.matched_keywords}") print()