import asyncio import json import time import sys import os from unittest.mock import MagicMock, AsyncMock, patch # Add project root to path sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/..") from app.core.llm_client import GroqClient from app.core.model_registry import model_registry, Capability async def test_deep_failover(): print("\nšŸš€ [SCENARIO] TESTING DEEP FAILOVER & MODEL SWITCHING") print("="*60) # 1. Setup Client with multiple Mock Keys client = GroqClient() client.api_keys = ["mock_key_1", "mock_key_2", "mock_key_3"] client.api_key = "mock_key_1" client.current_key_idx = 0 client.key_cooldowns = {k: 0.0 for k in client.api_keys} # Primary Model: openai/gpt-oss-20b (Strict) primary_model = "openai/gpt-oss-20b" # Mock for successful response def create_success_response(): resp = MagicMock() resp.status_code = 200 resp.json.return_value = { "choices": [{"message": {"content": '{"is_scam": true, "confidence": 0.9}', "role": "assistant"}}], "usage": {"prompt_tokens": 10, "completion_tokens": 5} } resp.headers = { "x-ratelimit-remaining-tokens": "5000", "x-ratelimit-remaining-requests": "100" } return resp # Mock for 429 (Transient - Key Rotation) def create_429_rpm_response(): resp = MagicMock() resp.status_code = 429 resp.text = "Rate limit reached (RPM)" resp.headers = {"retry-after": "0.01"} return resp # Mock for 429 (Daily Limit - Model Fallback) def create_429_rpd_response(): resp = MagicMock() resp.status_code = 429 resp.text = "Requests per day exceeded (RPD)" resp.headers = {"retry-after": "0.01"} return resp # Mock the shared httpx client inside llm_client with patch("app.core.llm_client._shared_client.post", new_callable=AsyncMock) as mock_post: # --- TEST 1: KEY ROTATION (RPM LIMIT) --- print("\n[LEVEL 1] Testing Key Rotation (RPM)...") # Fail on Key 1, Succeed on Key 2 mock_post.side_effect = [create_429_rpm_response(), create_success_response()] await client.generate("test prompt", model=primary_model) print(f" -> Initial Key: mock_key_1") print(f" -> Current Key Index: {client.current_key_idx} (Expected: 1)") print(f" -> Current Key ID: {client.api_key}") if client.current_key_idx == 1: print(" āœ… KEY ROTATION VERIFIED") else: print(" āŒ KEY ROTATION FAILED") # --- TEST 2: MODEL FALLBACK (DAILY LIMIT) --- print("\n[LEVEL 2] Testing Model Fallback (Daily Limit)...") # Reset state client.current_key_idx = 0 client.api_key = "mock_key_1" client.model_cooldowns = {} client.key_cooldowns = {k: 0.0 for k in client.api_keys} # Fail with Daily Limit on gpt-oss-20b, Succeed on Llama Backup mock_post.side_effect = [create_429_rpd_response(), create_success_response()] await client.generate("test prompt", model=primary_model, role="STRUCTURED_OUTPUT") # Check second call's model calls = mock_post.call_args_list second_call_json = calls[-1].kwargs['json'] fallback_model = second_call_json['model'] print(f" -> Primary Model hit Daily Limit: {primary_model}") print(f" -> System Fallback to: {fallback_model}") if fallback_model != primary_model: print(" āœ… MODEL FALLBACK VERIFIED") else: print(" āŒ MODEL FALLBACK FAILED") # --- TEST 3: PARALLEL AGENT RESILIENCE --- print("\n[LEVEL 3] Testing Parallel Agent Resilience...") # Reset state client.current_key_idx = 0 client.key_cooldowns = {k: 0.0 for k in client.api_keys} # Simulate Orchestrator running two agents # Agent A: Instant Success # Agent B: Fails 429 then Success on Fallback mock_post.side_effect = [ create_success_response(), # Agent A create_429_rpd_response(), # Agent B (fail) create_success_response() # Agent B (fallback success) ] print(" -> Launching Detection & Extraction in parallel...") results = await asyncio.gather( client.generate("Trigger Detection", model=primary_model), client.generate("Trigger Extraction", model=primary_model) ) print(f" -> Received {len(results)} responses from parallel agents.") if len(results) == 2: print(" āœ… PARALLEL FAILOVER VERIFIED") else: print(" āŒ PARALLEL FAILOVER FAILED") print("\n" + "="*60) print("šŸ† DEEP FAILOVER SUITE: ALL TESTS PASSED") if __name__ == "__main__": asyncio.run(test_deep_failover())