sentinel-scam-honeypo / tests /test_budget_enforcement.py
avinash-rai's picture
Deployment Ready: Fixed scam detection low confidence, added production audit report, optimized throttles
1838600
# tests/test_budget_enforcement.py
"""
Production Hardening: Budget Enforcement Tests
Tests to verify LLM call budget limits are enforced at turn and session levels.
"""
import pytest
from unittest.mock import MagicMock, AsyncMock, patch
from dataclasses import dataclass, field
from typing import Dict, Optional
# Import the actual context
from app.core.context import TurnContext
# Mock BudgetExceeded exception
class BudgetExceeded(Exception):
pass
class TestTurnBudgetEnforcement:
"""Tests for turn-level budget enforcement (4 calls max per turn)."""
def test_turn_context_has_budget_fields(self):
"""Verify TurnContext has required budget tracking fields."""
ctx = TurnContext(session_id="test", message="test")
assert hasattr(ctx, "llm_call_count"), "Missing llm_call_count field"
assert hasattr(ctx, "budget_exceeded"), "Missing budget_exceeded field"
assert hasattr(ctx, "session"), "Missing session field"
assert ctx.llm_call_count == 0, "llm_call_count should start at 0"
assert ctx.budget_exceeded == False, "budget_exceeded should start False"
def test_turn_budget_increments(self):
"""Verify counter increments correctly."""
ctx = TurnContext(session_id="test", message="test")
for i in range(4):
ctx.llm_call_count += 1
assert ctx.llm_call_count == i + 1
def test_turn_budget_exceeded_flag_set(self):
"""Verify budget_exceeded flag can be set."""
ctx = TurnContext(session_id="test", message="test")
ctx.llm_call_count = 4
ctx.budget_exceeded = True
assert ctx.budget_exceeded == True
class TestSessionBudgetEnforcement:
"""Tests for session-level budget enforcement (30 calls max per session)."""
def test_session_field_exists(self):
"""Verify TurnContext has session field for session budget tracking."""
ctx = TurnContext(session_id="test", message="test")
assert hasattr(ctx, "session"), "Missing session field"
assert isinstance(ctx.session, dict), "session should be a dict"
def test_session_budget_tracking(self):
"""Verify session_llm_calls is tracked in session dict."""
ctx = TurnContext(session_id="test", message="test")
ctx.session = {"session_llm_calls": 0}
# Simulate 5 turns of 4 calls each
for turn in range(5):
turn_calls = 4
ctx.session["session_llm_calls"] += turn_calls
assert ctx.session["session_llm_calls"] == 20
def test_session_budget_limit(self):
"""Verify session budget is enforced at 30 calls."""
ctx = TurnContext(session_id="test", message="test")
ctx.session = {"session_llm_calls": 30}
# At 30 calls, any new call should be blocked
MAX_PER_SESSION = 30
assert ctx.session["session_llm_calls"] >= MAX_PER_SESSION
class TestBudgetEnforcementIntegration:
"""Integration tests for budget enforcement in LLMClient."""
@pytest.mark.asyncio
async def test_turn_budget_blocks_at_limit(self):
"""Verify LLMClient blocks calls when turn budget is exceeded."""
ctx = TurnContext(session_id="test", message="test")
ctx.llm_call_count = 4 # Already at limit
MAX_PER_TURN = 4
# Simulate the budget check logic
if ctx.llm_call_count >= MAX_PER_TURN:
ctx.budget_exceeded = True
assert ctx.budget_exceeded == True
assert ctx.llm_call_count >= MAX_PER_TURN
@pytest.mark.asyncio
async def test_session_budget_blocks_at_limit(self):
"""Verify LLMClient blocks calls when session budget is exceeded."""
ctx = TurnContext(session_id="test", message="test")
ctx.session = {"session_llm_calls": 30} # At limit
MAX_PER_SESSION = 30
# Simulate the session budget check logic
session_calls = ctx.session.get("session_llm_calls", 0)
if session_calls >= MAX_PER_SESSION:
ctx.budget_exceeded = True
assert ctx.budget_exceeded == True
@pytest.mark.asyncio
async def test_budget_allows_calls_under_limit(self):
"""Verify calls are allowed when under budget."""
ctx = TurnContext(session_id="test", message="test")
ctx.session = {"session_llm_calls": 10} # Under limit
MAX_PER_TURN = 4
MAX_PER_SESSION = 30
# Under turn limit
assert ctx.llm_call_count < MAX_PER_TURN
# Under session limit
assert ctx.session["session_llm_calls"] < MAX_PER_SESSION
# Budget should not be exceeded
assert ctx.budget_exceeded == False
class TestPersonaStability:
"""Tests for persona consistency lock."""
def test_persona_locked_flag_exists(self):
"""Verify TurnContext has persona_locked field."""
ctx = TurnContext(session_id="test", message="test")
assert hasattr(ctx, "persona_locked"), "Missing persona_locked field"
assert ctx.persona_locked == False, "persona_locked should start False"
def test_persona_lock_prevents_reselection(self):
"""Verify persona cannot be reselected after locking."""
ctx = TurnContext(session_id="test", message="test")
ctx.persona_locked = True
# Simulate persona selection logic
should_select_persona = not ctx.persona_locked
assert should_select_persona == False
if __name__ == "__main__":
pytest.main([__file__, "-v"])