takumi123xxx's picture
Upload folder using huggingface_hub
effa90f verified
raw
history blame
4.2 kB
"""
Hugging Face Inference Endpoint用のカスタムハンドラー
DeepSeek-OCR LoRAモデル用
"""
from typing import Dict, List, Any
from PIL import Image
import io
import base64
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
class EndpointHandler:
def __init__(self, path=""):
"""
Inference Endpointの初期化
Args:
path: モデルのパス(自動的に設定される)
"""
# ベースモデル(DeepSeek-OCR)のロード
base_model_name = "deepseek-ai/deepseek-vl-1.3b-chat"
self.tokenizer = AutoTokenizer.from_pretrained(
base_model_name,
trust_remote_code=True
)
# ベースモデルをロード
base_model = AutoModelForCausalLM.from_pretrained(
base_model_name,
trust_remote_code=True,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)
# LoRAアダプターを適用
self.model = PeftModel.from_pretrained(
base_model,
path,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)
# 評価モードに設定
self.model.eval()
# GPUが利用可能な場合は移動
if torch.cuda.is_available():
self.model = self.model.cuda()
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
推論を実行
Args:
data: 入力データ
{
"inputs": "base64エンコードされた画像文字列",
"prompt": "カレンダーで丸印がついている日付を全て抽出してください。数字のみをカンマ区切りで出力してください。"
}
Returns:
推論結果
"""
# 入力データを取得
inputs = data.pop("inputs", data)
prompt = data.pop("prompt", "カレンダーで丸印がついている日付を全て抽出してください。数字のみをカンマ区切りで出力してください。")
# Base64デコード
if isinstance(inputs, str):
if inputs.startswith("data:image"):
# data:image/png;base64,... の形式
inputs = inputs.split(",")[1]
image_bytes = base64.b64decode(inputs)
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
elif isinstance(inputs, dict) and "image" in inputs:
image = Image.open(io.BytesIO(base64.b64decode(inputs["image"]))).convert("RGB")
else:
return [{"error": "Invalid input format"}]
# 画像を処理
try:
# モデルの入力形式に変換
conversation = [
{
"role": "User",
"content": f"<image>\n{prompt}",
"images": [image]
},
{
"role": "Assistant",
"content": ""
}
]
# プロンプトを準備
prepare_inputs = self.model.prepare_inputs_for_generation(
conversation,
tokenizer=self.tokenizer
)
# 推論実行
with torch.no_grad():
outputs = self.model.generate(
**prepare_inputs,
max_new_tokens=512,
temperature=0.1,
do_sample=False,
pad_token_id=self.tokenizer.eos_token_id
)
# 結果をデコード
answer = self.tokenizer.decode(
outputs[0][len(prepare_inputs["input_ids"][0]):],
skip_special_tokens=True
)
return [{"generated_text": answer.strip()}]
except Exception as e:
return [{"error": f"Inference error: {str(e)}"}]