# tests/test_budget_enforcement.py """ Production Hardening: Budget Enforcement Tests Tests to verify LLM call budget limits are enforced at turn and session levels. """ import pytest from unittest.mock import MagicMock, AsyncMock, patch from dataclasses import dataclass, field from typing import Dict, Optional # Import the actual context from app.core.context import TurnContext # Mock BudgetExceeded exception class BudgetExceeded(Exception): pass class TestTurnBudgetEnforcement: """Tests for turn-level budget enforcement (4 calls max per turn).""" def test_turn_context_has_budget_fields(self): """Verify TurnContext has required budget tracking fields.""" ctx = TurnContext(session_id="test", message="test") assert hasattr(ctx, "llm_call_count"), "Missing llm_call_count field" assert hasattr(ctx, "budget_exceeded"), "Missing budget_exceeded field" assert hasattr(ctx, "session"), "Missing session field" assert ctx.llm_call_count == 0, "llm_call_count should start at 0" assert ctx.budget_exceeded == False, "budget_exceeded should start False" def test_turn_budget_increments(self): """Verify counter increments correctly.""" ctx = TurnContext(session_id="test", message="test") for i in range(4): ctx.llm_call_count += 1 assert ctx.llm_call_count == i + 1 def test_turn_budget_exceeded_flag_set(self): """Verify budget_exceeded flag can be set.""" ctx = TurnContext(session_id="test", message="test") ctx.llm_call_count = 4 ctx.budget_exceeded = True assert ctx.budget_exceeded == True class TestSessionBudgetEnforcement: """Tests for session-level budget enforcement (30 calls max per session).""" def test_session_field_exists(self): """Verify TurnContext has session field for session budget tracking.""" ctx = TurnContext(session_id="test", message="test") assert hasattr(ctx, "session"), "Missing session field" assert isinstance(ctx.session, dict), "session should be a dict" def test_session_budget_tracking(self): """Verify session_llm_calls is tracked in session dict.""" ctx = TurnContext(session_id="test", message="test") ctx.session = {"session_llm_calls": 0} # Simulate 5 turns of 4 calls each for turn in range(5): turn_calls = 4 ctx.session["session_llm_calls"] += turn_calls assert ctx.session["session_llm_calls"] == 20 def test_session_budget_limit(self): """Verify session budget is enforced at 30 calls.""" ctx = TurnContext(session_id="test", message="test") ctx.session = {"session_llm_calls": 30} # At 30 calls, any new call should be blocked MAX_PER_SESSION = 30 assert ctx.session["session_llm_calls"] >= MAX_PER_SESSION class TestBudgetEnforcementIntegration: """Integration tests for budget enforcement in LLMClient.""" @pytest.mark.asyncio async def test_turn_budget_blocks_at_limit(self): """Verify LLMClient blocks calls when turn budget is exceeded.""" ctx = TurnContext(session_id="test", message="test") ctx.llm_call_count = 4 # Already at limit MAX_PER_TURN = 4 # Simulate the budget check logic if ctx.llm_call_count >= MAX_PER_TURN: ctx.budget_exceeded = True assert ctx.budget_exceeded == True assert ctx.llm_call_count >= MAX_PER_TURN @pytest.mark.asyncio async def test_session_budget_blocks_at_limit(self): """Verify LLMClient blocks calls when session budget is exceeded.""" ctx = TurnContext(session_id="test", message="test") ctx.session = {"session_llm_calls": 30} # At limit MAX_PER_SESSION = 30 # Simulate the session budget check logic session_calls = ctx.session.get("session_llm_calls", 0) if session_calls >= MAX_PER_SESSION: ctx.budget_exceeded = True assert ctx.budget_exceeded == True @pytest.mark.asyncio async def test_budget_allows_calls_under_limit(self): """Verify calls are allowed when under budget.""" ctx = TurnContext(session_id="test", message="test") ctx.session = {"session_llm_calls": 10} # Under limit MAX_PER_TURN = 4 MAX_PER_SESSION = 30 # Under turn limit assert ctx.llm_call_count < MAX_PER_TURN # Under session limit assert ctx.session["session_llm_calls"] < MAX_PER_SESSION # Budget should not be exceeded assert ctx.budget_exceeded == False class TestPersonaStability: """Tests for persona consistency lock.""" def test_persona_locked_flag_exists(self): """Verify TurnContext has persona_locked field.""" ctx = TurnContext(session_id="test", message="test") assert hasattr(ctx, "persona_locked"), "Missing persona_locked field" assert ctx.persona_locked == False, "persona_locked should start False" def test_persona_lock_prevents_reselection(self): """Verify persona cannot be reselected after locking.""" ctx = TurnContext(session_id="test", message="test") ctx.persona_locked = True # Simulate persona selection logic should_select_persona = not ctx.persona_locked assert should_select_persona == False if __name__ == "__main__": pytest.main([__file__, "-v"])