# Import environment loaders and deep learning frameworks import os import json from dotenv import load_dotenv from PIL import Image import torch from transformers import AutoModelForMultimodalLM, AutoProcessor # Import spaces wrapper conditionally for Hugging Face ZeroGPU compatibility try: import spaces has_spaces = True except ImportError: has_spaces = False def gpu_decorator(fn): """Decorator to conditionally apply spaces.GPU on Hugging Face Spaces.""" if has_spaces: return spaces.GPU(fn) return fn # Load token variables from local environment file load_dotenv() # Specify the Hugging Face hub repository target for Gemma 4 MODEL_ID = "google/gemma-4-12B-it" # Initialize global cache elements to keep load states persistent _model = None _processor = None _device = None def get_device(): """Determines and returns the best available device.""" global _device # Return device if already determined and cached if _device is not None: return _device # Use NVIDIA GPU if CUDA framework is active if torch.cuda.is_available(): _device = "cuda" # Use Metal Performance Shaders if running on Apple Silicon elif torch.backends.mps.is_available(): _device = "mps" # Fall back to standard CPU processing if no accelerators are present else: _device = "cpu" return _device def load_gemma_model(): """Loads and caches the Gemma 4 12B-it model and processor.""" global _model, _processor # Avoid reload if model and processor are already warm in memory if _model is not None and _processor is not None: return _model, _processor # Retrieve Hugging Face authentication token from environment token = os.getenv("HF_TOKEN") # Verify the token is present and not set to default template value if not token or token == "hf_YOUR_WRITE_TOKEN_HERE": raise ValueError( "HF_TOKEN not set or invalid in the .env file. Please add your Hugging Face write token." ) # Detect the active target execution hardware device = get_device() print(f"Loading {MODEL_ID} on device: {device}...") # Download and instantiate the native multimodal processor config _processor = AutoProcessor.from_pretrained(MODEL_ID, token=token) # Initialize model options depending on device classification if device == "cuda": # Load in half-precision bfloat16 on server CUDA hardware _model = AutoModelForMultimodalLM.from_pretrained( MODEL_ID, dtype=torch.bfloat16, device_map="auto", token=token ) elif device == "mps": # Load in half-precision bfloat16 using CPU low memory options on local Mac _model = AutoModelForMultimodalLM.from_pretrained( MODEL_ID, dtype=torch.bfloat16, low_cpu_mem_usage=True, device_map="auto", token=token, ) else: # Load in standard 32-bit floating point precision on standard CPU _model = AutoModelForMultimodalLM.from_pretrained( MODEL_ID, dtype=torch.float32, low_cpu_mem_usage=True, device_map="cpu", token=token, ) # Print success log when model initialization is finished print("Model loaded successfully.") return _model, _processor @gpu_decorator def generate_response( prompt: str, image: Image.Image | None = None, max_new_tokens: int = 1024, temperature: float = 0.4, ) -> str: """Generates a text response from Gemma 4 given a prompt and optional image.""" # Retrieve the model and processor from cache model, processor = load_gemma_model() # Resolve the active hardware device device = get_device() # Initialize message list content = [] # Append the image item if it is provided if image is not None: content.append({"type": "image", "image": image}) # Append the text prompt item content.append({"type": "text", "text": prompt}) # Wrap the contents inside a user role dict structure messages = [{"role": "user", "content": content}] # Format user inputs into the model's native chat syntax text_prompt = processor.apply_chat_template(messages, add_generation_prompt=True) # Run tokenization with the image if present if image is not None: inputs = processor(text=text_prompt, images=image, return_tensors="pt") # Run tokenization with text-only parameters if image is absent else: inputs = processor(text=text_prompt, return_tensors="pt") # Shift all input tensors to the target hardware device inputs = {k: v.to(device) for k, v in inputs.items()} # Convert visual features to match the exact data type of the model weights if "pixel_values" in inputs: inputs["pixel_values"] = inputs["pixel_values"].to(model.dtype) # Convert audio features to match the model weights dtype if "audio_values" in inputs: inputs["audio_values"] = inputs["audio_values"].to(model.dtype) # Generate text completions using no-gradient evaluation with torch.no_grad(): generated_ids = model.generate( **inputs, max_new_tokens=max_new_tokens, temperature=temperature, do_sample=True if temperature > 0.0 else False, ) # Extract the length of the input sequence to isolate the output completion input_len = inputs["input_ids"].shape[1] # Decode the newly generated token IDs, discarding the prompt tokens response_ids = generated_ids[0][input_len:] # Convert token IDs back to a readable text string response_text = processor.decode(response_ids, skip_special_tokens=True) return response_text.strip() def run_visual_extraction(image, doc_type: str) -> dict: """ Analyzes a shelf photo or invoice photo and returns structured JSON extraction. doc_type must be either 'shelf' or 'invoice'. """ if doc_type == "shelf": prompt = """You are an expert retail inventory scanner. Analyze this photo of a supermarket shelf. Identify each visible product package, detect the price written on the shelf label tag, and count the visible items on the shelf. Return a JSON object with a single key "shelf_detections" containing a list of objects with these keys: - "product_name_raw": The exact name of the product as seen on the package or shelf label. - "visible_quantity_estimate": An integer of how many units are currently visible. 0 if it is an empty slot. - "shelf_price_detected": The price on the shelf label as a decimal number (or null if not visible). - "shelf_label_present": A boolean (true if a price label/tag is present under the item, false otherwise). - "low_stock_signal": A boolean (true if the item is low on stock, i.e., 3 or fewer items left, or empty space). - "empty_space_signal": A boolean (true if the slot is completely empty or out of stock). - "confidence": A float between 0.0 and 1.0 representing your confidence in identifying this item. - "evidence": A brief description of the item's location on the shelf (e.g. "top shelf, middle"). Return ONLY valid raw JSON. Do not write any conversational text before or after the JSON. Do not include markdown code block formatting (like ```json). Ensure the output is valid JSON.""" elif doc_type == "invoice": prompt = """You are an expert accounting assistant. Analyze this wholesaler invoice/receipt. Extract each line item representing a product delivery. Return a JSON object with a single key "invoice_lines" containing a list of objects with these keys: - "product_name_raw": The exact name of the product listed on the invoice. - "quantity": The integer quantity delivered. - "unit_cost": The unit cost as a decimal number. - "total_cost": The total cost for this line item as a decimal number. - "confidence": A float between 0.0 and 1.0 representing your confidence in this extraction. Also include a key "invoice_meta" with: - "supplier_name": The name of the wholesale supplier. - "invoice_date": The date of the invoice (YYYY-MM-DD format if possible). - "invoice_number": The invoice or receipt reference number. Return ONLY valid raw JSON. Do not write any conversational text before or after the JSON. Do not include markdown code block formatting (like ```json). Ensure the output is valid JSON.""" else: raise ValueError("Invalid doc_type. Must be 'shelf' or 'invoice'.") try: response_text = generate_response(prompt, image=image, max_new_tokens=1536) # Clean response to isolate JSON content cleaned_text = response_text.strip() if cleaned_text.startswith("```"): # strip off ```json or ``` lines lines = cleaned_text.splitlines() if lines[0].startswith("```"): lines = lines[1:] if lines[-1].startswith("```"): lines = lines[:-1] cleaned_text = "\n".join(lines).strip() # Parse JSON data = json.loads(cleaned_text) return data except Exception as e: print(f"Error during VLM extraction or JSON parsing: {e}") # Return empty shell structure if doc_type == "shelf": return {"shelf_detections": []} else: return { "invoice_lines": [], "invoice_meta": { "supplier_name": "Unknown", "invoice_date": "", "invoice_number": "", }, } # Test loading mechanism when running the script directly if __name__ == "__main__": try: # Load the model and print the resolved hardware configuration model, processor = load_gemma_model() print( "Device map:", model.hf_device_map if hasattr(model, "hf_device_map") else "CPU", ) except Exception as e: # Print failure trace print("Error during test loading:", e)