Spaces:
Sleeping
Sleeping
File size: 4,157 Bytes
da0c238 d50f94e ad16e06 d50f94e da0c238 d50f94e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 | from __future__ import annotations
import json
import logging
import time
from typing import Any, Dict, List
logger = logging.getLogger(__name__)
class OpenAIClient:
def __init__(self, api_key: str, model: str = "gpt-4o-mini"):
try:
import openai # type: ignore
except Exception as exc: # pragma: no cover - import guard
raise RuntimeError(
"openai package is required. Install with `pip install openai`."
) from exc
self._raw = openai
self.model = model
# Handle both v1 and legacy SDK initializers.
if hasattr(openai, "OpenAI"):
self.client = openai.OpenAI(api_key=api_key)
self._mode = "client"
else:
openai.api_key = api_key
self.client = openai
self._mode = "legacy"
def chat(self, prompt: str, *, max_retries: int = 3) -> str:
messages = [{"role": "user", "content": prompt}]
delay = 1.0
last_error: Exception | None = None
for attempt in range(max_retries):
try:
if self._mode == "client":
resp = self.client.chat.completions.create(
model=self.model,
messages=messages,
temperature=0.2,
)
return resp.choices[0].message.content or ""
# Legacy fallback
resp = self.client.ChatCompletion.create(
model=self.model,
messages=messages,
temperature=0.2,
)
return resp["choices"][0]["message"]["content"] or ""
except Exception as exc: # pragma: no cover - network call
last_error = exc
if _is_rate_limit_error(exc):
# User indicated ~3 req/min; wait a full 60s to be safe.
wait_time = 60.0
logger.warning(
"OpenAI rate limit encountered (attempt %s). Waiting %.1fs",
attempt + 1,
wait_time,
)
time.sleep(wait_time)
else:
logger.warning(
"OpenAI call failed (attempt %s): %s", attempt + 1, exc
)
time.sleep(delay)
delay *= 2
raise RuntimeError(f"OpenAI call failed after retries: {last_error}") # pragma: no cover - network call
def chat_json(self, prompt: str, *, max_retries: int = 3) -> Dict[str, Any]:
raw = self.chat(prompt, max_retries=max_retries)
parsed = _safe_json_parse(raw)
if parsed is not None:
return parsed
# Ask model to repair the JSON if parsing failed.
repair_prompt = (
"The previous response was invalid JSON. "
"Return ONLY valid JSON that fixes it without adding new facts.\n"
f"Original response:\n{raw}"
)
repaired_raw = self.chat(repair_prompt, max_retries=max_retries)
repaired = _safe_json_parse(repaired_raw)
if repaired is None:
raise ValueError("Model did not return valid JSON after repair attempt")
return repaired
def _safe_json_parse(text: str) -> Dict[str, Any] | None:
# Attempt direct parse
try:
return json.loads(text)
except json.JSONDecodeError:
pass
# Try to extract JSON substring if wrapped in text.
start = text.find("{")
end = text.rfind("}")
if start != -1 and end != -1 and end > start:
snippet = text[start : end + 1]
try:
return json.loads(snippet)
except json.JSONDecodeError:
return None
return None
def _is_rate_limit_error(exc: Exception) -> bool:
# Works with both new and legacy SDK exceptions.
msg = str(exc).lower()
if "rate limit" in msg or "rate_limit" in msg:
return True
if hasattr(exc, "status_code") and getattr(exc, "status_code") == 429:
return True
return False
|