# Gemma 4 Model Loading Layer import os import torch from transformers import AutoModelForMultimodalLM, AutoProcessor # Specify the target gated Hugging Face repository ID MODEL_ID = "google/gemma-4-12B-it" # Initialize global loaders cache to prevent repeated re-initialization _model = None _processor = None _device = None def get_device() -> str: """Detects and returns the optimal available hardware execution device.""" global _device if _device is not None: return _device if torch.cuda.is_available(): _device = "cuda" elif torch.backends.mps.is_available(): _device = "mps" else: _device = "cpu" return _device def load_gemma_model(): """Loads and caches the Gemma 4 12B-it model and processor objects.""" global _model, _processor if _model is not None and _processor is not None: return _model, _processor token = os.getenv("HF_TOKEN") if not token or token == "hf_YOUR_WRITE_TOKEN_HERE": from config import IS_SPACES if IS_SPACES: raise ValueError( "HF_TOKEN environment variable is missing or invalid in the Space. " "Please add a Repository Secret named 'HF_TOKEN' containing your Hugging Face " "write token in your Space Settings (Settings -> Variables and secrets -> New secret)." ) else: raise ValueError( "HF_TOKEN is not set or holds a default placeholder in your .env file. " "Please open the .env file and set: HF_TOKEN=hf_your_write_token" ) device = get_device() print(f"Loading model {MODEL_ID} on device target: {device}...") # Load multimodal inputs processor _processor = AutoProcessor.from_pretrained(MODEL_ID, token=token) # Load model weights matching optimal precision and hardware constraints if device == "cuda": _model = AutoModelForMultimodalLM.from_pretrained( MODEL_ID, dtype=torch.bfloat16, device_map="auto", token=token ) elif device == "mps": _model = AutoModelForMultimodalLM.from_pretrained( MODEL_ID, dtype=torch.bfloat16, low_cpu_mem_usage=True, device_map="auto", token=token, ) else: _model = AutoModelForMultimodalLM.from_pretrained( MODEL_ID, dtype=torch.float32, low_cpu_mem_usage=True, device_map="cpu", token=token, ) print("Model initialized successfully.") return _model, _processor