import { useRef, useState, useCallback, type ReactNode } from "react"; import { AutoProcessor, Gemma4ForConditionalGeneration, TextStreamer, load_image, } from "@huggingface/transformers"; import { LLMContext, type LoadingStatus, type DetectionResult } from "./LLMContext"; import { parseJsonDetection } from "../utils/detection-parser"; const MODEL_ID = "onnx-community/gemma-4-E2B-it-ONNX"; // eslint-disable-next-line @typescript-eslint/no-explicit-any type AnyModel = any; // eslint-disable-next-line @typescript-eslint/no-explicit-any type AnyProcessor = any; async function pickDtype(): Promise<"q4f16" | "q4"> { try { const adapter = await navigator.gpu?.requestAdapter({ powerPreference: "high-performance", }); if (adapter) { const features = [...adapter.features]; console.log("[Gemma4] WebGPU adapter features:", features); console.log("[Gemma4] WebGPU adapter limits:", adapter.limits); if (adapter.features.has("shader-f16")) { console.log("[Gemma4] shader-f16 is AVAILABLE — using q4f16 (fast)"); return "q4f16"; } console.log("[Gemma4] shader-f16 NOT available — falling back to q4 (slow)"); } } catch (err) { console.warn("[Gemma4] Adapter probe failed:", err); } return "q4"; } export function LLMProvider({ children }: { children: ReactNode }) { const modelRef = useRef | null>(null); const processorRef = useRef(null); const [status, setStatus] = useState({ state: "idle" }); const [isGenerating, setIsGenerating] = useState(false); const isGeneratingRef = useRef(false); const [tps, setTps] = useState(0); const [dtype, setDtype] = useState<"q4f16" | "q4" | null>(null); const [result, setResult] = useState(null); const abortRef = useRef(false); isGeneratingRef.current = isGenerating; const loadModel = useCallback(() => { if (modelRef.current) return; modelRef.current = (async () => { setStatus({ state: "loading", message: "Loading processor..." }); try { const pickedDtype = await pickDtype(); setDtype(pickedDtype); console.log(`[Gemma4] Using dtype: ${pickedDtype}`); const processor = await AutoProcessor.from_pretrained(MODEL_ID); processorRef.current = processor; setStatus({ state: "loading", message: "Downloading model..." }); const model = await Gemma4ForConditionalGeneration.from_pretrained(MODEL_ID, { dtype: pickedDtype, device: "webgpu", progress_callback: (p: { status: string; progress?: number }) => { // Use aggregated progress_total event for smooth, monotonic progress. if (p.status === "progress_total" && typeof p.progress === "number") { setStatus({ state: "loading", progress: p.progress, message: "Loading model...", }); } }, }); setStatus({ state: "ready" }); return model; } catch (err) { console.error("[Gemma4] Model loading failed:", err); const msg = err instanceof Error ? err.message : String(err); setStatus({ state: "error", error: msg }); modelRef.current = null; throw err; } })(); }, []); const detect = useCallback(async (imageUrl: string, prompt: string) => { if (!modelRef.current) { console.warn("[Gemma4] detect() called before model loaded"); return; } if (isGeneratingRef.current) { console.warn("[Gemma4] detect() called while already generating"); return; } console.log("[Gemma4] Starting detection..."); setIsGenerating(true); setTps(0); setResult(null); abortRef.current = false; try { const model = await modelRef.current; const processor = processorRef.current!; // Build multimodal message const messages = [ { role: "user", content: [ { type: "image" as const }, { type: "text" as const, text: prompt }, ], }, ]; // Apply chat template const text = processor.apply_chat_template(messages, { enable_thinking: false, add_generation_prompt: true, }); // Load and process image const image = await load_image(imageUrl); const inputs = await processor(text, image, null, { add_special_tokens: false, }); // Stream generation let outputText = ""; let tokenCount = 0; let firstTokenTime = 0; const streamer = new TextStreamer(processor.tokenizer, { skip_prompt: true, skip_special_tokens: true, callback_function: (token: string) => { if (abortRef.current) return; outputText += token; const detections = parseJsonDetection(outputText); setResult({ text: outputText, detections }); }, token_callback_function: () => { tokenCount++; if (tokenCount === 1) { firstTokenTime = performance.now(); } else { const elapsed = (performance.now() - firstTokenTime) / 1000; if (elapsed > 0) { setTps(Math.round(((tokenCount - 1) / elapsed) * 10) / 10); } } }, }); await model.generate({ ...inputs, max_new_tokens: 128, do_sample: true, temperature: 1.0, top_p: 0.95, top_k: 64, streamer, }); // Final parse const detections = parseJsonDetection(outputText); setResult({ text: outputText, detections }); console.log(`[Gemma4] Detection done. ${detections.length} objects found.`); } catch (err) { console.error("[Gemma4] Detection error:", err); const rawMsg = err instanceof Error ? err.message : String(err); const isOOM = rawMsg.includes("OUT_OF_DEVICE_MEMORY") || rawMsg.includes("Invalid Buffer") || rawMsg.includes("Softmax"); const friendly = isOOM ? "GPU ran out of memory. Enable chrome://flags/#enable-unsafe-webgpu and chrome://flags/#enable-vulkan, restart Chrome, and close other GPU-heavy tabs." : rawMsg; setResult({ text: `Error: ${friendly}`, detections: [] }); } setIsGenerating(false); }, []); const stop = useCallback(() => { abortRef.current = true; }, []); return ( {children} ); }