""" CheckAI Audio Detector — Production-hardened FastAPI backend. Runs on Hugging Face Spaces free tier (2 vCPU, 16 GB RAM). Provides ensemble AI-audio detection via Wav2Vec2 + spectral fingerprinting. """ # NOTE: Do NOT add `from __future__ import annotations` here. # It breaks Pydantic 2.x + FastAPI endpoint signature resolution # (PydanticUndefinedAnnotation: name 'AnalysisRequest' is not defined). # Python 3.10 natively supports `X | Y`, `list[X]`, `dict[K, V]` syntax # without that import. import asyncio import hashlib import hmac import io import ipaddress import logging import os import tempfile import time import uuid import warnings # Silence the per-decode warning storm. Apple Music / iTunes previews # arrive as AAC / M4A which `soundfile` cannot decode directly, so # librosa always falls back to `audioread` (ffmpeg). The fallback # itself is correct — but it emits two `UserWarning` / `FutureWarning` # lines per decode. With 150+ decodes per day they swamp the log, # making genuine errors invisible. The decoding behaviour is # unchanged; only the warnings are suppressed. warnings.filterwarnings("ignore", message="PySoundFile failed", category=UserWarning) warnings.filterwarnings( "ignore", message=".*audioread_load.*", category=FutureWarning, ) from collections import defaultdict from typing import Any, Optional from urllib.parse import urlparse import librosa import numpy as np import requests import torch from fastapi import ( Depends, FastAPI, File, Header, HTTPException, Request, UploadFile, ) from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from pydantic import BaseModel from slowapi import Limiter, _rate_limit_exceeded_handler from slowapi.errors import RateLimitExceeded from slowapi.util import get_remote_address from sse_starlette.sse import EventSourceResponse from starlette.middleware.base import BaseHTTPMiddleware from transformers import pipeline # --------------------------------------------------------------------------- # Logging # --------------------------------------------------------------------------- logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", ) logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- # Configuration # --------------------------------------------------------------------------- # MODEL_ID is env-configurable so we can A/B test candidate detectors without # redeploying. Verified so far: # - "MelodyMachine/Deepfake-audio-detection-V2" → BROKEN (constant ~1.0 on # both real music and AI music; do not use) # - "mo-thecreator/Deepfake-audio-detection" → to evaluate (speech-trained) MODEL_ID = os.getenv("MODEL_ID", "MelodyMachine/Deepfake-audio-detection-V2") # Which pipeline labels count as "this is AI"? Comma-separated, case-insensitive. # Some HF models use LABEL_0 / LABEL_1 instead of semantic names — check the # model's config.json and set this accordingly. _ai_labels_raw = os.getenv("MODEL_AI_LABELS", "fake,ai,synthetic,spoof,label_1") AI_LABELS = {s.strip().lower() for s in _ai_labels_raw.split(",") if s.strip()} API_KEY = os.getenv("DETECTOR_API_KEY", "your-fallback-test-key") MAX_FILE_SIZE = 5 * 1024 * 1024 # 5 MB MIN_AUDIO_DURATION = 1.0 # seconds MAX_AUDIO_DURATION = 30.0 # seconds MAX_GLOBAL_CONCURRENCY = 2 MAX_QUEUE_SIZE = 10 DAILY_LIMIT = 50 ALLOWED_DOMAINS = [ "p.scdn.co", "i.scdn.co", "audio-ak-spotify-com.akamaized.net", "music.apple.com", "audio-ssl.itunes.apple.com", "m.apple-music.com", "googlevideo.com", "youtube.com", "firebasestorage.googleapis.com", "cloudflare-ipfs.com", ] ALLOWED_AUDIO_MIMETYPES = { "audio/mpeg", "audio/mp4", "audio/aac", "audio/x-m4a", "audio/wav", "audio/x-wav", "audio/ogg", "audio/flac", "audio/webm", "application/octet-stream", # fallback — some clients don't set MIME } _raw_origins = os.getenv("ALLOWED_ORIGINS", "http://localhost:3000,http://localhost:7860") ALLOWED_ORIGINS = [o.strip() for o in _raw_origins.split(",") if o.strip()] # --------------------------------------------------------------------------- # Rate-limiter (slowapi) # --------------------------------------------------------------------------- limiter = Limiter(key_func=get_remote_address) # --------------------------------------------------------------------------- # App # --------------------------------------------------------------------------- app = FastAPI(title="CheckAI Audio Detector") app.state.limiter = limiter app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) app.add_middleware( CORSMiddleware, allow_origins=ALLOWED_ORIGINS, allow_credentials=False, allow_methods=["GET", "POST", "OPTIONS"], allow_headers=["Content-Type", "X-Api-Key", "X-Request-ID"], ) # --------------------------------------------------------------------------- # Global state # --------------------------------------------------------------------------- _startup_time: float = 0.0 classifier: Any = None # Flips to True the first time the model has been exercised after # startup. Used by /health to lazily run a warm-up forward pass — # avoids paying the ~1 s "first inference is slow" tax on the very # first real /analyze request after a HF Spaces cold-start. _model_warmed: bool = False # --------------------------------------------------------------------------- # Music-native AI detector (lofcz/ai-music-detector, MIT, June 2025). # # Logistic regression on a 3,585-dim "fakeprint" spectral feature # vector. Detects characteristic frequency-domain artefacts left by # transposed-convolution layers in neural vocoders (Suno ≤ v5, Udio ≤ # v1.5). Reported 99.88 % accuracy on a 17,866-sample held-out set. # # Why it complements the existing wav2vec2 path: # * Wav2Vec2 (the existing primary model) is speech-trained — weak # signal on music. The product positioning explicitly admits this # in the About copy. # * Music-native model is strongest exactly where wav2vec2 is # weakest, and weakest on speech where wav2vec2 is strong. # * Combined as a 3-model soft-voting ensemble, two-of-three # agreement is required for high-confidence AI verdicts. If two # models say "AI" and the third says "human", the weighted score # reflects the disagreement and the user sees a moderate- # confidence verdict instead of an over-confident one — which # is exactly what the user asked for. # # Memory cost: weights are 3,585 × 4 bytes (~14 KB). Bias is 4 bytes. # Inference is a single dot product, sub-millisecond. # --------------------------------------------------------------------------- music_classifier_weights: np.ndarray | None = None # shape (1, 3585) music_classifier_bias: np.ndarray | None = None # shape (1,) MUSIC_MODEL_REPO = "lofcz/ai-music-detector" MUSIC_MODEL_MIN_DURATION = 5.0 # seconds — model unreliable below this # Per-IP concurrency semaphores (max 1 concurrent request per IP) _ip_semaphores: dict[str, asyncio.Semaphore] = defaultdict(lambda: asyncio.Semaphore(1)) # Global inference semaphore _global_semaphore: asyncio.Semaphore = asyncio.Semaphore(MAX_GLOBAL_CONCURRENCY) # (Deleted: `_request_queue` and `_queue_depth` — declared but never read, # left readers with the false impression that a request queue existed. The # old "queued" SSE event computed `MAX_GLOBAL_CONCURRENCY - _global_semaphore._value` # which is always 0 when the semaphore is fully held — i.e. the position # emitted to the client was fake. Build 12 emits a real queue position # from the semaphore state at acquire time.) # In-flight deduplication: hash -> asyncio.Future _inflight: dict[str, asyncio.Future] = {} _inflight_lock = asyncio.Lock() # Daily rate-limit tracking: IP -> (count, reset_timestamp) _daily_counts: dict[str, tuple[int, float]] = {} _daily_lock = asyncio.Lock() # Temp files to clean on shutdown _temp_files: set[str] = set() # --------------------------------------------------------------------------- # Data models # --------------------------------------------------------------------------- class AnalysisRequest(BaseModel): preview_url: str class AnalysisResult(BaseModel): is_ai: bool confidence: float details: dict # (Deleted: `QueueEntry` dataclass — companion to the dead `_request_queue`. # Backpressure now lives in `_acquire_global_slot` which times out and # returns 503 after GLOBAL_SEMAPHORE_TIMEOUT seconds.) # --------------------------------------------------------------------------- # Request-ID middleware # --------------------------------------------------------------------------- class RequestIDMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): request_id = str(uuid.uuid4()) request.state.request_id = request_id response = await call_next(request) response.headers["X-Request-ID"] = request_id return response app.add_middleware(RequestIDMiddleware) # --------------------------------------------------------------------------- # Daily rate-limit helpers # --------------------------------------------------------------------------- async def _check_daily_limit(ip: str) -> tuple[int, float]: """Returns (remaining, reset_timestamp). Raises HTTPException if exhausted.""" now = time.time() async with _daily_lock: count, reset_ts = _daily_counts.get(ip, (0, 0.0)) if now > reset_ts: count = 0 reset_ts = now + 86400 if count >= DAILY_LIMIT: retry_after = int(reset_ts - now) raise HTTPException( status_code=429, detail={ "error": "rate_limited", "retry_after": retry_after, "daily_remaining": 0, "message": f"Daily limit of {DAILY_LIMIT} requests exceeded. Resets in {retry_after}s.", }, headers={"Retry-After": str(retry_after)}, ) count += 1 _daily_counts[ip] = (count, reset_ts) return DAILY_LIMIT - count, reset_ts # --------------------------------------------------------------------------- # Security helpers # --------------------------------------------------------------------------- def verify_api_key(x_api_key: Optional[str] = Header(None)): if x_api_key is None: raise HTTPException(status_code=403, detail="Unauthorized: Missing API Key") if not hmac.compare_digest(x_api_key.encode(), API_KEY.encode()): raise HTTPException(status_code=403, detail="Unauthorized: Invalid API Key") return x_api_key def _is_private_ip(hostname: str) -> bool: """Return True if hostname resolves to a private/reserved IP.""" try: addr = ipaddress.ip_address(hostname) return addr.is_private or addr.is_loopback or addr.is_reserved or addr.is_link_local except ValueError: return False def validate_audio_url(url: str) -> None: parsed = urlparse(url) # Scheme must be https if parsed.scheme != "https": raise HTTPException(status_code=400, detail="Only HTTPS URLs are allowed") # Reject URLs with embedded credentials if parsed.username or parsed.password: raise HTTPException(status_code=400, detail="URLs with credentials are not allowed") # SSRF: reject private IPs if _is_private_ip(parsed.hostname or ""): raise HTTPException(status_code=400, detail="Private/internal addresses are not allowed") # Domain allowlist is_allowed = any(domain in (parsed.netloc or "") for domain in ALLOWED_DOMAINS) if not is_allowed: raise HTTPException(status_code=400, detail="Invalid audio source domain") def validate_upload_content_type(content_type: str | None) -> None: if content_type and content_type not in ALLOWED_AUDIO_MIMETYPES: raise HTTPException( status_code=400, detail=f"Unsupported content type: {content_type}. Expected an audio file.", ) # --------------------------------------------------------------------------- # Global concurrency guard # # **Build 12 fix.** The previous `async with _global_semaphore:` had no # timeout, so when MAX_GLOBAL_CONCURRENCY=2 was saturated the third (and # subsequent) requests blocked indefinitely with no client signal. # The Apple reviewer's request hit this: their connection was accepted, # parked behind earlier reviewers' inferences, and eventually reset by # the OS — manifesting as "the backend never finalises". # # Now: every acquire site uses `_acquire_global_slot()` which waits at # most GLOBAL_SEMAPHORE_TIMEOUT seconds and raises a 503 with # Retry-After if the queue isn't drained in time. The Flutter client # already handles 503 + Retry-After via its existing retry path. # --------------------------------------------------------------------------- GLOBAL_SEMAPHORE_TIMEOUT = 20.0 async def _acquire_global_slot(request_id: str) -> None: """Acquire the global inference semaphore with bounded wait. Caller MUST `_global_semaphore.release()` when done (use try/finally). """ try: await asyncio.wait_for( _global_semaphore.acquire(), timeout=GLOBAL_SEMAPHORE_TIMEOUT ) except asyncio.TimeoutError: logger.warning( "[%s] global-semaphore wait timed out after %.0fs — returning 503", request_id, GLOBAL_SEMAPHORE_TIMEOUT, ) raise HTTPException( status_code=503, detail="Server is busy. Please retry in 30 seconds.", headers={"Retry-After": "30"}, ) # --------------------------------------------------------------------------- # Per-IP concurrency guard # --------------------------------------------------------------------------- async def _acquire_ip_slot(ip: str) -> None: sem = _ip_semaphores[ip] if sem.locked(): raise HTTPException( status_code=429, detail="A request from your IP is already being processed. Please wait.", headers={"Retry-After": "10"}, ) await sem.acquire() def _release_ip_slot(ip: str) -> None: try: _ip_semaphores[ip].release() except ValueError: pass # --------------------------------------------------------------------------- # Content hashing for deduplication # --------------------------------------------------------------------------- def _hash_content(data: bytes) -> str: return hashlib.sha256(data).hexdigest() def _hash_url(url: str) -> str: return hashlib.sha256(url.encode()).hexdigest() # --------------------------------------------------------------------------- # Audio loading — robust AAC/M4A/MP3/WAV via librosa + audioread + ffmpeg # --------------------------------------------------------------------------- # librosa.load() → tries soundfile first (WAV/FLAC/OGG), falls back to # audioread (which shells out to ffmpeg) for AAC/M4A/MP3. # IMPORTANT: audioread requires a FILE PATH on disk, not a BytesIO — # so we must materialize any in-memory bytes to a temp file first. # --------------------------------------------------------------------------- def _load_audio( *, data: bytes | None = None, path: str | None = None, sr: int = 16000, ) -> tuple[np.ndarray, int]: """Decode audio from either a byte blob or a file path. Uses a temp file for byte inputs so audioread+ffmpeg can handle containers (AAC/M4A/MP3) that soundfile cannot decode directly. Returns: (audio_array, sample_rate) as mono float32. """ if data is not None: tmp_fd, tmp_path = tempfile.mkstemp(suffix=".audio") try: os.write(tmp_fd, data) os.close(tmp_fd) return librosa.load(tmp_path, sr=sr, mono=True) finally: try: os.remove(tmp_path) except OSError: pass if path is not None: return librosa.load(path, sr=sr, mono=True) raise ValueError("_load_audio requires either 'data' or 'path'") # --------------------------------------------------------------------------- # Audio analysis core # --------------------------------------------------------------------------- def detect_frequency_fingerprints(audio_data: np.ndarray, sr: int) -> float: """ Analyzes high-frequency artifacts common in AI generators. Looks for unnatural 'digital haze' in the 8kHz-16kHz range. """ D = np.abs(librosa.stft(audio_data)) freqs = librosa.fft_frequencies(sr=sr) high_freq_mask = (freqs >= 8000) & (freqs <= 16000) high_freq_energy = np.mean(D[high_freq_mask, :]) total_energy = np.mean(D) ratio = high_freq_energy / (total_energy + 1e-6) score = min(ratio * 10, 1.0) return float(score) # --------------------------------------------------------------------------- # Music-native fakeprint extractor (lofcz/ai-music-detector contract). # # Translation of `FakeprintExtractor.compute_fakeprint_from_spectrum` # from the upstream GPU/torchaudio implementation # (`src/python/extract_fakeprints.py`) into a CPU-only librosa+scipy # pipeline that fits the existing _load_audio path. Math is bit-for- # bit equivalent; only the front-end (STFT) library differs. # # Constants come straight from the model card's preprocessing_config.json # and MUST stay in sync — feeding the model a feature vector built # with different parameters silently destroys accuracy. # --------------------------------------------------------------------------- _MUSIC_N_FFT = 8192 _MUSIC_HOP = _MUSIC_N_FFT // 2 # 4096, matches torchaudio Spectrogram default _MUSIC_FREQ_MIN = 1000.0 _MUSIC_FREQ_MAX = 8000.0 _MUSIC_HULL_AREA = 10 _MUSIC_MAX_DB = 5.0 _MUSIC_MIN_DB = -45.0 def _extract_fakeprint(audio_data: np.ndarray, sr: int) -> np.ndarray: """Compute the 3,585-dim spectral fakeprint per the lofcz contract. Pipeline (matches upstream exactly): 1. STFT with n_fft=8192, hop=n_fft/2, power=2. 2. Convert to dB: 10 * log10(clip(power, 1e-10, 1e6)). 3. Average across time → 1-D mean spectrum. 4. Apply 1–8 kHz frequency mask (3,585 bins at sr=16 kHz). 5. Compute lower hull via 1-D minimum filter (window=10). 6. Clip hull to [min_db, ∞). 7. Residue = clip(spectrum − hull, 0, max_db). 8. Normalise by max + 1e-6. """ if sr != 16000: # _load_audio always resamples to 16 kHz, but defensive check. raise ValueError(f"fakeprint requires sr=16000, got {sr}") # Step 1 — power spectrum (|STFT|^2, like torchaudio.Spectrogram(power=2)) stft = ( np.abs( librosa.stft( audio_data, n_fft=_MUSIC_N_FFT, hop_length=_MUSIC_HOP, center=True ) ) ** 2 ) # Step 2 — dB scale, matching upstream's torch.log10(clamp(...)) exactly. spec_db = 10.0 * np.log10(np.clip(stft, 1e-10, 1e6)) # Step 3 — time-averaged spectrum. mean_spectrum = spec_db.mean(axis=1) # Step 4 — 1–8 kHz mask. At sr=16000, n_fft=8192 → 4097 bins, mask = 3585. freq_bins = np.linspace(0.0, sr / 2.0, _MUSIC_N_FFT // 2 + 1) freq_mask = (freq_bins >= _MUSIC_FREQ_MIN) & (freq_bins <= _MUSIC_FREQ_MAX) band = mean_spectrum[freq_mask] # Step 5 — lower hull via minimum filter (isolates the slowly-varying # melodic content so the residue captures only the fast peaks that # are characteristic of neural-vocoder artefacts). from scipy.ndimage import minimum_filter1d hull = minimum_filter1d(band, size=_MUSIC_HULL_AREA, mode="nearest") # Step 6 — clip hull from below. hull = np.clip(hull, _MUSIC_MIN_DB, None) # Step 7 — residue: peaks above the hull, clipped to max_db. residue = np.clip(band - hull, 0.0, None) residue = np.clip(residue, 0.0, _MUSIC_MAX_DB) # Step 8 — peak-normalise to [0, 1] for stable model input. fakeprint = residue / (np.max(residue) + 1e-6) return fakeprint.astype(np.float32) def _load_music_classifier() -> tuple[np.ndarray, np.ndarray]: """Download + load the lofcz/ai-music-detector logistic regression. Returns (weights, bias). Weights shape (1, 3585), bias shape (1,). Raises any underlying network/IO error so the caller can decide between failing startup vs continuing with a 2-model ensemble. """ from huggingface_hub import hf_hub_download from safetensors.numpy import load_file weights_path = hf_hub_download( repo_id=MUSIC_MODEL_REPO, filename="model.safetensors" ) weights = load_file(weights_path) return weights["weights"], weights["bias"] def _run_music_detector(audio_data: np.ndarray, sr: int) -> float: """Return P(audio is AI-generated) per the music-native model. Returns 0.0 (interpreted as "no signal from this model") when: * The model failed to load at startup * The audio is shorter than MUSIC_MODEL_MIN_DURATION (model accuracy degrades below ~5 s of input) The ensemble weight on this branch falls to ~0 in those cases — see `_finalise_score`. """ if music_classifier_weights is None or music_classifier_bias is None: return 0.0 duration = len(audio_data) / sr if duration < MUSIC_MODEL_MIN_DURATION: return 0.0 try: fakeprint = _extract_fakeprint(audio_data, sr) # Logistic regression: σ(x · wᵀ + b) logit = float( np.dot(fakeprint, music_classifier_weights.T)[0] + music_classifier_bias[0] ) return float(1.0 / (1.0 + np.exp(-logit))) except Exception as e: logger.warning(f"[music-detector] inference failed: {e}") return 0.0 def run_detection_pipeline(audio_data: np.ndarray, sr: int) -> AnalysisResult: """Shared detection logic for both URL and upload endpoints. Runs all three models sequentially and hands the raw scores to `_finalise_score` for ensemble weighting. The music model is invoked unconditionally — its `_run_music_detector` short-circuits cleanly to 0.0 when unavailable or audio is too short, and `_finalise_score` rebalances the weights accordingly. """ # Model A: Wav2Vec2 sequence analysis (speech-deepfake fine-tune) model_score = 0.0 if classifier: results = classifier(audio_data) logger.info(f"[model] raw output: {results}") for res in results: if res["label"].lower() in AI_LABELS: model_score = res["score"] break # Model B: Music-native fakeprint detector music_model_score = _run_music_detector(audio_data, sr) # Model C: Frequency-domain fingerprinting (spectral haze heuristic) fingerprint_score = detect_frequency_fingerprints(audio_data, sr) return _finalise_score(model_score, fingerprint_score, music_model_score) def _validate_audio_duration(audio_data: np.ndarray, sr: int) -> None: duration = len(audio_data) / sr if duration < MIN_AUDIO_DURATION: raise HTTPException( status_code=400, detail=f"Audio too short ({duration:.1f}s). Minimum is {MIN_AUDIO_DURATION}s.", ) if duration > MAX_AUDIO_DURATION: raise HTTPException( status_code=400, detail=f"Audio too long ({duration:.1f}s). Maximum is {MAX_AUDIO_DURATION}s.", ) # --------------------------------------------------------------------------- # Deduplication wrapper # --------------------------------------------------------------------------- async def _deduplicated_run( content_hash: str, process_fn, ) -> AnalysisResult: """ If the same content hash is already being processed, wait for that result instead of running the pipeline again. """ async with _inflight_lock: if content_hash in _inflight: logger.info(f"Dedup hit for hash {content_hash[:12]}...") return await _inflight[content_hash] future: asyncio.Future = asyncio.get_event_loop().create_future() _inflight[content_hash] = future try: result = await asyncio.to_thread(process_fn) future.set_result(result) return result except Exception as exc: future.set_exception(exc) raise finally: async with _inflight_lock: _inflight.pop(content_hash, None) # --------------------------------------------------------------------------- # Synchronous processing functions (called via asyncio.to_thread in dedup) # --------------------------------------------------------------------------- def _run_sync_url(raw_bytes: bytes) -> AnalysisResult: audio_data, sr = _load_audio(data=raw_bytes, sr=16000) duration = len(audio_data) / sr if duration < MIN_AUDIO_DURATION: raise HTTPException( status_code=400, detail=f"Audio too short ({duration:.1f}s). Minimum is {MIN_AUDIO_DURATION}s.", ) if duration > MAX_AUDIO_DURATION: raise HTTPException( status_code=400, detail=f"Audio too long ({duration:.1f}s). Maximum is {MAX_AUDIO_DURATION}s.", ) return run_detection_pipeline(audio_data, sr) def _run_sync_file(file_path: str) -> AnalysisResult: audio_data, sr = _load_audio(path=file_path, sr=16000) duration = len(audio_data) / sr if duration < MIN_AUDIO_DURATION: raise HTTPException( status_code=400, detail=f"Audio too short ({duration:.1f}s). Minimum is {MIN_AUDIO_DURATION}s.", ) if duration > MAX_AUDIO_DURATION: raise HTTPException( status_code=400, detail=f"Audio too long ({duration:.1f}s). Maximum is {MAX_AUDIO_DURATION}s.", ) return run_detection_pipeline(audio_data, sr) # --------------------------------------------------------------------------- # Streaming SSE helpers # # **Build 12 changes (Apple resubmit):** # # 1. Every `yield` is logged via `_log_sse(...)` so the server-side log # records exactly what each client received. Previously the log # contained the model output but no `[sse] event=...` lines, which # made "the backend never finalises" impossible to diagnose. # # 2. Generators now check `request.is_disconnected()` between stages # and abort early when the Flutter client closes the SSE socket # (e.g. user backgrounds the app). The semaphore is then released # promptly instead of being held for the full inference duration. # # 3. Every successful (and every error) terminal yield is followed by # an explicit `event: complete` frame. SSE clients waiting for a # defined end-of-stream marker now receive one. Clients that ignore # unknown events still close on socket EOF, so the change is # backwards-compatible. # # 4. The duplicate AnalysisResult-building tail across URL and file # generators was extracted into `_finalise_score(...)` so future # tweaks live in one place. # --------------------------------------------------------------------------- def _log_sse(request_id: str, event_name: str, payload: dict | None = None) -> None: """One-line structured log per SSE emission. Filterable in `os_log` / Console.app via `[sse] event=`. """ if payload is None: logger.info("[sse] event=%s req=%s", event_name, request_id) else: logger.info("[sse] event=%s req=%s payload=%s", event_name, request_id, payload) def _finalise_score( model_score: float, fingerprint_score: float, music_model_score: float, ) -> AnalysisResult: """Build the final AnalysisResult — ensemble weights live in one place. Three-model soft voting: * `model_score` — Wav2Vec2 (speech-deepfake fine-tune, weak on music) * `music_model_score` — lofcz/ai-music-detector (music-native, strong on Suno/Udio) * `fingerprint_score` — spectral haze heuristic (broad, signal-of-last-resort) When the music model is unavailable (returns 0.0 — failed to load or audio too short), its weight is rebalanced to the other two so the ensemble degrades gracefully to the previous 2-model behaviour. Two-of-three agreement produces a confident verdict; one model disagreeing with the other two pulls the score toward the middle, which is the "lower-confidence on dissent" behaviour the product spec calls for. """ music_active = music_model_score > 0.0 or music_classifier_weights is not None if music_active and music_classifier_weights is not None: # Full 3-model ensemble. Weights chosen so: # * music model carries the most weight (it's the one # trained on the actual product target) # * wav2vec2 keeps meaningful weight on the speech path # (mic recordings of voice / spoken word) # * spectral haze provides the third opinion at a smaller # weight (a heuristic, not a model) combined_score = ( model_score * 0.30 + music_model_score * 0.50 + fingerprint_score * 0.20 ) strategy = "3-model voting (Wav2Vec2 + music-native + spectral haze)" else: # Fallback when the music model is offline (network blip on # HF model download, file format change, etc). Reverts to # the prior 2-model split exactly as before, no behaviour # change for users. combined_score = (model_score * 0.7) + (fingerprint_score * 0.3) strategy = "Wav2Vec2 + Spectral Haze Analysis (music model offline)" return AnalysisResult( is_ai=combined_score > 0.5, confidence=float(combined_score), details={ "wav2vec2_score": float(model_score), "music_model_score": float(music_model_score), "fingerprint_score": float(fingerprint_score), "ensemble_strategy": strategy, }, ) async def _is_client_gone(request: Request | None, request_id: str) -> bool: """Return True if the SSE consumer has closed the socket. Tolerates `request=None` for callers that synthesise a generator without a Starlette request bound to it. """ if request is None: return False try: if await request.is_disconnected(): logger.info("[sse] cancelled req=%s", request_id) return True except Exception: # If the disconnect probe itself errors we err on the side of # continuing — the surrounding generator's normal write will # raise if the socket is genuinely dead. return False return False async def _stream_url_analysis(url: str, request_id: str, request: Request | None = None): """Generator for SSE events during URL analysis.""" import json validate_audio_url(url) _log_sse(request_id, "processing", {"stage": "downloading"}) yield {"event": "processing", "data": json.dumps({"stage": "downloading"})} if await _is_client_gone(request, request_id): return downloaded_bytes = bytearray() try: with requests.get(url, stream=True, timeout=10) as resp: resp.raise_for_status() for chunk in resp.iter_content(chunk_size=8192): downloaded_bytes.extend(chunk) if len(downloaded_bytes) > MAX_FILE_SIZE: _log_sse(request_id, "error", {"reason": "file_too_large"}) yield {"event": "error", "data": json.dumps({"message": "Audio file too large (exceeds 5MB)"})} return except requests.exceptions.RequestException as e: logger.error(f"[{request_id}] Download failed: {e}") _log_sse(request_id, "error", {"reason": "download_failed"}) yield {"event": "error", "data": json.dumps({"message": "Could not download audio from source"})} return if await _is_client_gone(request, request_id): return _log_sse(request_id, "processing", {"stage": "loading"}) yield {"event": "processing", "data": json.dumps({"stage": "loading"})} try: audio_data, sr = await asyncio.to_thread( _load_audio, data=bytes(downloaded_bytes), sr=16000 ) except Exception as e: logger.error(f"[{request_id}] Audio load failed: {e}") _log_sse(request_id, "error", {"reason": "decode_failed"}) yield {"event": "error", "data": json.dumps({"message": "Failed to decode audio"})} return duration = len(audio_data) / sr if duration < MIN_AUDIO_DURATION or duration > MAX_AUDIO_DURATION: _log_sse(request_id, "error", {"reason": "duration_out_of_range", "duration": duration}) yield { "event": "error", "data": json.dumps({"message": f"Audio duration {duration:.1f}s out of range ({MIN_AUDIO_DURATION}-{MAX_AUDIO_DURATION}s)"}), } return if await _is_client_gone(request, request_id): return _log_sse(request_id, "processing", {"stage": "wav2vec2"}) yield {"event": "processing", "data": json.dumps({"stage": "wav2vec2"})} model_score = 0.0 if classifier: results = await asyncio.to_thread(classifier, audio_data) logger.info(f"[model] raw output: {results}") for res in results: if res["label"].lower() in AI_LABELS: model_score = res["score"] break if await _is_client_gone(request, request_id): return _log_sse(request_id, "processing", {"stage": "music_model"}) yield {"event": "processing", "data": json.dumps({"stage": "music_model"})} music_model_score = await asyncio.to_thread(_run_music_detector, audio_data, sr) if await _is_client_gone(request, request_id): return _log_sse(request_id, "processing", {"stage": "fingerprint"}) yield {"event": "processing", "data": json.dumps({"stage": "fingerprint"})} fingerprint_score = await asyncio.to_thread(detect_frequency_fingerprints, audio_data, sr) result = _finalise_score(model_score, fingerprint_score, music_model_score) _log_sse(request_id, "result", {"is_ai": result.is_ai, "confidence": result.confidence}) yield {"event": "result", "data": result.model_dump_json()} _log_sse(request_id, "complete") yield {"event": "complete", "data": "{}"} async def _stream_file_analysis(file_path: str, request_id: str, request: Request | None = None): """Generator for SSE events during file analysis.""" import json _log_sse(request_id, "processing", {"stage": "loading"}) yield {"event": "processing", "data": json.dumps({"stage": "loading"})} try: audio_data, sr = await asyncio.to_thread( _load_audio, path=file_path, sr=16000 ) except Exception as e: logger.error(f"[{request_id}] Audio load failed: {e}") _log_sse(request_id, "error", {"reason": "decode_failed"}) yield {"event": "error", "data": json.dumps({"message": "Failed to decode audio"})} return duration = len(audio_data) / sr if duration < MIN_AUDIO_DURATION or duration > MAX_AUDIO_DURATION: _log_sse(request_id, "error", {"reason": "duration_out_of_range", "duration": duration}) yield { "event": "error", "data": json.dumps({"message": f"Audio duration {duration:.1f}s out of range ({MIN_AUDIO_DURATION}-{MAX_AUDIO_DURATION}s)"}), } return if await _is_client_gone(request, request_id): return _log_sse(request_id, "processing", {"stage": "wav2vec2"}) yield {"event": "processing", "data": json.dumps({"stage": "wav2vec2"})} model_score = 0.0 if classifier: results = await asyncio.to_thread(classifier, audio_data) logger.info(f"[model] raw output: {results}") for res in results: if res["label"].lower() in AI_LABELS: model_score = res["score"] break if await _is_client_gone(request, request_id): return _log_sse(request_id, "processing", {"stage": "music_model"}) yield {"event": "processing", "data": json.dumps({"stage": "music_model"})} music_model_score = await asyncio.to_thread(_run_music_detector, audio_data, sr) if await _is_client_gone(request, request_id): return _log_sse(request_id, "processing", {"stage": "fingerprint"}) yield {"event": "processing", "data": json.dumps({"stage": "fingerprint"})} fingerprint_score = await asyncio.to_thread(detect_frequency_fingerprints, audio_data, sr) result = _finalise_score(model_score, fingerprint_score, music_model_score) _log_sse(request_id, "result", {"is_ai": result.is_ai, "confidence": result.confidence}) yield {"event": "result", "data": result.model_dump_json()} _log_sse(request_id, "complete") yield {"event": "complete", "data": "{}"} # --------------------------------------------------------------------------- # Common pre-processing for all analyze endpoints # --------------------------------------------------------------------------- async def _pre_process(request: Request) -> tuple[str, str]: """Run daily limit + per-IP concurrency check. Returns (ip, request_id).""" ip = get_remote_address(request) request_id = getattr(request.state, "request_id", str(uuid.uuid4())) remaining, reset_ts = await _check_daily_limit(ip) return ip, request_id # ============================================================================= # URL-based analysis (original endpoint — backward compatible) # ============================================================================= @app.post("/analyze", response_model=AnalysisResult) @limiter.limit("5/minute") async def analyze_track( request: Request, body: AnalysisRequest, api_key: str = Depends(verify_api_key), ): ip, request_id = await _pre_process(request) await _acquire_ip_slot(ip) try: validate_audio_url(body.preview_url) downloaded_bytes = bytearray() try: with requests.get(body.preview_url, stream=True, timeout=10) as response: response.raise_for_status() for chunk in response.iter_content(chunk_size=8192): downloaded_bytes.extend(chunk) if len(downloaded_bytes) > MAX_FILE_SIZE: raise HTTPException(status_code=400, detail="Audio file too large (exceeds 5MB)") except requests.exceptions.RequestException as e: logger.error(f"[{request_id}] Download failed: {e}") raise HTTPException(status_code=400, detail="Could not download audio from source") content_hash = _hash_url(body.preview_url) await _acquire_global_slot(request_id) try: result = await _deduplicated_run( content_hash, lambda: _run_sync_url(bytes(downloaded_bytes)), ) finally: _global_semaphore.release() return result except HTTPException: raise except Exception as e: logger.error(f"[{request_id}] Internal processing error: {e}", exc_info=True) raise HTTPException(status_code=500, detail="Internal processing error") finally: _release_ip_slot(ip) # ============================================================================= # File upload analysis (original endpoint — backward compatible) # ============================================================================= @app.post("/analyze/upload", response_model=AnalysisResult) @limiter.limit("5/minute") async def analyze_upload( request: Request, file: UploadFile = File(...), api_key: str = Depends(verify_api_key), ): ip, request_id = await _pre_process(request) await _acquire_ip_slot(ip) tmp_path = None try: validate_upload_content_type(file.content_type) contents = await file.read() if len(contents) > MAX_FILE_SIZE: raise HTTPException(status_code=400, detail="Audio file too large (exceeds 5MB)") if len(contents) == 0: raise HTTPException(status_code=400, detail="Empty audio file") suffix = os.path.splitext(file.filename or "recording.m4a")[1] or ".m4a" with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp: tmp.write(contents) tmp_path = tmp.name _temp_files.add(tmp_path) content_hash = _hash_content(contents) await _acquire_global_slot(request_id) try: result = await _deduplicated_run( content_hash, lambda: _run_sync_file(tmp_path), ) finally: _global_semaphore.release() return result except HTTPException: raise except Exception as e: logger.error(f"[{request_id}] Upload processing error: {e}", exc_info=True) raise HTTPException(status_code=500, detail="Internal processing error") finally: _release_ip_slot(ip) if tmp_path and os.path.exists(tmp_path): try: os.remove(tmp_path) _temp_files.discard(tmp_path) except OSError as e: logger.warning(f"[{request_id}] Failed to clean up temp file: {e}") # ============================================================================= # Streaming SSE endpoints # ============================================================================= @app.post("/analyze/stream") @limiter.limit("5/minute") async def analyze_stream( request: Request, body: AnalysisRequest, api_key: str = Depends(verify_api_key), ): import json ip, request_id = await _pre_process(request) await _acquire_ip_slot(ip) async def event_generator(): slot_acquired = False try: # Indicate to the client that they're queued IF the semaphore # is saturated. Estimate is honest: how many slots are currently # held above the cap. The previous code computed # `MAX - _value` which was always 0 when saturated — fake. held = MAX_GLOBAL_CONCURRENCY - _global_semaphore._value if held >= MAX_GLOBAL_CONCURRENCY: _log_sse(request_id, "queued", {"position": held - MAX_GLOBAL_CONCURRENCY + 1}) yield { "event": "queued", "data": json.dumps({"position": held - MAX_GLOBAL_CONCURRENCY + 1}), } await _acquire_global_slot(request_id) slot_acquired = True async for event in _stream_url_analysis( body.preview_url, request_id, request=request ): yield event except HTTPException as exc: _log_sse(request_id, "error", {"http_detail": str(exc.detail)}) yield {"event": "error", "data": json.dumps({"message": str(exc.detail)})} yield {"event": "complete", "data": "{}"} except Exception as e: logger.error(f"[{request_id}] Stream error: {e}", exc_info=True) _log_sse(request_id, "error", {"reason": "internal"}) yield {"event": "error", "data": json.dumps({"message": "Internal processing error"})} yield {"event": "complete", "data": "{}"} finally: if slot_acquired: _global_semaphore.release() _release_ip_slot(ip) return EventSourceResponse(event_generator()) @app.post("/analyze/upload/stream") @limiter.limit("5/minute") async def analyze_upload_stream( request: Request, file: UploadFile = File(...), api_key: str = Depends(verify_api_key), ): import json ip, request_id = await _pre_process(request) await _acquire_ip_slot(ip) validate_upload_content_type(file.content_type) contents = await file.read() if len(contents) > MAX_FILE_SIZE: _release_ip_slot(ip) raise HTTPException(status_code=400, detail="Audio file too large (exceeds 5MB)") if len(contents) == 0: _release_ip_slot(ip) raise HTTPException(status_code=400, detail="Empty audio file") suffix = os.path.splitext(file.filename or "recording.m4a")[1] or ".m4a" tmp_fd, tmp_path = tempfile.mkstemp(suffix=suffix) os.write(tmp_fd, contents) os.close(tmp_fd) _temp_files.add(tmp_path) async def event_generator(): slot_acquired = False try: held = MAX_GLOBAL_CONCURRENCY - _global_semaphore._value if held >= MAX_GLOBAL_CONCURRENCY: _log_sse(request_id, "queued", {"position": held - MAX_GLOBAL_CONCURRENCY + 1}) yield { "event": "queued", "data": json.dumps({"position": held - MAX_GLOBAL_CONCURRENCY + 1}), } await _acquire_global_slot(request_id) slot_acquired = True async for event in _stream_file_analysis( tmp_path, request_id, request=request ): yield event except HTTPException as exc: _log_sse(request_id, "error", {"http_detail": str(exc.detail)}) yield {"event": "error", "data": json.dumps({"message": str(exc.detail)})} yield {"event": "complete", "data": "{}"} except Exception as e: logger.error(f"[{request_id}] Stream error: {e}", exc_info=True) _log_sse(request_id, "error", {"reason": "internal"}) yield {"event": "error", "data": json.dumps({"message": "Internal processing error"})} yield {"event": "complete", "data": "{}"} finally: if slot_acquired: _global_semaphore.release() _release_ip_slot(ip) if os.path.exists(tmp_path): try: os.remove(tmp_path) _temp_files.discard(tmp_path) except OSError: pass return EventSourceResponse(event_generator()) # ============================================================================= # Queue status # ============================================================================= @app.get("/queue/status") async def queue_status(): active = MAX_GLOBAL_CONCURRENCY - _global_semaphore._value # Estimate wait: ~15s per inference on free tier estimated_wait = max(0, active - MAX_GLOBAL_CONCURRENCY) * 15 return { "active_requests": active, "max_concurrency": MAX_GLOBAL_CONCURRENCY, "estimated_wait_seconds": estimated_wait, } # ============================================================================= # Health endpoints # ============================================================================= @app.get("/") def root_health(): """Backward-compatible root health check.""" return {"status": "online", "model": MODEL_ID} @app.get("/health") async def health_check(): """Detailed health check with model status, queue depth, and uptime. **Build 12: model warm-up on first call.** HF Spaces auto-suspend after 48 h of idle. The cold-start that follows takes 30–60 s while the transformers pipeline loads weights from disk. Without a warm-up step the very first real `/analyze` request after wake hits the loading window with the client's keep-alive timer already running — exactly the failure mode reproduced in the forensic logs (48.8 h gap between 2026-04-25 13:08 and 2026-04-27 13:59 followed by a stuck request). `/health` now triggers a single 0.05-s zero-input inference the first time after startup so an external warm-keeper cron can keep the model resident without ever touching `/analyze`. """ global _model_warmed uptime = time.time() - _startup_time if _startup_time else 0 active = MAX_GLOBAL_CONCURRENCY - _global_semaphore._value if classifier is not None and not _model_warmed: try: # 0.05 s of silence at 16 kHz — minimal payload, real forward pass. silence = np.zeros(int(16000 * 0.05), dtype=np.float32) await asyncio.to_thread(classifier, silence) _model_warmed = True logger.info("[health] model warmed via /health") except Exception as e: # Don't fail the health check if warm-up errors — the SDK # is still up; we just log and continue. logger.warning(f"[health] model warm-up failed: {e}") return { "status": "online", "model": MODEL_ID, "model_loaded": classifier is not None, "model_warmed": _model_warmed, "music_model": MUSIC_MODEL_REPO, "music_model_loaded": music_classifier_weights is not None, "active_requests": active, "max_concurrency": MAX_GLOBAL_CONCURRENCY, "uptime_seconds": round(uptime, 1), } # ============================================================================= # Lifecycle events # ============================================================================= @app.on_event("startup") async def on_startup(): global classifier, _startup_time, music_classifier_weights, music_classifier_bias _startup_time = time.time() logger.info("=== CheckAI Backend Starting ===") logger.info(f"Model: {MODEL_ID}") logger.info(f"AI labels: {sorted(AI_LABELS)}") logger.info(f"Global concurrency: {MAX_GLOBAL_CONCURRENCY}") logger.info(f"Daily limit per IP: {DAILY_LIMIT}") logger.info(f"Allowed origins: {ALLOWED_ORIGINS}") if classifier is None: try: logger.info(f"Loading model: {MODEL_ID}") classifier = pipeline("audio-classification", model=MODEL_ID) logger.info("Model loaded successfully") except Exception as e: logger.error(f"Error loading model: {e}") # Music-native detector. Treated as best-effort: if the download # fails (HF down, network blip on cold-start) the ensemble # gracefully falls back to the wav2vec2 + fingerprint pair via # `_finalise_score`. Logging the failure makes the degraded mode # visible in Console.app, the existing diagnostic surface. if music_classifier_weights is None: try: logger.info(f"Loading music detector: {MUSIC_MODEL_REPO}") music_classifier_weights, music_classifier_bias = ( _load_music_classifier() ) logger.info( "Music detector loaded · weights=%s · bias=%s", music_classifier_weights.shape, music_classifier_bias.shape, ) except Exception as e: logger.error(f"Error loading music detector — falling back: {e}") logger.info("=== Startup complete ===") @app.on_event("shutdown") async def on_shutdown(): logger.info("=== CheckAI Backend Shutting Down ===") cleaned = 0 for path in list(_temp_files): if os.path.exists(path): try: os.remove(path) cleaned += 1 except OSError: pass _temp_files.clear() logger.info(f"Cleaned up {cleaned} temp file(s)") logger.info("=== Shutdown complete ===") # ============================================================================= # Entrypoint # ============================================================================= if __name__ == "__main__": import uvicorn port = int(os.environ.get("PORT", 7860)) logger.info(f"Starting server on port {port}") uvicorn.run(app, host="0.0.0.0", port=port)