"""
Streamlit Dashboard for Plutchik Emotion Recognition with Explainability.
"""
from pathlib import Path
import sys
# Now at the root directory
import sys
from pathlib import Path
root_dir = Path(__file__).resolve().parent
if str(root_dir) not in sys.path:
sys.path.append(str(root_dir))
import streamlit as st
import pandas as pd
import numpy as np
import torch
import plotly.graph_objects as go
import plotly.express as px
import pickle
import os
import requests
import html
from typing import Dict, List
from models.multitask_emotion_model import PluTchikMultiTaskModel
from utils.preprocessing import ERCPreprocessor
from utils.explainability import ExplainabilityEngine
from utils.explainability_v2 import CaptumExplainer
from utils.trainer import PluTchikTrainer
from utils.llm_inference import NemotronClient
from utils.constants import PLUTCHIK, PRIMARY_EMOTIONS, EMOTION_NAMES, NUM_EMOTIONS
from dotenv import load_dotenv
# Load environment variables from .env file for local development.
# In a production environment, these should be set as actual environment variables.
load_dotenv(dotenv_path=Path(__file__).resolve().parent / ".env")
# ============== PAGE CONFIG ==============
st.set_page_config(page_title="Plutchik ERC", page_icon="🎭", layout="wide")
# ============== SESSION STATE INIT ==============
if "history" not in st.session_state:
st.session_state.history = []
if "history_buffer" not in st.session_state:
st.session_state.history_buffer = []
if "prediction" not in st.session_state:
st.session_state.prediction = None
# ============== CUSTOM CSS: PREMIUM AESTHETICS ==============
st.markdown("""
""", unsafe_allow_html=True)
# ============== UI HEADER ==============
st.markdown("""
PLUTCHIK AI
BEYOND WORDS: DECODING THE EMOTIONAL DNA • v2.5 Hardened
""", unsafe_allow_html=True)
# ============== API INTEGRATION (Thin Client) ==============
API_BASE = os.getenv("PLUTCHIK_API_URL", "http://127.0.0.1:8000")
API_KEY = os.getenv("PLUTCHIK_API_KEY")
import subprocess
import socket
import time
def is_port_open(port):
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(('127.0.0.1', port)) == 0
@st.cache_resource
def ensure_backend():
if API_BASE in ["http://127.0.0.1:8000", "http://0.0.0.0:8000", "http://localhost:8000"]:
if not is_port_open(8000):
with st.status("🚀 Initializing Plutchik Neural Engine...", expanded=True) as status:
st.write("Loading RoBERTa weights and trajectory forecasting ODEs...")
proc = subprocess.Popen([sys.executable, "inference_server.py"])
# Polling for availability
max_retries = 45
for i in range(max_retries):
if is_port_open(8000):
status.update(label="✅ Neural Engine Active", state="complete", expanded=False)
return True
time.sleep(2)
if i % 5 == 0:
st.write(f"Waking up the engine... ({i}/{max_retries})")
status.update(label="❌ Engine Initialization Timeout", state="error")
return False
return True
# Ensure backend is running before proceeding
if not ensure_backend():
st.warning("⚠️ The local inference core could not be started. Some features (RoBERTa analysis, attributions) will be unavailable. Nemotron-3 (LLM) mode is still functional.")
def call_api(endpoint: str, payload: dict, use_auth: bool = True):
headers = {"Content-Type": "application/json"}
if use_auth:
headers["X-API-Key"] = API_KEY
try:
response = requests.post(f"{API_BASE}/{endpoint}", json=payload, headers=headers, timeout=60.0)
response.raise_for_status()
return response.json()
except requests.exceptions.HTTPError as e:
try:
err_detail = response.json().get("detail", str(e))
except Exception:
err_detail = str(e)
st.error(f"❌ Server Error: {err_detail}")
return None
except requests.exceptions.ConnectionError:
st.error("❌ Connection Refused: The inference server is unreachable.")
return None
except requests.exceptions.Timeout:
st.error("❌ Timeout: The inference server took too long to respond.")
return None
except Exception as e:
st.error(f"❌ Unexpected Error: {e}")
return None
# Load utilities
preprocessor = ERCPreprocessor(PLUTCHIK)
llm_client = NemotronClient()
# Load centroids for embedding similarity (lightweight)
@st.cache_resource
def load_centroids():
model_dir = Path(__file__).parent / "my_plutchik_model"
centroids_path = model_dir / "emotion_centroids.pkl"
if centroids_path.exists():
with open(centroids_path, "rb") as f:
return pickle.load(f)
return {}
emotion_centroids = load_centroids()
# ============== SIDEBAR CONFIGURATION ==============
with st.sidebar:
st.markdown("### ⚙️ Engine Control")
analysis_mode = st.radio(
"Analysis Protocol",
["Single Utterance", "Conversation Arc", "Comparative Analysis", "Dynamic Intelligence", "Batch File Upload"],
help="Choose the scale and type of emotional analysis."
)
model_type = st.radio("Inference Core", ["Local RoBERTa", "Nemotron-3 (LLM)", "Compare Both Models"])
st.markdown("---")
st.markdown("### 📍 Context Matrix")
scenario = st.selectbox(
"Scenario Environment",
["workplace", "friendship", "family", "romance", "support", "academic",
"conflict", "casual", "social", "travel", "technology", "creative", "wellbeing", "community"]
)
col_t1, col_t2 = st.columns([3, 1])
with col_t1:
topic_list = ["general", "billing", "technical", "feedback", "deadline", "resolution", "complaint", "other"]
topic = st.selectbox("Operational Topic", topic_list, index=0)
with col_t2:
topic_manual = st.text_input("Custom Topic", placeholder="...")
if topic_manual: topic = topic_manual
col_s1, col_s2 = st.columns([3, 1])
with col_s1:
persona_list = ["USER", "AGENT", "CUSTOMER", "EMPLOYEE", "MANAGER", "ADMIN", "other"]
speaker = st.selectbox("Source Persona", persona_list, index=0)
with col_s2:
speaker_manual = st.text_input("Custom Persona", placeholder="...")
if speaker_manual: speaker = speaker_manual
st.markdown("---")
use_history = st.checkbox("Persistent Context", value=True)
if not use_history:
prev_turns_manual = st.text_area("Manual Context Buffer", placeholder="Turn 1 | Turn 2...")
else:
st.caption("Using session history for context-aware inference.")
use_captum_explain = False
if analysis_mode == "Single Utterance":
use_captum_explain = st.checkbox(
"Full explainability (Captum on full context window; slower)",
value=False,
help="Calls POST /explain so token IG matches [CONTEXT]…[CURRENT]… input seen by the model.",
)
st.markdown("---")
if API_KEY:
st.sidebar.success("🔑 API Key Loaded")
else:
st.sidebar.error("❌ API Key Missing")
batch_max_rows = 200
if analysis_mode == "Batch File Upload":
batch_max_rows = st.number_input("Max CSV rows to score", min_value=10, max_value=2000, value=200, step=10)
# ============== MAIN UI: INPUT SECTION ==============
input_container = st.container()
with input_container:
if analysis_mode == "Single Utterance":
user_text = st.text_area(
label="Input Signal",
label_visibility="collapsed",
placeholder="Transmit message for emotional decoding...",
height=160
)
elif analysis_mode == "Conversation Arc":
user_text = st.text_area(
label="Dialogue Data",
label_visibility="collapsed",
placeholder="Enter dialogue stream (format SPEAKER: TEXT)\n\nExample:\nUSER: This is unacceptable!\nAGENT: I'm so sorry you're feeling this way.",
height=280
)
elif analysis_mode == "Comparative Analysis":
col_v1, col_v2 = st.columns(2)
with col_v1:
st.markdown("#### Stream A")
user_text_1 = st.text_area("A", label_visibility="collapsed", placeholder="USER: Hello\nAGENT: Hi", height=200)
with col_v2:
st.markdown("#### Stream B")
user_text_2 = st.text_area("B", label_visibility="collapsed", placeholder="USER: Hello\nAGENT: GO AWAY", height=200)
user_text = user_text_1 # For button validation
else:
st.markdown("#### 📦 Batch Configuration")
c1, c2 = st.columns([2, 1])
with c1:
uploaded_file = st.file_uploader("Upload Signal Batch (CSV)", type=["csv"], help="CSV should have at least a 'text' column. Optional: 'speaker', 'topic', 'scenario'")
with c2:
st.markdown("
Inference Insight
The signal indicates {pred['emotion'].title()} with {pred['emotion_confidence']:.1%} statistical confidence.
Subtext analysis suggests the intent is {'ironic/hidden' if pred['sarcasm_confidence'] > 0.5 else 'literal'},
operating within the {PLUTCHIK[pred["emotion"]]["ring"].upper()} intensity layer of the Plutchik ecosystem.
"""
st.markdown(interpretation, unsafe_allow_html=True)
# Metrics row
col1, col2, col3, col4 = st.columns(4)
with col1:
st.markdown(f"""
", "[CONTEXT]", "[/CONTEXT]", "[CURRENT]", "[/CURRENT]"])]
fig_attr = px.bar(
attr_df,
x="score",
y="token",
orientation='h',
color="score",
color_continuous_scale="RdBu",
)
fig_attr.update_layout(
height=400,
paper_bgcolor='rgba(0,0,0,0)',
plot_bgcolor='rgba(0,0,0,0)',
font=dict(color='#e6edf3'),
xaxis=dict(showgrid=True, gridcolor='#30363d', zerolinecolor='#30363d'),
yaxis=dict(showgrid=False)
)
st.plotly_chart(fig_attr, use_container_width=True)
else:
st.info("No token attribution data available for this prediction.")
ctx_top = pred.get("context_span_top") or []
cur_top = pred.get("current_span_top") or []
if ctx_top or cur_top:
st.markdown("##### Context window (T-2 / T-1) vs current turn — top tokens by |IG|")
cc1, cc2 = st.columns(2)
with cc1:
st.caption("Tokens before [CURRENT] span")
st.dataframe(pd.DataFrame(ctx_top), use_container_width=True, height=260)
with cc2:
st.caption("Tokens inside [CURRENT] span (scenario + topic + utterance)")
st.dataframe(pd.DataFrame(cur_top), use_container_width=True, height=260)
else:
st.caption("Context vs current span breakdown appears when running with “Full explainability”.")
else:
st.info("Token attribution data is currently unavailable. Enable “Full explainability” in the sidebar (Single Utterance) or use POST /explain.")
with int_tab2:
st.write("**Embedding Heatmap (Sampled Dims)**")
emb_blob = pred.get("embedding_info", {}).get("all_token_embeddings")
if emb_blob is not None and hasattr(emb_blob, "shape") and emb_blob.size > 0:
heatmap_data = emb_blob
hm_min = heatmap_data.min(axis=0, keepdims=True)
hm_max = heatmap_data.max(axis=0, keepdims=True)
hm_max = np.where(hm_max == hm_min, 1.0, hm_max)
hm_normalized = (heatmap_data - hm_min) / (hm_max - hm_min + 1e-8)
sample_cols = np.linspace(0, hm_normalized.shape[1] - 1, 30, dtype=int)
hm_sampled = hm_normalized[:, sample_cols]
fig_hm = go.Figure(
data=go.Heatmap(
z=hm_sampled,
colorscale='Viridis',
x=[f'Dim {i}' for i in sample_cols],
y=[f'Token {i}' for i in range(hm_sampled.shape[0])]
)
)
fig_hm.update_layout(
height=400,
xaxis_title="Hidden Dims",
yaxis_title="Tokens",
paper_bgcolor='rgba(0,0,0,0)',
plot_bgcolor='rgba(0,0,0,0)',
font=dict(color='#e6edf3')
)
st.plotly_chart(fig_hm, use_container_width=True)
else:
st.warning("Embedding visualization is only available for the Local RoBERTa model.")
with int_tab3:
st.write("**Cosine Similarity to Top Emotions**")
st.caption("Shows how close the prediction is to learned emotion centroids in embedding space.")
if "embedding_info" in pred and "cls_embedding" in pred["embedding_info"] and emotion_centroids:
cls_embedding = pred["embedding_info"]["cls_embedding"]
similarities = {}
# Simple cosine similarity via numpy
def cosine_sim(a, b):
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b) + 1e-8)
emotion_names_list = sorted(PLUTCHIK.keys())
for emotion_idx, centroid in emotion_centroids.items():
sim = cosine_sim(cls_embedding, centroid)
emotion_str = emotion_names_list[emotion_idx] if isinstance(emotion_idx, (int, np.integer)) else emotion_idx
similarities[emotion_str] = sim
sorted_sims = sorted(similarities.items(), key=lambda x: x[1], reverse=True)[:10]
sim_emotions, sim_values = zip(*sorted_sims)
fig_sim = go.Figure(
data=go.Bar(
x=list(sim_values),
y=[e.title() for e in sim_emotions],
orientation='h',
marker_color=['#58a6ff' if e == pred["emotion"] else '#30363d' for e in sim_emotions],
text=[f'{v:.3f}' for v in sim_values],
textposition='auto'
)
)
fig_sim.update_layout(
xaxis_title="Cosine Similarity",
yaxis_title="",
height=350,
showlegend=False,
paper_bgcolor='rgba(0,0,0,0)',
plot_bgcolor='rgba(0,0,0,0)',
font=dict(color='#e6edf3'),
xaxis=dict(showgrid=True, gridcolor='#30363d', zerolinecolor='#30363d'),
yaxis=dict(showgrid=False)
)
st.plotly_chart(fig_sim, use_container_width=True)
else:
st.warning("Cosine similarity analysis requires the Local RoBERTa model's learned centroids.")
# ============== SIDEBAR: HISTORY & ABOUT ==============
st.sidebar.markdown("---")
st.sidebar.subheader("🕒 Prediction History")
if st.session_state.prediction:
# Add to history if not already the latest
latest = st.session_state.prediction
if not st.session_state.history or st.session_state.history[0]["text"] != latest.get("text", "Unknown text"):
st.session_state.history.insert(0, {
"emotion": latest["emotion"],
"confidence": latest["emotion_confidence"],
"text": latest.get("text", "Unknown text")[:30] + "..."
})
for h in st.session_state.history[:5]:
st.sidebar.markdown(f"""
{html.escape(h['text'])}
{html.escape(h['emotion'].title())} ({h['confidence']:.0%})
""", unsafe_allow_html=True)
st.sidebar.markdown("---")
st.sidebar.markdown("### 📚 The Plutchik Lexicon")
with st.sidebar.expander("Explore Emotional Layers"):
st.markdown("""
Intense: Raw, visceral reactions (Rage, Grief, Terror).
Primary: Balanced, conscious states (Anger, Sadness, Fear).
Mild: Subtle, transient feelings (Annoyance, Pensiveness, Apprehension).
Dyadic: Complex blends (Contempt, Remorse, Love).
""", unsafe_allow_html=True)
st.sidebar.info("""
The **Plutchik Wheel** defines emotions as a spectrum of 32 classes.
This AI decodes the **subtext**—detecting when words and intent diverge.
""")
st.sidebar.markdown("""
---
**Version 2.5.0 Hardened Production Edition**
© 2026 Plutchik ERC Project
""")