| import asyncio |
| import httpx |
| import uuid |
| from datetime import datetime |
| from typing import Optional, List, Literal |
| from fastapi import FastAPI, HTTPException, BackgroundTasks |
| from pydantic import BaseModel, Field |
| import logging |
| import os |
|
|
| |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| app = FastAPI( |
| title="OpenAI Compatible Image Generation API", |
| description="OpenAI-compatible API for image generation using Captions backend", |
| version="1.0.0" |
| ) |
|
|
| |
| CAPTIONS_BASE_URL = "https://core.captions-web-api.xyz/proxy/v1/gen-ai/image" |
| BEARER_TOKEN = os.getenv("CAPTIONS_BEARER_TOKEN", "eyJhbGciOiJSUzI1NiIsImtpZCI6IjU3YmZiMmExMWRkZmZjMGFkMmU2ODE0YzY4NzYzYjhjNjg3NTgxZDgiLCJ0eXAiOiJKV1QifQ.eyJnb29nbGUiOnRydWUsImlzcyI6Imh0dHBzOi8vc2VjdXJldG9rZW4uZ29vZ2xlLmNvbS9jYXB0aW9ucy1mNmRlOSIsImF1ZCI6ImNhcHRpb25zLWY2ZGU5IiwiYXV0aF90aW1lIjoxNzU1MzYyODEzLCJ1c2VyX2lkIjoic3hWek5XaUYyempXYmUxTjNjd3UiLCJzdWIiOiJzeFZ6TldpRjJ6aldiZTFOM2N3dSIsImlhdCI6MTc1NTM2MjgxMywiZXhwIjoxNzU1MzY2NDEzLCJmaXJlYmFzZSI6eyJpZGVudGl0aWVzIjp7fSwic2lnbl9pbl9wcm92aWRlciI6ImN1c3RvbSJ9fQ.jGuhWp-w8jlGy8xmMjqOyig_LVcr53udFgMjrQTJtKtE_J_iVkvMLncO2TnJ2BquoEp9pwVlZIG-imlFe6Uhtz95-t1oHENf5yzUWu3HocFsNVeAZh9avi_iObSYM_pFOT9lwRNzk1oMa6LbwViuVgTXvHDse9T4_nDfmCBbWngWksh1_JGtnrK2qPb5YD8Hr26itDRMx8mzUr2cQqtU9mU0R910CROqsNaQ9ovemeGe-2RT-hZku4VVYAMDOdvcFsgcf_BJTLRikmc3T7Ekx8T0KM6ZpTgr34wtnl7rpDBNOX0cOSYu3NEUDBnhNJKmPl5qL08gcYEur1ijP2mcTA") |
|
|
| |
| class ImageGenerationRequest(BaseModel): |
| prompt: str = Field(..., description="A text description of the desired image(s)") |
| model: Optional[str] = Field("dall-e-3", description="The model to use for image generation") |
| n: Optional[int] = Field(1, ge=1, le=10, description="Number of images to generate") |
| quality: Optional[Literal["standard", "hd"]] = Field("standard", description="Quality of the image") |
| response_format: Optional[Literal["url", "b64_json"]] = Field("url", description="Response format") |
| size: Optional[Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"]] = Field("1024x1024", description="Size of the generated images") |
| style: Optional[Literal["vivid", "natural"]] = Field("vivid", description="Style of the generated images") |
| user: Optional[str] = Field(None, description="A unique identifier representing your end-user") |
|
|
| |
| class ImageData(BaseModel): |
| url: Optional[str] = None |
| b64_json: Optional[str] = None |
| revised_prompt: Optional[str] = None |
|
|
| class ImageGenerationResponse(BaseModel): |
| created: int |
| data: List[ImageData] |
|
|
| |
| class CaptionsSubmitRequest(BaseModel): |
| modelId: str = "openai-gpt-4o-image" |
| prompt: str |
| aspectRatio: int = 2 |
| magicPrompt: bool = False |
| optimisticProjectId: str |
|
|
| class CaptionsStatusRequest(BaseModel): |
| operationId: str |
|
|
| |
| operations_store = {} |
|
|
| def get_aspect_ratio_from_size(size: str) -> int: |
| """Convert OpenAI size format to Captions aspect ratio""" |
| size_map = { |
| "256x256": 1, |
| "512x512": 1, |
| "1024x1024": 1, |
| "1792x1024": 2, |
| "1024x1792": 3 |
| } |
| return size_map.get(size, 1) |
|
|
| async def submit_image_generation(prompt: str, size: str = "1024x1024") -> str: |
| """Submit image generation request to Captions API""" |
| headers = { |
| "accept": "application/json, text/plain, */*", |
| "authorization": f"Bearer {BEARER_TOKEN}", |
| "content-type": "application/json", |
| "origin": "https://desktop.captions.ai", |
| "user-agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36", |
| "x-app-version": "1.0.0", |
| "x-captions-user-timezone": "UTC", |
| "x-device-id": str(uuid.uuid4()).replace("-", "") |
| } |
| |
| payload = { |
| "modelId": "openai-gpt-4o-image", |
| "prompt": prompt, |
| "aspectRatio": get_aspect_ratio_from_size(size), |
| "magicPrompt": False, |
| "optimisticProjectId": f"API-{uuid.uuid4()}" |
| } |
| |
| async with httpx.AsyncClient() as client: |
| try: |
| response = await client.post( |
| f"{CAPTIONS_BASE_URL}/generate/submit", |
| headers=headers, |
| json=payload, |
| timeout=30.0 |
| ) |
| response.raise_for_status() |
| result = response.json() |
| |
| if result.get("success"): |
| operation_id = result["data"]["operationId"] |
| logger.info(f"Image generation submitted with operation ID: {operation_id}") |
| return operation_id |
| else: |
| raise HTTPException(status_code=500, detail="Failed to submit image generation") |
| |
| except httpx.RequestError as e: |
| logger.error(f"Request error: {e}") |
| raise HTTPException(status_code=500, detail="Failed to connect to image generation service") |
| except Exception as e: |
| logger.error(f"Unexpected error: {e}") |
| raise HTTPException(status_code=500, detail="Internal server error") |
|
|
| async def check_generation_status(operation_id: str) -> dict: |
| """Check the status of image generation""" |
| headers = { |
| "accept": "application/json, text/plain, */*", |
| "authorization": f"Bearer {BEARER_TOKEN}", |
| "content-type": "application/json", |
| "origin": "https://desktop.captions.ai", |
| "user-agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36", |
| "x-app-version": "1.0.0", |
| "x-captions-user-timezone": "UTC", |
| "x-device-id": str(uuid.uuid4()).replace("-", "") |
| } |
| |
| payload = {"operationId": operation_id} |
| |
| async with httpx.AsyncClient() as client: |
| try: |
| response = await client.post( |
| f"{CAPTIONS_BASE_URL}/generate/status", |
| headers=headers, |
| json=payload, |
| timeout=30.0 |
| ) |
| response.raise_for_status() |
| result = response.json() |
| |
| if result.get("success"): |
| return result["data"] |
| else: |
| raise HTTPException(status_code=500, detail="Failed to check generation status") |
| |
| except httpx.RequestError as e: |
| logger.error(f"Request error: {e}") |
| raise HTTPException(status_code=500, detail="Failed to connect to status service") |
| except Exception as e: |
| logger.error(f"Unexpected error: {e}") |
| raise HTTPException(status_code=500, detail="Internal server error") |
|
|
| async def wait_for_completion(operation_id: str, max_wait_time: int = 300) -> dict: |
| """Wait for image generation to complete with polling""" |
| start_time = datetime.now() |
| |
| while True: |
| status_data = await check_generation_status(operation_id) |
| |
| |
| if status_data.get("state") == 2: |
| return status_data["complete"] |
| |
| |
| elapsed = (datetime.now() - start_time).total_seconds() |
| if elapsed > max_wait_time: |
| raise HTTPException(status_code=408, detail="Image generation timeout") |
| |
| |
| await asyncio.sleep(2) |
|
|
| @app.post("/v1/images/generations", response_model=ImageGenerationResponse) |
| async def create_image(request: ImageGenerationRequest): |
| """ |
| Creates an image given a text prompt. |
| Compatible with OpenAI's image generation API. |
| """ |
| try: |
| logger.info(f"Received image generation request: {request.prompt}") |
| |
| |
| operation_id = await submit_image_generation(request.prompt, request.size) |
| |
| |
| completion_data = await wait_for_completion(operation_id) |
| |
| |
| image_data = ImageData( |
| url=completion_data.get("assetResolvedUrl"), |
| revised_prompt=request.prompt |
| ) |
| |
| response = ImageGenerationResponse( |
| created=int(datetime.now().timestamp()), |
| data=[image_data] |
| ) |
| |
| logger.info(f"Image generation completed successfully for operation: {operation_id}") |
| return response |
| |
| except HTTPException: |
| raise |
| except Exception as e: |
| logger.error(f"Unexpected error in image generation: {e}") |
| raise HTTPException(status_code=500, detail="Internal server error") |
|
|
| @app.get("/health") |
| async def health_check(): |
| """Health check endpoint""" |
| return {"status": "healthy", "timestamp": datetime.now().isoformat()} |
|
|
| @app.get("/") |
| async def root(): |
| """Root endpoint with API information""" |
| return { |
| "message": "OpenAI Compatible Image Generation API", |
| "version": "1.0.0", |
| "endpoints": { |
| "image_generation": "/v1/images/generations", |
| "health": "/health", |
| "docs": "/docs" |
| } |
| } |
|
|
| if __name__ == "__main__": |
| import uvicorn |
| uvicorn.run(app, host="0.0.0.0", port=8000) |