| |
| |
| |
| |
|
|
| """ |
| Database-backed conversation memory that persists across restarts. |
| Provides same interface as in-memory store for drop-in replacement. |
| """ |
|
|
| import uuid |
| from typing import Dict, List, Optional, Any |
| from datetime import datetime |
| from sqlalchemy import select, func |
| from sqlalchemy.orm import selectinload |
|
|
| from app.database.db import get_db_manager |
| from app.database.models import Conversation, Message, Intelligence |
|
|
|
|
| class DatabaseMemoryStore: |
| """ |
| Persistent conversation storage using SQLAlchemy. |
| |
| Compatible with the original ConversationMemory interface |
| but stores data in SQLite/PostgreSQL. |
| """ |
| |
| def __init__(self): |
| self._cache: Dict[str, Dict] = {} |
| |
| async def get_or_create( |
| self, |
| conversation_id: Optional[str] = None, |
| sender_id: Optional[str] = None |
| ) -> Dict: |
| """Get existing conversation or create new one.""" |
| |
| if not conversation_id: |
| conversation_id = f"conv_{uuid.uuid4().hex[:12]}" |
| |
| |
| if conversation_id in self._cache: |
| return self._cache[conversation_id] |
| |
| db = get_db_manager() |
| async with db.session() as session: |
| |
| result = await session.execute( |
| select(Conversation) |
| .where(Conversation.id == conversation_id) |
| .options( |
| selectinload(Conversation.messages), |
| selectinload(Conversation.intelligence_items) |
| ) |
| ) |
| conv = result.scalar_one_or_none() |
| |
| if conv: |
| |
| conv_dict = conv.to_dict() |
| |
| |
| if len(conv_dict.get("history", [])) > 20: |
| conv_dict["history"] = conv_dict["history"][-20:] |
| |
| self._cache[conversation_id] = conv_dict |
| return conv_dict |
| |
| |
| new_conv = Conversation( |
| id=conversation_id, |
| sender_id=sender_id, |
| phase="hook" |
| ) |
| session.add(new_conv) |
| await session.flush() |
| |
| conv_dict = { |
| "id": conversation_id, |
| "sender_id": sender_id, |
| "created_at": datetime.utcnow().isoformat(), |
| "updated_at": datetime.utcnow().isoformat(), |
| "message_count": 0, |
| "phase": "hook", |
| "trust_score": 0.0, |
| "scam_type": None, |
| "persona": None, |
| "history": [], |
| "aggregated_intelligence": { |
| "phone_numbers": [], |
| "upi_ids": [], |
| "bank_accounts": [], |
| "ifsc_codes": [], |
| "emails": [], |
| "urls": [], |
| "credit_cards": [], |
| "otps": [], |
| "rat_apps": [], |
| "pan_cards": [], |
| "aadhar_numbers": [] |
| }, |
| "threat_intel": None, |
| "risk_score": 0.0 |
| } |
| |
| self._cache[conversation_id] = conv_dict |
| return conv_dict |
| |
| async def get(self, conversation_id: str) -> Optional[Dict]: |
| """Get conversation by ID.""" |
| if conversation_id in self._cache: |
| return self._cache[conversation_id] |
| |
| return await self.get_or_create(conversation_id) |
| |
| async def update( |
| self, |
| conversation_id: str, |
| scammer_message: str, |
| honeypot_response: str, |
| intelligence: Dict, |
| phase: str, |
| scam_type: Optional[str] = None, |
| persona: Optional[str] = None, |
| risk_score: float = 0.0, |
| trust_score: float = 0.0 |
| ) -> Dict: |
| """Update conversation with new message exchange.""" |
| conv_dict = await self.get_or_create(conversation_id) |
| |
| db = get_db_manager() |
| async with db.session() as session: |
| |
| result = await session.execute( |
| select(Conversation).where(Conversation.id == conversation_id) |
| ) |
| conv = result.scalar_one_or_none() |
| |
| if not conv: |
| return conv_dict |
| |
| |
| conv.message_count += 1 |
| conv.phase = phase |
| conv.updated_at = datetime.utcnow() |
| |
| if scam_type: |
| conv.scam_type = scam_type |
| if persona: |
| conv.persona = persona |
| |
| conv.risk_score = risk_score |
| conv.trust_score = trust_score |
| |
| |
| msg = Message( |
| conversation_id=conversation_id, |
| turn=conv.message_count, |
| scammer_message=scammer_message, |
| honeypot_response=honeypot_response, |
| phase=phase |
| ) |
| session.add(msg) |
| |
| |
| for entity_type, values in intelligence.items(): |
| if values and isinstance(values, list): |
| for value in values: |
| |
| existing = await session.execute( |
| select(Intelligence).where( |
| Intelligence.conversation_id == conversation_id, |
| Intelligence.entity_type == entity_type, |
| Intelligence.entity_value == value |
| ) |
| ) |
| if not existing.scalar_one_or_none(): |
| intel = Intelligence( |
| conversation_id=conversation_id, |
| entity_type=entity_type, |
| entity_value=value |
| ) |
| session.add(intel) |
| |
| await session.flush() |
| |
| |
| conv_dict["message_count"] = conv.message_count |
| conv_dict["phase"] = phase |
| conv_dict["updated_at"] = datetime.utcnow().isoformat() |
| |
| if scam_type: |
| conv_dict["scam_type"] = scam_type |
| if persona: |
| conv_dict["persona"] = persona |
| |
| conv_dict["risk_score"] = risk_score |
| conv_dict["trust_score"] = trust_score |
| |
| conv_dict["history"].append({ |
| "turn": conv.message_count, |
| "timestamp": datetime.utcnow().isoformat(), |
| "scammer_message": scammer_message, |
| "honeypot_response": honeypot_response, |
| "phase": phase, |
| "intelligence": intelligence |
| }) |
| |
| |
| if len(conv_dict["history"]) > 20: |
| conv_dict["history"] = conv_dict["history"][-20:] |
| |
| |
| for key, values in intelligence.items(): |
| |
| current_val = conv_dict["aggregated_intelligence"].get(key) |
| if not isinstance(current_val, list): |
| current_val = [current_val] if current_val else [] |
| conv_dict["aggregated_intelligence"][key] = current_val |
|
|
| if isinstance(values, list): |
| for item in values: |
| if item not in current_val: |
| current_val.append(item) |
| |
| |
| |
| if len(conv_dict["aggregated_intelligence"].get("reasoning_history", [])) > 5: |
| conv_dict["aggregated_intelligence"]["reasoning_history"] = \ |
| conv_dict["aggregated_intelligence"]["reasoning_history"][-5:] |
| |
| self._cache[conversation_id] = conv_dict |
| return conv_dict |
| |
| async def update_intelligence(self, conversation_id: str, intelligence: Dict[str, Any]) -> Dict: |
| """Explicitly update intelligence fields (e.g., keywords).""" |
| conv_dict = await self.get_or_create(conversation_id) |
| |
| db = get_db_manager() |
| async with db.session() as session: |
| |
| for entity_type, values in intelligence.items(): |
| if values and isinstance(values, list): |
| for value in values: |
| existing = await session.execute( |
| select(Intelligence).where( |
| Intelligence.conversation_id == conversation_id, |
| Intelligence.entity_type == entity_type, |
| Intelligence.entity_value == str(value) |
| ) |
| ) |
| if not existing.scalar_one_or_none(): |
| intel = Intelligence( |
| conversation_id=conversation_id, |
| entity_type=entity_type, |
| entity_value=str(value) |
| ) |
| session.add(intel) |
| |
| await session.flush() |
| |
| |
| for key, values in intelligence.items(): |
| |
| current_val = conv_dict["aggregated_intelligence"].get(key) |
| if not isinstance(current_val, list): |
| current_val = [current_val] if current_val else [] |
| conv_dict["aggregated_intelligence"][key] = current_val |
|
|
| for val in (values if isinstance(values, list) else [values]): |
| if val not in current_val: |
| current_val.append(val) |
| |
| |
| if len(conv_dict["aggregated_intelligence"].get("reasoning_history", [])) > 5: |
| conv_dict["aggregated_intelligence"]["reasoning_history"] = \ |
| conv_dict["aggregated_intelligence"]["reasoning_history"][-5:] |
| |
| self._cache[conversation_id] = conv_dict |
| return conv_dict |
| |
| async def get_statistics(self) -> Dict[str, Any]: |
| """Get global statistics.""" |
| db = get_db_manager() |
| async with db.session() as session: |
| |
| total_conv = await session.execute(select(func.count(Conversation.id))) |
| total_conversations = total_conv.scalar() or 0 |
| |
| |
| total_msg = await session.execute(select(func.count(Message.id))) |
| total_messages = total_msg.scalar() or 0 |
| |
| |
| scams = await session.execute( |
| select(func.count(Conversation.id)) |
| .where(Conversation.scam_type.isnot(None)) |
| ) |
| scams_detected = scams.scalar() or 0 |
| |
| |
| intel = await session.execute(select(func.count(Intelligence.id))) |
| intelligence_extracted = intel.scalar() or 0 |
| |
| return { |
| "total_conversations": total_conversations, |
| "total_messages": total_messages, |
| "scams_detected": scams_detected, |
| "intelligence_extracted": intelligence_extracted, |
| "active_conversations": len(self._cache), |
| "scam_distribution": {} |
| } |
| |
| def get_history_text(self, conversation_id: str, max_turns: int = 10) -> str: |
| """Get conversation history as formatted text.""" |
| conv = self._cache.get(conversation_id) |
| if not conv: |
| |
| return "" |
| |
| history = conv.get("history", [])[-max_turns:] |
| lines = [] |
| |
| for msg in history: |
| lines.append(f"Caller: {msg.get('scammer_message', '')}") |
| lines.append(f"Me: {msg.get('honeypot_response', '')}") |
| |
| return "\n".join(lines) |
|
|
| async def clear(self, conversation_id: str) -> bool: |
| """Explicitly remove a conversation from cache and DB.""" |
| from sqlalchemy import delete |
| |
| |
| if conversation_id in self._cache: |
| del self._cache[conversation_id] |
| |
| |
| db = get_db_manager() |
| async with db.session() as session: |
| try: |
| |
| await session.execute( |
| delete(Message).where(Message.conversation_id == conversation_id) |
| ) |
| await session.execute( |
| delete(Intelligence).where(Intelligence.conversation_id == conversation_id) |
| ) |
| await session.execute( |
| delete(Conversation).where(Conversation.id == conversation_id) |
| ) |
| await session.commit() |
| return True |
| except Exception as e: |
| |
| from sqlalchemy import text |
| try: |
| await session.execute(text(f"DELETE FROM messages WHERE conversation_id = :id"), {"id": conversation_id}) |
| await session.execute(text(f"DELETE FROM intelligence WHERE conversation_id = :id"), {"id": conversation_id}) |
| await session.execute(text(f"DELETE FROM conversations WHERE id = :id"), {"id": conversation_id}) |
| await session.commit() |
| return True |
| except: |
| return False |
|
|
|
|
| |
| db_memory_store = DatabaseMemoryStore() |
|
|
| __all__ = ["DatabaseMemoryStore", "db_memory_store"] |
|
|