""" Hugging Face Inference Endpoint用のカスタムハンドラー DeepSeek-OCR LoRAモデル用 """ from typing import Dict, List, Any from PIL import Image import io import base64 import torch from transformers import AutoModelForCausalLM, AutoTokenizer from peft import PeftModel class EndpointHandler: def __init__(self, path=""): """ Inference Endpointの初期化 Args: path: モデルのパス(自動的に設定される) """ # ベースモデル(DeepSeek-OCR)のロード base_model_name = "deepseek-ai/deepseek-vl-1.3b-chat" self.tokenizer = AutoTokenizer.from_pretrained( base_model_name, trust_remote_code=True ) # ベースモデルをロード base_model = AutoModelForCausalLM.from_pretrained( base_model_name, trust_remote_code=True, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 ) # LoRAアダプターを適用 self.model = PeftModel.from_pretrained( base_model, path, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 ) # 評価モードに設定 self.model.eval() # GPUが利用可能な場合は移動 if torch.cuda.is_available(): self.model = self.model.cuda() def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ 推論を実行 Args: data: 入力データ { "inputs": "base64エンコードされた画像文字列", "prompt": "カレンダーで丸印がついている日付を全て抽出してください。数字のみをカンマ区切りで出力してください。" } Returns: 推論結果 """ # 入力データを取得 inputs = data.pop("inputs", data) prompt = data.pop("prompt", "カレンダーで丸印がついている日付を全て抽出してください。数字のみをカンマ区切りで出力してください。") # Base64デコード if isinstance(inputs, str): if inputs.startswith("data:image"): # data:image/png;base64,... の形式 inputs = inputs.split(",")[1] image_bytes = base64.b64decode(inputs) image = Image.open(io.BytesIO(image_bytes)).convert("RGB") elif isinstance(inputs, dict) and "image" in inputs: image = Image.open(io.BytesIO(base64.b64decode(inputs["image"]))).convert("RGB") else: return [{"error": "Invalid input format"}] # 画像を処理 try: # モデルの入力形式に変換 conversation = [ { "role": "User", "content": f"\n{prompt}", "images": [image] }, { "role": "Assistant", "content": "" } ] # プロンプトを準備 prepare_inputs = self.model.prepare_inputs_for_generation( conversation, tokenizer=self.tokenizer ) # 推論実行 with torch.no_grad(): outputs = self.model.generate( **prepare_inputs, max_new_tokens=512, temperature=0.1, do_sample=False, pad_token_id=self.tokenizer.eos_token_id ) # 結果をデコード answer = self.tokenizer.decode( outputs[0][len(prepare_inputs["input_ids"][0]):], skip_special_tokens=True ) return [{"generated_text": answer.strip()}] except Exception as e: return [{"error": f"Inference error: {str(e)}"}]