File size: 4,202 Bytes
effa90f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
"""
Hugging Face Inference Endpoint用のカスタムハンドラー
DeepSeek-OCR LoRAモデル用
"""
from typing import Dict, List, Any
from PIL import Image
import io
import base64
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel

class EndpointHandler:
    def __init__(self, path=""):
        """
        Inference Endpointの初期化
        
        Args:
            path: モデルのパス(自動的に設定される)
        """
        # ベースモデル(DeepSeek-OCR)のロード
        base_model_name = "deepseek-ai/deepseek-vl-1.3b-chat"
        
        self.tokenizer = AutoTokenizer.from_pretrained(
            base_model_name,
            trust_remote_code=True
        )
        
        # ベースモデルをロード
        base_model = AutoModelForCausalLM.from_pretrained(
            base_model_name,
            trust_remote_code=True,
            torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
        )
        
        # LoRAアダプターを適用
        self.model = PeftModel.from_pretrained(
            base_model,
            path,
            torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
        )
        
        # 評価モードに設定
        self.model.eval()
        
        # GPUが利用可能な場合は移動
        if torch.cuda.is_available():
            self.model = self.model.cuda()
    
    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
        推論を実行
        
        Args:
            data: 入力データ
                {
                    "inputs": "base64エンコードされた画像文字列",
                    "prompt": "カレンダーで丸印がついている日付を全て抽出してください。数字のみをカンマ区切りで出力してください。"
                }
        
        Returns:
            推論結果
        """
        # 入力データを取得
        inputs = data.pop("inputs", data)
        prompt = data.pop("prompt", "カレンダーで丸印がついている日付を全て抽出してください。数字のみをカンマ区切りで出力してください。")
        
        # Base64デコード
        if isinstance(inputs, str):
            if inputs.startswith("data:image"):
                # data:image/png;base64,... の形式
                inputs = inputs.split(",")[1]
            
            image_bytes = base64.b64decode(inputs)
            image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
        elif isinstance(inputs, dict) and "image" in inputs:
            image = Image.open(io.BytesIO(base64.b64decode(inputs["image"]))).convert("RGB")
        else:
            return [{"error": "Invalid input format"}]
        
        # 画像を処理
        try:
            # モデルの入力形式に変換
            conversation = [
                {
                    "role": "User",
                    "content": f"<image>\n{prompt}",
                    "images": [image]
                },
                {
                    "role": "Assistant",
                    "content": ""
                }
            ]
            
            # プロンプトを準備
            prepare_inputs = self.model.prepare_inputs_for_generation(
                conversation,
                tokenizer=self.tokenizer
            )
            
            # 推論実行
            with torch.no_grad():
                outputs = self.model.generate(
                    **prepare_inputs,
                    max_new_tokens=512,
                    temperature=0.1,
                    do_sample=False,
                    pad_token_id=self.tokenizer.eos_token_id
                )
            
            # 結果をデコード
            answer = self.tokenizer.decode(
                outputs[0][len(prepare_inputs["input_ids"][0]):],
                skip_special_tokens=True
            )
            
            return [{"generated_text": answer.strip()}]
            
        except Exception as e:
            return [{"error": f"Inference error: {str(e)}"}]