""" Standalone inference for Ukiyo-e Haiku VLM. Loads SigLIP (vision) + custom projector + Qwen2.5-3B-Instruct + LoRA adapter and generates a haiku for an input image. Usage: python inference.py path/to/image.jpg python inference.py path/to/image.jpg --device cpu # if no CUDA python inference.py path/to/image.jpg --qlora # 4-bit base for low-VRAM GPUs """ import argparse from pathlib import Path import torch import torch.nn as nn from peft import LoraConfig, PeftModel, get_peft_model from PIL import Image from transformers import AutoModel, AutoModelForCausalLM, AutoProcessor, AutoTokenizer VISION_NAME = "google/siglip-base-patch16-224" LM_NAME = "Qwen/Qwen2.5-3B-Instruct" IMG_TOKEN = "<|image|>" N_IMG_TOKENS = 196 # 14x14 patches at patch_size=16, image_size=224 class HaikuVLM(nn.Module): """SigLIP (frozen) + 2-layer MLP projector + Qwen2.5-3B + LoRA.""" def __init__(self, use_qlora: bool = False) -> None: super().__init__() self.vision = AutoModel.from_pretrained(VISION_NAME).vision_model for p in self.vision.parameters(): p.requires_grad = False self.vision.eval() if use_qlora: from transformers import BitsAndBytesConfig bnb = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16, ) lm = AutoModelForCausalLM.from_pretrained(LM_NAME, quantization_config=bnb, torch_dtype=torch.bfloat16) else: lm = AutoModelForCausalLM.from_pretrained(LM_NAME, torch_dtype=torch.bfloat16) for p in lm.parameters(): p.requires_grad = False lora_cfg = LoraConfig( r=16, lora_alpha=32, lora_dropout=0.05, bias="none", target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], task_type="CAUSAL_LM", ) self.lm = get_peft_model(lm, lora_cfg) v_dim = self.vision.config.hidden_size l_dim = self.lm.config.hidden_size self.projector = nn.Sequential( nn.Linear(v_dim, l_dim), nn.GELU(), nn.Linear(l_dim, l_dim), ) self.img_token_id: int | None = None def encode_image(self, pixel_values): with torch.no_grad(): out = self.vision(pixel_values=pixel_values) feats = out.last_hidden_state feats = feats.float() if self.projector[0].weight.dtype == torch.float32 else feats return self.projector(feats).to(torch.bfloat16) def load(self, in_dir: str) -> None: self.lm.load_adapter(in_dir, adapter_name="default") self.projector.load_state_dict(torch.load(f"{in_dir}/projector.pt", map_location="cpu")) def build_tokenizer_with_image_token(): tok = AutoTokenizer.from_pretrained(LM_NAME) tok.add_tokens([IMG_TOKEN], special_tokens=True) img_token_id = tok.convert_tokens_to_ids(IMG_TOKEN) return tok, img_token_id @torch.no_grad() def generate_haiku(model, tokenizer, processor, image_path: str, device: str, prompt: str, max_new_tokens: int = 80) -> str: img = Image.open(image_path).convert("RGB") pixel_values = processor(images=img, return_tensors="pt").pixel_values.to(device) img_ph = IMG_TOKEN * N_IMG_TOKENS messages = [{"role": "user", "content": f"{img_ph}{prompt}"}] text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) input_ids = tokenizer(text, return_tensors="pt").input_ids.to(device) embed_layer = model.lm.get_input_embeddings() text_embeds = embed_layer(input_ids) img_embeds = model.encode_image(pixel_values) text_embeds[input_ids == model.img_token_id] = img_embeds.reshape(-1, img_embeds.shape[-1]) out = model.lm.generate( inputs_embeds=text_embeds, max_new_tokens=max_new_tokens, do_sample=False, pad_token_id=tokenizer.eos_token_id, ) return tokenizer.decode(out[0], skip_special_tokens=True).strip() def main(): p = argparse.ArgumentParser() p.add_argument("image", help="Path to image file") p.add_argument("--ckpt", default=".", help="Path to checkpoint dir (default: current dir)") p.add_argument("--lang", default="en", choices=["en", "jp"]) p.add_argument("--device", default=None) p.add_argument("--qlora", action="store_true", help="Use 4-bit base for low-VRAM GPUs (e.g. GTX 1650)") args = p.parse_args() device = args.device or ("cuda" if torch.cuda.is_available() else "cpu") prompt = "この浮世絵を5-7-5の俳句で描写してください。" if args.lang == "jp" \ else "Describe this ukiyo-e print as a 5-7-5 haiku." print(f"Loading model (this can take a couple of minutes the first time)...") tok, img_token_id = build_tokenizer_with_image_token() proc = AutoProcessor.from_pretrained(VISION_NAME) model = HaikuVLM(use_qlora=args.qlora).to(device) model.lm.resize_token_embeddings(len(tok)) model.img_token_id = img_token_id model.load(args.ckpt) model.eval() print(f"\nImage: {args.image}") print(f"Prompt: {prompt}\n") print(generate_haiku(model, tok, proc, args.image, device=device, prompt=prompt)) if __name__ == "__main__": main()