File size: 4,157 Bytes
da0c238
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d50f94e
ad16e06
 
d50f94e
 
 
 
 
 
 
 
 
 
 
da0c238
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d50f94e
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
from __future__ import annotations

import json
import logging
import time
from typing import Any, Dict, List

logger = logging.getLogger(__name__)


class OpenAIClient:
    def __init__(self, api_key: str, model: str = "gpt-4o-mini"):
        try:
            import openai  # type: ignore
        except Exception as exc:  # pragma: no cover - import guard
            raise RuntimeError(
                "openai package is required. Install with `pip install openai`."
            ) from exc

        self._raw = openai
        self.model = model

        # Handle both v1 and legacy SDK initializers.
        if hasattr(openai, "OpenAI"):
            self.client = openai.OpenAI(api_key=api_key)
            self._mode = "client"
        else:
            openai.api_key = api_key
            self.client = openai
            self._mode = "legacy"

    def chat(self, prompt: str, *, max_retries: int = 3) -> str:
        messages = [{"role": "user", "content": prompt}]
        delay = 1.0
        last_error: Exception | None = None
        for attempt in range(max_retries):
            try:
                if self._mode == "client":
                    resp = self.client.chat.completions.create(
                        model=self.model,
                        messages=messages,
                        temperature=0.2,
                    )
                    return resp.choices[0].message.content or ""
                # Legacy fallback
                resp = self.client.ChatCompletion.create(
                    model=self.model,
                    messages=messages,
                    temperature=0.2,
                )
                return resp["choices"][0]["message"]["content"] or ""
            except Exception as exc:  # pragma: no cover - network call
                last_error = exc
                if _is_rate_limit_error(exc):
                    # User indicated ~3 req/min; wait a full 60s to be safe.
                    wait_time = 60.0
                    logger.warning(
                        "OpenAI rate limit encountered (attempt %s). Waiting %.1fs",
                        attempt + 1,
                        wait_time,
                    )
                    time.sleep(wait_time)
                else:
                    logger.warning(
                        "OpenAI call failed (attempt %s): %s", attempt + 1, exc
                    )
                    time.sleep(delay)
                delay *= 2
        raise RuntimeError(f"OpenAI call failed after retries: {last_error}")  # pragma: no cover - network call

    def chat_json(self, prompt: str, *, max_retries: int = 3) -> Dict[str, Any]:
        raw = self.chat(prompt, max_retries=max_retries)
        parsed = _safe_json_parse(raw)
        if parsed is not None:
            return parsed

        # Ask model to repair the JSON if parsing failed.
        repair_prompt = (
            "The previous response was invalid JSON. "
            "Return ONLY valid JSON that fixes it without adding new facts.\n"
            f"Original response:\n{raw}"
        )
        repaired_raw = self.chat(repair_prompt, max_retries=max_retries)
        repaired = _safe_json_parse(repaired_raw)
        if repaired is None:
            raise ValueError("Model did not return valid JSON after repair attempt")
        return repaired


def _safe_json_parse(text: str) -> Dict[str, Any] | None:
    # Attempt direct parse
    try:
        return json.loads(text)
    except json.JSONDecodeError:
        pass

    # Try to extract JSON substring if wrapped in text.
    start = text.find("{")
    end = text.rfind("}")
    if start != -1 and end != -1 and end > start:
        snippet = text[start : end + 1]
        try:
            return json.loads(snippet)
        except json.JSONDecodeError:
            return None
    return None


def _is_rate_limit_error(exc: Exception) -> bool:
    # Works with both new and legacy SDK exceptions.
    msg = str(exc).lower()
    if "rate limit" in msg or "rate_limit" in msg:
        return True
    if hasattr(exc, "status_code") and getattr(exc, "status_code") == 429:
        return True
    return False