# ═══════════════════════════════════════════════════════════════════════════════ # File: app/database/memory_db.py # Description: Database-backed conversation memory store # ═══════════════════════════════════════════════════════════════════════════════ """ Database-backed conversation memory that persists across restarts. Provides same interface as in-memory store for drop-in replacement. """ import uuid from typing import Dict, List, Optional, Any from datetime import datetime from sqlalchemy import select, func from sqlalchemy.orm import selectinload from app.database.db import get_db_manager from app.database.models import Conversation, Message, Intelligence class DatabaseMemoryStore: """ Persistent conversation storage using SQLAlchemy. Compatible with the original ConversationMemory interface but stores data in SQLite/PostgreSQL. """ def __init__(self): self._cache: Dict[str, Dict] = {} # Hot cache for performance async def get_or_create( self, conversation_id: Optional[str] = None, sender_id: Optional[str] = None ) -> Dict: """Get existing conversation or create new one.""" if not conversation_id: conversation_id = f"conv_{uuid.uuid4().hex[:12]}" # Check cache first if conversation_id in self._cache: return self._cache[conversation_id] db = get_db_manager() async with db.session() as session: # Try to find existing result = await session.execute( select(Conversation) .where(Conversation.id == conversation_id) .options( selectinload(Conversation.messages), selectinload(Conversation.intelligence_items) ) ) conv = result.scalar_one_or_none() if conv: # Convert to dict for compatibility conv_dict = conv.to_dict() # 🔥 [RISK 5] HISTORY PRUNING: Cap history at 20 records (10 turns) if len(conv_dict.get("history", [])) > 20: conv_dict["history"] = conv_dict["history"][-20:] self._cache[conversation_id] = conv_dict return conv_dict # Create new conversation new_conv = Conversation( id=conversation_id, sender_id=sender_id, phase="hook" ) session.add(new_conv) await session.flush() conv_dict = { "id": conversation_id, "sender_id": sender_id, "created_at": datetime.utcnow().isoformat(), "updated_at": datetime.utcnow().isoformat(), "message_count": 0, "phase": "hook", "trust_score": 0.0, "scam_type": None, "persona": None, "history": [], "aggregated_intelligence": { "phone_numbers": [], "upi_ids": [], "bank_accounts": [], "ifsc_codes": [], "emails": [], "urls": [], "credit_cards": [], "otps": [], "rat_apps": [], "pan_cards": [], "aadhar_numbers": [] }, "threat_intel": None, "risk_score": 0.0 } self._cache[conversation_id] = conv_dict return conv_dict async def get(self, conversation_id: str) -> Optional[Dict]: """Get conversation by ID.""" if conversation_id in self._cache: return self._cache[conversation_id] return await self.get_or_create(conversation_id) async def update( self, conversation_id: str, scammer_message: str, honeypot_response: str, intelligence: Dict, phase: str, scam_type: Optional[str] = None, persona: Optional[str] = None, risk_score: float = 0.0, trust_score: float = 0.0 ) -> Dict: """Update conversation with new message exchange.""" conv_dict = await self.get_or_create(conversation_id) db = get_db_manager() async with db.session() as session: # Get conversation from DB result = await session.execute( select(Conversation).where(Conversation.id == conversation_id) ) conv = result.scalar_one_or_none() if not conv: return conv_dict # Update conversation conv.message_count += 1 conv.phase = phase conv.updated_at = datetime.utcnow() if scam_type: conv.scam_type = scam_type if persona: conv.persona = persona conv.risk_score = risk_score conv.trust_score = trust_score # Add message msg = Message( conversation_id=conversation_id, turn=conv.message_count, scammer_message=scammer_message, honeypot_response=honeypot_response, phase=phase ) session.add(msg) # Add intelligence items for entity_type, values in intelligence.items(): if values and isinstance(values, list): for value in values: # Check if exists existing = await session.execute( select(Intelligence).where( Intelligence.conversation_id == conversation_id, Intelligence.entity_type == entity_type, Intelligence.entity_value == value ) ) if not existing.scalar_one_or_none(): intel = Intelligence( conversation_id=conversation_id, entity_type=entity_type, entity_value=value ) session.add(intel) await session.flush() # Update cache conv_dict["message_count"] = conv.message_count conv_dict["phase"] = phase conv_dict["updated_at"] = datetime.utcnow().isoformat() if scam_type: conv_dict["scam_type"] = scam_type if persona: conv_dict["persona"] = persona conv_dict["risk_score"] = risk_score conv_dict["trust_score"] = trust_score conv_dict["history"].append({ "turn": conv.message_count, "timestamp": datetime.utcnow().isoformat(), "scammer_message": scammer_message, "honeypot_response": honeypot_response, "phase": phase, "intelligence": intelligence }) # 🔥 [RISK 5] HISTORY PRUNING: Cap history at 20 records (10 turns) if len(conv_dict["history"]) > 20: conv_dict["history"] = conv_dict["history"][-20:] # Update aggregated intelligence in cache for key, values in intelligence.items(): # Ensure target is a list current_val = conv_dict["aggregated_intelligence"].get(key) if not isinstance(current_val, list): current_val = [current_val] if current_val else [] conv_dict["aggregated_intelligence"][key] = current_val if isinstance(values, list): for item in values: if item not in current_val: current_val.append(item) # Removed redundant appending of the list itself # 🔥 [RISK 5] TRACE PRUNING: Cap reasoning segments if len(conv_dict["aggregated_intelligence"].get("reasoning_history", [])) > 5: conv_dict["aggregated_intelligence"]["reasoning_history"] = \ conv_dict["aggregated_intelligence"]["reasoning_history"][-5:] self._cache[conversation_id] = conv_dict return conv_dict async def update_intelligence(self, conversation_id: str, intelligence: Dict[str, Any]) -> Dict: """Explicitly update intelligence fields (e.g., keywords).""" conv_dict = await self.get_or_create(conversation_id) db = get_db_manager() async with db.session() as session: # Update DB (Intelligence items) for entity_type, values in intelligence.items(): if values and isinstance(values, list): for value in values: existing = await session.execute( select(Intelligence).where( Intelligence.conversation_id == conversation_id, Intelligence.entity_type == entity_type, Intelligence.entity_value == str(value) ) ) if not existing.scalar_one_or_none(): intel = Intelligence( conversation_id=conversation_id, entity_type=entity_type, entity_value=str(value) ) session.add(intel) await session.flush() # Update Cache for key, values in intelligence.items(): # Ensure target is a list current_val = conv_dict["aggregated_intelligence"].get(key) if not isinstance(current_val, list): current_val = [current_val] if current_val else [] conv_dict["aggregated_intelligence"][key] = current_val for val in (values if isinstance(values, list) else [values]): if val not in current_val: current_val.append(val) # 🔥 [RISK 5] TRACE PRUNING: Cap reasoning segments if len(conv_dict["aggregated_intelligence"].get("reasoning_history", [])) > 5: conv_dict["aggregated_intelligence"]["reasoning_history"] = \ conv_dict["aggregated_intelligence"]["reasoning_history"][-5:] self._cache[conversation_id] = conv_dict return conv_dict async def get_statistics(self) -> Dict[str, Any]: """Get global statistics.""" db = get_db_manager() async with db.session() as session: # Count conversations total_conv = await session.execute(select(func.count(Conversation.id))) total_conversations = total_conv.scalar() or 0 # Count messages total_msg = await session.execute(select(func.count(Message.id))) total_messages = total_msg.scalar() or 0 # Count scams detected scams = await session.execute( select(func.count(Conversation.id)) .where(Conversation.scam_type.isnot(None)) ) scams_detected = scams.scalar() or 0 # Count intelligence intel = await session.execute(select(func.count(Intelligence.id))) intelligence_extracted = intel.scalar() or 0 return { "total_conversations": total_conversations, "total_messages": total_messages, "scams_detected": scams_detected, "intelligence_extracted": intelligence_extracted, "active_conversations": len(self._cache), "scam_distribution": {} } def get_history_text(self, conversation_id: str, max_turns: int = 10) -> str: """Get conversation history as formatted text.""" conv = self._cache.get(conversation_id) if not conv: # Try to fetch from DB if not in cache (Cold fetch) return "" history = conv.get("history", [])[-max_turns:] lines = [] for msg in history: lines.append(f"Caller: {msg.get('scammer_message', '')}") lines.append(f"Me: {msg.get('honeypot_response', '')}") return "\n".join(lines) async def clear(self, conversation_id: str) -> bool: """Explicitly remove a conversation from cache and DB.""" from sqlalchemy import delete # 1. Clear Cache if conversation_id in self._cache: del self._cache[conversation_id] # 2. Clear Database db = get_db_manager() async with db.session() as session: try: # Delete messages first await session.execute( delete(Message).where(Message.conversation_id == conversation_id) ) await session.execute( delete(Intelligence).where(Intelligence.conversation_id == conversation_id) ) await session.execute( delete(Conversation).where(Conversation.id == conversation_id) ) await session.commit() return True except Exception as e: # Fallback to Textual SQL if ORM fails from sqlalchemy import text try: await session.execute(text(f"DELETE FROM messages WHERE conversation_id = :id"), {"id": conversation_id}) await session.execute(text(f"DELETE FROM intelligence WHERE conversation_id = :id"), {"id": conversation_id}) await session.execute(text(f"DELETE FROM conversations WHERE id = :id"), {"id": conversation_id}) await session.commit() return True except: return False # Global instance db_memory_store = DatabaseMemoryStore() __all__ = ["DatabaseMemoryStore", "db_memory_store"]