File size: 5,698 Bytes
1838600
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
# tests/test_budget_enforcement.py
"""
Production Hardening: Budget Enforcement Tests
Tests to verify LLM call budget limits are enforced at turn and session levels.
"""

import pytest
from unittest.mock import MagicMock, AsyncMock, patch
from dataclasses import dataclass, field
from typing import Dict, Optional

# Import the actual context
from app.core.context import TurnContext

# Mock BudgetExceeded exception
class BudgetExceeded(Exception):
    pass


class TestTurnBudgetEnforcement:
    """Tests for turn-level budget enforcement (4 calls max per turn)."""
    
    def test_turn_context_has_budget_fields(self):
        """Verify TurnContext has required budget tracking fields."""
        ctx = TurnContext(session_id="test", message="test")
        
        assert hasattr(ctx, "llm_call_count"), "Missing llm_call_count field"
        assert hasattr(ctx, "budget_exceeded"), "Missing budget_exceeded field"
        assert hasattr(ctx, "session"), "Missing session field"
        
        assert ctx.llm_call_count == 0, "llm_call_count should start at 0"
        assert ctx.budget_exceeded == False, "budget_exceeded should start False"
    
    def test_turn_budget_increments(self):
        """Verify counter increments correctly."""
        ctx = TurnContext(session_id="test", message="test")
        
        for i in range(4):
            ctx.llm_call_count += 1
            assert ctx.llm_call_count == i + 1
    
    def test_turn_budget_exceeded_flag_set(self):
        """Verify budget_exceeded flag can be set."""
        ctx = TurnContext(session_id="test", message="test")
        ctx.llm_call_count = 4
        ctx.budget_exceeded = True
        
        assert ctx.budget_exceeded == True


class TestSessionBudgetEnforcement:
    """Tests for session-level budget enforcement (30 calls max per session)."""
    
    def test_session_field_exists(self):
        """Verify TurnContext has session field for session budget tracking."""
        ctx = TurnContext(session_id="test", message="test")
        assert hasattr(ctx, "session"), "Missing session field"
        assert isinstance(ctx.session, dict), "session should be a dict"
    
    def test_session_budget_tracking(self):
        """Verify session_llm_calls is tracked in session dict."""
        ctx = TurnContext(session_id="test", message="test")
        ctx.session = {"session_llm_calls": 0}
        
        # Simulate 5 turns of 4 calls each
        for turn in range(5):
            turn_calls = 4
            ctx.session["session_llm_calls"] += turn_calls
        
        assert ctx.session["session_llm_calls"] == 20
    
    def test_session_budget_limit(self):
        """Verify session budget is enforced at 30 calls."""
        ctx = TurnContext(session_id="test", message="test")
        ctx.session = {"session_llm_calls": 30}
        
        # At 30 calls, any new call should be blocked
        MAX_PER_SESSION = 30
        assert ctx.session["session_llm_calls"] >= MAX_PER_SESSION


class TestBudgetEnforcementIntegration:
    """Integration tests for budget enforcement in LLMClient."""
    
    @pytest.mark.asyncio
    async def test_turn_budget_blocks_at_limit(self):
        """Verify LLMClient blocks calls when turn budget is exceeded."""
        ctx = TurnContext(session_id="test", message="test")
        ctx.llm_call_count = 4  # Already at limit
        
        MAX_PER_TURN = 4
        
        # Simulate the budget check logic
        if ctx.llm_call_count >= MAX_PER_TURN:
            ctx.budget_exceeded = True
            
        assert ctx.budget_exceeded == True
        assert ctx.llm_call_count >= MAX_PER_TURN
    
    @pytest.mark.asyncio
    async def test_session_budget_blocks_at_limit(self):
        """Verify LLMClient blocks calls when session budget is exceeded."""
        ctx = TurnContext(session_id="test", message="test")
        ctx.session = {"session_llm_calls": 30}  # At limit
        
        MAX_PER_SESSION = 30
        
        # Simulate the session budget check logic
        session_calls = ctx.session.get("session_llm_calls", 0)
        if session_calls >= MAX_PER_SESSION:
            ctx.budget_exceeded = True
            
        assert ctx.budget_exceeded == True
    
    @pytest.mark.asyncio
    async def test_budget_allows_calls_under_limit(self):
        """Verify calls are allowed when under budget."""
        ctx = TurnContext(session_id="test", message="test")
        ctx.session = {"session_llm_calls": 10}  # Under limit
        
        MAX_PER_TURN = 4
        MAX_PER_SESSION = 30
        
        # Under turn limit
        assert ctx.llm_call_count < MAX_PER_TURN
        
        # Under session limit
        assert ctx.session["session_llm_calls"] < MAX_PER_SESSION
        
        # Budget should not be exceeded
        assert ctx.budget_exceeded == False


class TestPersonaStability:
    """Tests for persona consistency lock."""
    
    def test_persona_locked_flag_exists(self):
        """Verify TurnContext has persona_locked field."""
        ctx = TurnContext(session_id="test", message="test")
        assert hasattr(ctx, "persona_locked"), "Missing persona_locked field"
        assert ctx.persona_locked == False, "persona_locked should start False"
    
    def test_persona_lock_prevents_reselection(self):
        """Verify persona cannot be reselected after locking."""
        ctx = TurnContext(session_id="test", message="test")
        ctx.persona_locked = True
        
        # Simulate persona selection logic
        should_select_persona = not ctx.persona_locked
        
        assert should_select_persona == False


if __name__ == "__main__":
    pytest.main([__file__, "-v"])