avinash-rai's picture
fix: Remove 'scammer' word leak and improve human-likeness
6af17ac
# ═══════════════════════════════════════════════════════════════════════════════
# File: app/database/memory_db.py
# Description: Database-backed conversation memory store
# ═══════════════════════════════════════════════════════════════════════════════
"""
Database-backed conversation memory that persists across restarts.
Provides same interface as in-memory store for drop-in replacement.
"""
import uuid
from typing import Dict, List, Optional, Any
from datetime import datetime
from sqlalchemy import select, func
from sqlalchemy.orm import selectinload
from app.database.db import get_db_manager
from app.database.models import Conversation, Message, Intelligence
class DatabaseMemoryStore:
"""
Persistent conversation storage using SQLAlchemy.
Compatible with the original ConversationMemory interface
but stores data in SQLite/PostgreSQL.
"""
def __init__(self):
self._cache: Dict[str, Dict] = {} # Hot cache for performance
async def get_or_create(
self,
conversation_id: Optional[str] = None,
sender_id: Optional[str] = None
) -> Dict:
"""Get existing conversation or create new one."""
if not conversation_id:
conversation_id = f"conv_{uuid.uuid4().hex[:12]}"
# Check cache first
if conversation_id in self._cache:
return self._cache[conversation_id]
db = get_db_manager()
async with db.session() as session:
# Try to find existing
result = await session.execute(
select(Conversation)
.where(Conversation.id == conversation_id)
.options(
selectinload(Conversation.messages),
selectinload(Conversation.intelligence_items)
)
)
conv = result.scalar_one_or_none()
if conv:
# Convert to dict for compatibility
conv_dict = conv.to_dict()
# 🔥 [RISK 5] HISTORY PRUNING: Cap history at 20 records (10 turns)
if len(conv_dict.get("history", [])) > 20:
conv_dict["history"] = conv_dict["history"][-20:]
self._cache[conversation_id] = conv_dict
return conv_dict
# Create new conversation
new_conv = Conversation(
id=conversation_id,
sender_id=sender_id,
phase="hook"
)
session.add(new_conv)
await session.flush()
conv_dict = {
"id": conversation_id,
"sender_id": sender_id,
"created_at": datetime.utcnow().isoformat(),
"updated_at": datetime.utcnow().isoformat(),
"message_count": 0,
"phase": "hook",
"trust_score": 0.0,
"scam_type": None,
"persona": None,
"history": [],
"aggregated_intelligence": {
"phone_numbers": [],
"upi_ids": [],
"bank_accounts": [],
"ifsc_codes": [],
"emails": [],
"urls": [],
"credit_cards": [],
"otps": [],
"rat_apps": [],
"pan_cards": [],
"aadhar_numbers": []
},
"threat_intel": None,
"risk_score": 0.0
}
self._cache[conversation_id] = conv_dict
return conv_dict
async def get(self, conversation_id: str) -> Optional[Dict]:
"""Get conversation by ID."""
if conversation_id in self._cache:
return self._cache[conversation_id]
return await self.get_or_create(conversation_id)
async def update(
self,
conversation_id: str,
scammer_message: str,
honeypot_response: str,
intelligence: Dict,
phase: str,
scam_type: Optional[str] = None,
persona: Optional[str] = None,
risk_score: float = 0.0,
trust_score: float = 0.0
) -> Dict:
"""Update conversation with new message exchange."""
conv_dict = await self.get_or_create(conversation_id)
db = get_db_manager()
async with db.session() as session:
# Get conversation from DB
result = await session.execute(
select(Conversation).where(Conversation.id == conversation_id)
)
conv = result.scalar_one_or_none()
if not conv:
return conv_dict
# Update conversation
conv.message_count += 1
conv.phase = phase
conv.updated_at = datetime.utcnow()
if scam_type:
conv.scam_type = scam_type
if persona:
conv.persona = persona
conv.risk_score = risk_score
conv.trust_score = trust_score
# Add message
msg = Message(
conversation_id=conversation_id,
turn=conv.message_count,
scammer_message=scammer_message,
honeypot_response=honeypot_response,
phase=phase
)
session.add(msg)
# Add intelligence items
for entity_type, values in intelligence.items():
if values and isinstance(values, list):
for value in values:
# Check if exists
existing = await session.execute(
select(Intelligence).where(
Intelligence.conversation_id == conversation_id,
Intelligence.entity_type == entity_type,
Intelligence.entity_value == value
)
)
if not existing.scalar_one_or_none():
intel = Intelligence(
conversation_id=conversation_id,
entity_type=entity_type,
entity_value=value
)
session.add(intel)
await session.flush()
# Update cache
conv_dict["message_count"] = conv.message_count
conv_dict["phase"] = phase
conv_dict["updated_at"] = datetime.utcnow().isoformat()
if scam_type:
conv_dict["scam_type"] = scam_type
if persona:
conv_dict["persona"] = persona
conv_dict["risk_score"] = risk_score
conv_dict["trust_score"] = trust_score
conv_dict["history"].append({
"turn": conv.message_count,
"timestamp": datetime.utcnow().isoformat(),
"scammer_message": scammer_message,
"honeypot_response": honeypot_response,
"phase": phase,
"intelligence": intelligence
})
# 🔥 [RISK 5] HISTORY PRUNING: Cap history at 20 records (10 turns)
if len(conv_dict["history"]) > 20:
conv_dict["history"] = conv_dict["history"][-20:]
# Update aggregated intelligence in cache
for key, values in intelligence.items():
# Ensure target is a list
current_val = conv_dict["aggregated_intelligence"].get(key)
if not isinstance(current_val, list):
current_val = [current_val] if current_val else []
conv_dict["aggregated_intelligence"][key] = current_val
if isinstance(values, list):
for item in values:
if item not in current_val:
current_val.append(item)
# Removed redundant appending of the list itself
# 🔥 [RISK 5] TRACE PRUNING: Cap reasoning segments
if len(conv_dict["aggregated_intelligence"].get("reasoning_history", [])) > 5:
conv_dict["aggregated_intelligence"]["reasoning_history"] = \
conv_dict["aggregated_intelligence"]["reasoning_history"][-5:]
self._cache[conversation_id] = conv_dict
return conv_dict
async def update_intelligence(self, conversation_id: str, intelligence: Dict[str, Any]) -> Dict:
"""Explicitly update intelligence fields (e.g., keywords)."""
conv_dict = await self.get_or_create(conversation_id)
db = get_db_manager()
async with db.session() as session:
# Update DB (Intelligence items)
for entity_type, values in intelligence.items():
if values and isinstance(values, list):
for value in values:
existing = await session.execute(
select(Intelligence).where(
Intelligence.conversation_id == conversation_id,
Intelligence.entity_type == entity_type,
Intelligence.entity_value == str(value)
)
)
if not existing.scalar_one_or_none():
intel = Intelligence(
conversation_id=conversation_id,
entity_type=entity_type,
entity_value=str(value)
)
session.add(intel)
await session.flush()
# Update Cache
for key, values in intelligence.items():
# Ensure target is a list
current_val = conv_dict["aggregated_intelligence"].get(key)
if not isinstance(current_val, list):
current_val = [current_val] if current_val else []
conv_dict["aggregated_intelligence"][key] = current_val
for val in (values if isinstance(values, list) else [values]):
if val not in current_val:
current_val.append(val)
# 🔥 [RISK 5] TRACE PRUNING: Cap reasoning segments
if len(conv_dict["aggregated_intelligence"].get("reasoning_history", [])) > 5:
conv_dict["aggregated_intelligence"]["reasoning_history"] = \
conv_dict["aggregated_intelligence"]["reasoning_history"][-5:]
self._cache[conversation_id] = conv_dict
return conv_dict
async def get_statistics(self) -> Dict[str, Any]:
"""Get global statistics."""
db = get_db_manager()
async with db.session() as session:
# Count conversations
total_conv = await session.execute(select(func.count(Conversation.id)))
total_conversations = total_conv.scalar() or 0
# Count messages
total_msg = await session.execute(select(func.count(Message.id)))
total_messages = total_msg.scalar() or 0
# Count scams detected
scams = await session.execute(
select(func.count(Conversation.id))
.where(Conversation.scam_type.isnot(None))
)
scams_detected = scams.scalar() or 0
# Count intelligence
intel = await session.execute(select(func.count(Intelligence.id)))
intelligence_extracted = intel.scalar() or 0
return {
"total_conversations": total_conversations,
"total_messages": total_messages,
"scams_detected": scams_detected,
"intelligence_extracted": intelligence_extracted,
"active_conversations": len(self._cache),
"scam_distribution": {}
}
def get_history_text(self, conversation_id: str, max_turns: int = 10) -> str:
"""Get conversation history as formatted text."""
conv = self._cache.get(conversation_id)
if not conv:
# Try to fetch from DB if not in cache (Cold fetch)
return ""
history = conv.get("history", [])[-max_turns:]
lines = []
for msg in history:
lines.append(f"Caller: {msg.get('scammer_message', '')}")
lines.append(f"Me: {msg.get('honeypot_response', '')}")
return "\n".join(lines)
async def clear(self, conversation_id: str) -> bool:
"""Explicitly remove a conversation from cache and DB."""
from sqlalchemy import delete
# 1. Clear Cache
if conversation_id in self._cache:
del self._cache[conversation_id]
# 2. Clear Database
db = get_db_manager()
async with db.session() as session:
try:
# Delete messages first
await session.execute(
delete(Message).where(Message.conversation_id == conversation_id)
)
await session.execute(
delete(Intelligence).where(Intelligence.conversation_id == conversation_id)
)
await session.execute(
delete(Conversation).where(Conversation.id == conversation_id)
)
await session.commit()
return True
except Exception as e:
# Fallback to Textual SQL if ORM fails
from sqlalchemy import text
try:
await session.execute(text(f"DELETE FROM messages WHERE conversation_id = :id"), {"id": conversation_id})
await session.execute(text(f"DELETE FROM intelligence WHERE conversation_id = :id"), {"id": conversation_id})
await session.execute(text(f"DELETE FROM conversations WHERE id = :id"), {"id": conversation_id})
await session.commit()
return True
except:
return False
# Global instance
db_memory_store = DatabaseMemoryStore()
__all__ = ["DatabaseMemoryStore", "db_memory_store"]