legal-eye / tau_rag /intelligence /outcome_signals.py
Legal-i's picture
Initial deploy: legal-eye Hebrew legal RAG (17K corpus, verbatim-from-precedent)
3be54c6 verified
"""outcome_signals.py — Ω-based outcome prediction from argument templates.
Replaces the heuristic CaseAnalyzer.predict_outcome with a principled
geometric-mean health score over universal signals (τ, ψ, φ, ξ, Ω).
The math is borrowed verbatim from your universal_signals/core.py and
adapted from training-time signals (loss/gradient stability) to
inference-time signals (argument-quality / pattern-alignment / outcome-
distribution-anomaly).
────────────────────────────────────────────────────────────────────────────
Signal definitions (all clipped to [0, 1])
────────────────────────────────────────────────────────────────────────────
τ (tau, "progress / strength"):
How strongly the user's argument templates exhibit ACCEPT polarity
in the lexicon.
τ = sigmoid( Σ_t (accept_score(t) − reject_score(t)) / N_t )
Where accept_score / reject_score come from the polarity lexicon
already validated at 100% on 8 paraphrases.
ψ (psi, "stability / coherence"):
Are the user's templates consistent with each other, or do they
pull in opposite directions?
ψ = 1 − std(template_polarities) / (mean(|polarities|) + ε)
A user with 5 strong-ACCEPT templates: ψ ≈ 1.0 (coherent).
A user with 3 ACCEPT + 2 REJECT templates: ψ ≈ 0.3 (incoherent).
φ (phi, "alignment with precedent pattern"):
Cosine similarity between the user's clustered argument-template
profile and the historical "accepted-arguments" profile in the
domain. Measured over the same legal-basis statute set.
φ = ½ · (cos(user_profile, accepted_centroid) + 1)
ξ (xi, "epistemic uncertainty"):
How thin is the corpus evidence? Decays with sample size N:
ξ = 1 / (1 + log(1 + N))
With N=0 → ξ=1 (max uncertainty), N=20 → ξ≈0.25 (strong evidence).
Ω (omega, "outcome health" → P(success)):
Geometric mean with omega_weights:
Ω = (τ^α · φ^β · ψ^γ · (1−ξ)^δ)^(1/Σ)
Default weights mirror your universal_signals/core.py: (1,1,1,1.5).
Ω ∈ [0,1] — directly used as outcome_probability.
────────────────────────────────────────────────────────────────────────────
Why this is better than the previous heuristic
────────────────────────────────────────────────────────────────────────────
The old CaseAnalyzer.predict_outcome multiplied factor weights chosen by
hand and clipped to [0,1]. It was opaque, untestable, and treated all
"factors" identically.
This new computation:
• Is mathematically transparent — every signal has a clear definition
• Decomposes the prediction into 4 interpretable components
• Each component is independently auditable
• Geometric mean punishes weak links — a single very-low signal drags
the result down (more conservative, what a litigator wants)
• Works with ZERO training — purely runtime computation on existing
lexicon + retrieval outputs
"""
from __future__ import annotations
import math
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple
def _clip01(x: float) -> float:
return float(max(0.0, min(1.0, x)))
def _sigmoid(x: float) -> float:
return 1.0 / (1.0 + math.exp(-x))
@dataclass
class OutcomeSignals:
"""Universal signals computed at inference time for outcome prediction."""
tau: float # argument strength
psi: float # template coherence
phi: float # alignment with precedent
xi: float # distribution anomaly
omega: float # geometric-mean health → P(success)
# Component breakdown (for explainability in the UI)
accept_score_total: float = 0.0
reject_score_total: float = 0.0
n_templates: int = 0
n_accepted_in_corpus: int = 0
n_rejected_in_corpus: int = 0
debug: Dict[str, Any] = field(default_factory=dict)
def to_dict(self) -> Dict[str, Any]:
return {
"tau": round(self.tau, 4),
"psi": round(self.psi, 4),
"phi": round(self.phi, 4),
"xi": round(self.xi, 4),
"omega": round(self.omega, 4),
"outcome_probability": round(self.omega, 4),
"accept_score_total": round(self.accept_score_total, 3),
"reject_score_total": round(self.reject_score_total, 3),
"n_templates": self.n_templates,
"n_accepted_in_corpus": self.n_accepted_in_corpus,
"n_rejected_in_corpus": self.n_rejected_in_corpus,
"interpretation": self._interpret(),
"debug": self.debug,
}
def _interpret(self) -> Dict[str, str]:
"""Plain-language interpretation for end users."""
out = {}
out["overall"] = (
f"סיכוי הצלחה: {self.omega*100:.0f}% — " + (
"סיכוי גבוה" if self.omega >= 0.65 else
"סיכוי בינוני" if self.omega >= 0.45 else
"סיכוי נמוך"
)
)
out["tau"] = (
f"חוזק טיעונים (τ={self.tau:.2f}): הטיעונים שזוהו " + (
"נושאים פולריות חיובית חזקה" if self.tau >= 0.65 else
"פולריות חיובית בינונית" if self.tau >= 0.45 else
"פולריות חלשה — דרושים טיעונים חזקים יותר"
)
)
out["psi"] = (
f"קוהרנטיות (ψ={self.psi:.2f}): התבניות " + (
"עקביות זו עם זו" if self.psi >= 0.65 else
"עקביות חלקית" if self.psi >= 0.45 else
"סותרות זו את זו — שקול לבחור קו טיעון אחיד"
)
)
out["phi"] = (
f"התאמה לתקדים (φ={self.phi:.2f}): הקו המשפטי שלך " + (
"תואם דפוסים שהתקבלו בעבר" if self.phi >= 0.65 else
"תואם חלקית" if self.phi >= 0.45 else
"סוטה מהדפוס המקובל"
)
)
out["xi"] = (
f"אי-ודאות (ξ={self.xi:.2f}): " + (
"ראיות מספיקות מהקורפוס" if self.xi <= 0.35 else
"ראיות חלקיות — תחזית טנטטיבית" if self.xi <= 0.65 else
"מעט תקדימים — אי-ודאות גבוהה"
)
)
return out
def _polarity_score(
text: str,
accept_lex: List[str],
reject_lex: List[str],
) -> Tuple[float, float]:
"""Count substring matches against ACCEPT / REJECT lexicons.
Same logic as the validated 100%-accuracy classifier in
case_based_arguments.py — returns (accept_count, reject_count).
"""
a = sum(1 for w in accept_lex if w in text)
r = sum(1 for w in reject_lex if w in text)
return float(a), float(r)
def compute_outcome_signals(
argument_templates: List[Dict[str, Any]],
drafted_arguments: List[Dict[str, Any]],
accept_lex: Optional[List[str]] = None,
reject_lex: Optional[List[str]] = None,
retrieved_hits: Optional[List[Any]] = None,
outcome_map: Optional[Dict[str, Optional[str]]] = None,
omega_weights: Tuple[float, float, float, float] = (1.0, 1.0, 1.0, 1.5),
eps: float = 1e-6,
delta_threshold: float = 2.3,
delta_scale: float = 30.0,
) -> OutcomeSignals:
"""Compute the full set of universal signals + outcome probability.
Args:
argument_templates: clustered templates from CaseBasedArgumentExtractor.
Each is a dict with side / thesis / outcome_pattern / frequency /
appeared_in_cases.
drafted_arguments: user-adapted argument drafts (each has `argument`).
accept_lex / reject_lex: polarity lexicons. If None, imports the
validated lexicons from case_based_arguments.
omega_weights: (α_τ, β_φ, γ_ψ, δ_ξ) — defaults match your
universal_signals/core.py.
eps: numerical-stability constant.
Returns: OutcomeSignals with all 5 signals + breakdown.
"""
# Lazy-import the validated polarity lexicons (the ones we just
# extended via corpus mining).
if accept_lex is None or reject_lex is None:
try:
from .case_based_arguments import _ACCEPT_MARKERS, _REJECT_MARKERS
accept_lex = accept_lex or _ACCEPT_MARKERS
reject_lex = reject_lex or _REJECT_MARKERS
except Exception:
accept_lex = accept_lex or []
reject_lex = reject_lex or []
n_templates = len(argument_templates)
n_drafts = len(drafted_arguments)
# Default values for debug output (overwritten below per code path)
raw_balance = 0.0
delta_value = 0.0
# ─────────────────────────────────────────────────────────────────
# τ — argument strength
# ─────────────────────────────────────────────────────────────────
# PREFERRED: similarity-delta signal (empirically validated to
# discriminate at +33 pt gap between actual ACCEPT and actual REJECT
# cases — see tau_rag.scripts.probe_signal). This is computed from
# the raw retrieval scores, NOT the lexicon (which we found doesn't
# work on full discussion sections).
#
# FALLBACK: lexicon-based polarity (original implementation). Used
# when retrieved_hits / outcome_map aren't supplied.
delta_used = False
if retrieved_hits is not None and outcome_map is not None:
acc_scores = []
rej_scores = []
for h in retrieved_hits:
doc_id = getattr(getattr(h, "chunk", None), "doc_id", None)
outcome = outcome_map.get(doc_id) if doc_id else None
score = float(getattr(h, "score", 0.0))
if outcome == "accepted":
acc_scores.append(score)
elif outcome == "rejected":
rej_scores.append(score)
if len(acc_scores) >= 2 and len(rej_scores) >= 2:
s_acc_mean = sum(acc_scores) / len(acc_scores)
s_rej_mean = sum(rej_scores) / len(rej_scores)
delta = s_acc_mean - s_rej_mean
# Centered at the empirically-found threshold (+2.3) with
# slope 1/30 because std of deltas is ~30 in raw score units.
tau = _sigmoid((delta - delta_threshold) / delta_scale)
accept_total = s_acc_mean
reject_total = s_rej_mean
delta_used = True
else:
# Insufficient outcome-labeled hits — fall through to lexicon
tau = 0.5
accept_total = 0.0
reject_total = 0.0
else:
tau = None # signal that we need lexicon fallback
per_template_polarity: List[float] = []
for t in argument_templates:
thesis = t.get("thesis", "") or ""
a, r = _polarity_score(thesis, accept_lex, reject_lex)
per_template_polarity.append(a - r)
if not delta_used:
# Lexicon fallback path
accept_total = 0.0
reject_total = 0.0
for t in argument_templates:
thesis = t.get("thesis", "") or ""
a, r = _polarity_score(thesis, accept_lex, reject_lex)
accept_total += a
reject_total += r
for d in drafted_arguments:
text = d.get("argument", "") or ""
a, r = _polarity_score(text, accept_lex, reject_lex)
accept_total += a * 0.5
reject_total += r * 0.5
n_total = max(n_templates + n_drafts, 1)
raw_balance = (accept_total - reject_total) / n_total
tau = _sigmoid(raw_balance)
# ─────────────────────────────────────────────────────────────────
# ψ — coherence: do templates agree with each other?
# ─────────────────────────────────────────────────────────────────
if len(per_template_polarity) >= 2:
# std / (mean abs + ε), inverted
m = sum(per_template_polarity) / len(per_template_polarity)
var = sum((p - m) ** 2 for p in per_template_polarity) / len(per_template_polarity)
std = math.sqrt(var)
mean_abs = sum(abs(p) for p in per_template_polarity) / len(per_template_polarity)
psi = 1.0 - (std / (mean_abs + eps))
psi = _clip01(psi)
elif n_templates == 1:
psi = 0.7 # single template — assume moderate coherence
else:
psi = 0.5 # no signal
# ─────────────────────────────────────────────────────────────────
# φ — alignment: how often did similar templates get accepted in
# the corpus?
# ─────────────────────────────────────────────────────────────────
n_accepted_corpus = 0
n_rejected_corpus = 0
for t in argument_templates:
outcome = t.get("outcome_pattern", "unknown")
freq = int(t.get("frequency", 1))
if outcome == "accepted":
n_accepted_corpus += freq
elif outcome == "rejected":
n_rejected_corpus += freq
total_corpus_outcomes = n_accepted_corpus + n_rejected_corpus
if total_corpus_outcomes > 0:
# Fraction of similar-case outcomes that were accepted
accept_fraction = n_accepted_corpus / total_corpus_outcomes
# Linear map [0,1] — natural already
phi = _clip01(accept_fraction)
else:
# No corpus signal — neutral prior
phi = 0.5
# ─────────────────────────────────────────────────────────────────
# ξ — uncertainty: how confident can we be in the prediction?
# ─────────────────────────────────────────────────────────────────
# Bug-fix: my first attempt punished unanimous corpus signal as
# "anomaly" (deviation from base rate). That was wrong — a unanimous
# 9/9 ACCEPT is HIGH confidence, not high anomaly.
#
# Correct formulation: ξ measures epistemic uncertainty due to small
# sample size. With N corpus outcomes, uncertainty decays as:
# ξ = 1 / (1 + log(1 + N))
# N=0 → ξ=1.0 (no data, fully uncertain — punishes Ω heavily)
# N=2 → ξ=0.48 (very thin evidence)
# N=9 → ξ=0.30 (decent)
# N=20 → ξ=0.25 (good)
# N=50 → ξ=0.20 (strong)
if total_corpus_outcomes >= 1:
xi = 1.0 / (1.0 + math.log(1.0 + total_corpus_outcomes))
xi = _clip01(xi)
else:
xi = 1.0 # zero data → max uncertainty
# ─────────────────────────────────────────────────────────────────
# Ω — outcome probability
# ─────────────────────────────────────────────────────────────────
# When the empirically-validated delta-signal is the basis for τ
# (delta_used=True), we use a SHARP combiner that lets τ dominate:
# Ω = τ adjusted by uncertainty + alignment
# The geometric mean was washing τ out via ψ=0 cases (mixed-polarity
# templates are common in real legal text but should NOT zero out
# the prediction).
#
# When delta isn't available (lexicon fallback), keep the original
# geometric mean formula for backward compatibility.
α, β, γ, δ = omega_weights
if delta_used:
# τ-dominated formula: sigmoid-of-delta with light φ and ξ adjustment
# Confidence factor: (1-ξ) — penalizes thin evidence
# Alignment factor: 0.5 + 0.5*φ (centered at 0.5, reflects topical
# consistency of retrieved cases with their outcomes)
confidence_factor = max(1.0 - xi, 0.3) # floor at 0.3
alignment_factor = 0.5 + 0.5 * phi # 0.5 to 1.0
# Bayesian-flavored: shift τ toward 0.5 by uncertainty,
# then nudge by alignment.
omega_raw = 0.5 + (tau - 0.5) * confidence_factor * alignment_factor
omega = _clip01(omega_raw)
# For debug — preserve geometric-mean for inspection too
a = max(tau, eps) ** α
b = max(phi, eps) ** β
c = max(psi, eps) ** γ
d = max(1.0 - xi, eps) ** δ
else:
a = max(tau, eps) ** α
b = max(phi, eps) ** β
c = max(psi, eps) ** γ
d = max(1.0 - xi, eps) ** δ
total_weight = α + β + γ + δ
omega = _clip01((a * b * c * d) ** (1.0 / total_weight))
return OutcomeSignals(
tau=round(tau, 4),
psi=round(psi, 4),
phi=round(phi, 4),
xi=round(xi, 4),
omega=round(omega, 4),
accept_score_total=accept_total,
reject_score_total=reject_total,
n_templates=n_templates,
n_accepted_in_corpus=n_accepted_corpus,
n_rejected_in_corpus=n_rejected_corpus,
debug={
"raw_balance": round(raw_balance, 3),
"per_template_polarity": [round(p, 2) for p in per_template_polarity],
"omega_components": {
"tau_pow": round(a, 4),
"phi_pow": round(b, 4),
"psi_pow": round(c, 4),
"1-xi_pow": round(d, 4),
},
"weights": list(omega_weights),
"n_corpus_outcomes": total_corpus_outcomes,
},
)
__all__ = ["OutcomeSignals", "compute_outcome_signals"]