raj999's picture
updated
ad16e06
Raw
History Blame
4.16 kB
from __future__ import annotations
import json
import logging
import time
from typing import Any, Dict, List
logger = logging.getLogger(__name__)
class OpenAIClient:
def __init__(self, api_key: str, model: str = "gpt-4o-mini"):
try:
import openai # type: ignore
except Exception as exc: # pragma: no cover - import guard
raise RuntimeError(
"openai package is required. Install with `pip install openai`."
) from exc
self._raw = openai
self.model = model
# Handle both v1 and legacy SDK initializers.
if hasattr(openai, "OpenAI"):
self.client = openai.OpenAI(api_key=api_key)
self._mode = "client"
else:
openai.api_key = api_key
self.client = openai
self._mode = "legacy"
def chat(self, prompt: str, *, max_retries: int = 3) -> str:
messages = [{"role": "user", "content": prompt}]
delay = 1.0
last_error: Exception | None = None
for attempt in range(max_retries):
try:
if self._mode == "client":
resp = self.client.chat.completions.create(
model=self.model,
messages=messages,
temperature=0.2,
)
return resp.choices[0].message.content or ""
# Legacy fallback
resp = self.client.ChatCompletion.create(
model=self.model,
messages=messages,
temperature=0.2,
)
return resp["choices"][0]["message"]["content"] or ""
except Exception as exc: # pragma: no cover - network call
last_error = exc
if _is_rate_limit_error(exc):
# User indicated ~3 req/min; wait a full 60s to be safe.
wait_time = 60.0
logger.warning(
"OpenAI rate limit encountered (attempt %s). Waiting %.1fs",
attempt + 1,
wait_time,
)
time.sleep(wait_time)
else:
logger.warning(
"OpenAI call failed (attempt %s): %s", attempt + 1, exc
)
time.sleep(delay)
delay *= 2
raise RuntimeError(f"OpenAI call failed after retries: {last_error}") # pragma: no cover - network call
def chat_json(self, prompt: str, *, max_retries: int = 3) -> Dict[str, Any]:
raw = self.chat(prompt, max_retries=max_retries)
parsed = _safe_json_parse(raw)
if parsed is not None:
return parsed
# Ask model to repair the JSON if parsing failed.
repair_prompt = (
"The previous response was invalid JSON. "
"Return ONLY valid JSON that fixes it without adding new facts.\n"
f"Original response:\n{raw}"
)
repaired_raw = self.chat(repair_prompt, max_retries=max_retries)
repaired = _safe_json_parse(repaired_raw)
if repaired is None:
raise ValueError("Model did not return valid JSON after repair attempt")
return repaired
def _safe_json_parse(text: str) -> Dict[str, Any] | None:
# Attempt direct parse
try:
return json.loads(text)
except json.JSONDecodeError:
pass
# Try to extract JSON substring if wrapped in text.
start = text.find("{")
end = text.rfind("}")
if start != -1 and end != -1 and end > start:
snippet = text[start : end + 1]
try:
return json.loads(snippet)
except json.JSONDecodeError:
return None
return None
def _is_rate_limit_error(exc: Exception) -> bool:
# Works with both new and legacy SDK exceptions.
msg = str(exc).lower()
if "rate limit" in msg or "rate_limit" in msg:
return True
if hasattr(exc, "status_code") and getattr(exc, "status_code") == 429:
return True
return False