File size: 2,724 Bytes
9429f10
9c26840
 
9429f10
 
 
ff6e4b2
9429f10
ff6e4b2
9429f10
 
 
 
9c26840
ff6e4b2
 
 
 
 
 
 
 
 
 
9c26840
ff6e4b2
 
9c26840
 
9429f10
ff6e4b2
 
 
 
 
9c26840
 
ff6e4b2
 
 
 
 
9429f10
 
 
 
ff6e4b2
9429f10
 
9c26840
ff6e4b2
9c26840
ff6e4b2
 
 
 
 
 
 
 
 
 
 
 
9429f10
 
 
ff6e4b2
9429f10
9c26840
 
 
9429f10
ff6e4b2
9c26840
 
ff6e4b2
9429f10
ff6e4b2
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import os
import json
import re
import logging
import warnings
from llama_cpp import Llama
from huggingface_hub import hf_hub_download

# Silence unnecessary logs
warnings.filterwarnings("ignore", category=UserWarning)
logging.getLogger("huggingface_hub").setLevel(logging.ERROR)
os.environ["GGML_PYTHON_VERBOSE"] = "0"

class HealthFunctionLM:
    """
    A specialized wrapper for Llama-3.2 GGUF models to perform 
    health-related function calling.
    """
    def __init__(self, repo_id: str, filename: str, n_ctx: int = 2048):
        # Download model automatically from Hugging Face
        model_path = hf_hub_download(repo_id=repo_id, filename=filename)
        
        self.llm = Llama(
            model_path=model_path,
            n_ctx=n_ctx,
            n_threads=os.cpu_count() or 4,
            verbose=False
        )

    def _build_prompt(self, query: str) -> str:
        return (
            "You are an API generator. Return ONLY a JSON object.\n"
            "Format: {\"name\": \"function_name\", \"parameters\": {\"key\": \"value\"}}\n\n"
            f"User query: {query}\n\n"
            "JSON:"
        )

    def _extract_json(self, text: str):
        """Extracts JSON even if the model wraps it in markdown blocks."""
        # Remove <think> tags if present
        text = re.sub(r"<think>.*?</think>", "", text, flags=re.DOTALL)
        # Find JSON block
        match = re.search(r"\{.*\}", text, re.DOTALL)
        if match:
            try:
                return json.loads(match.group(0))
            except json.JSONDecodeError:
                return None
        return None

    def query(self, user_query: str):
        prompt = self._build_prompt(user_query)
        output = self.llm.create_chat_completion(
            messages=[{"role": "user", "content": prompt}],
            temperature=0.1
        )
        
        message = output["choices"][0]["message"]
        content = message.get("content", "").strip()
        
        # Try to parse function call
        function_data = self._extract_json(content)
        
        if function_data:
            return {
                "query": user_query,
                "type": "function_call",
                "data": function_data
            }

        return {
            "query": user_query,
            "type": "text",
            "data": {"content": content}
        }

# --- Example Usage ---
if __name__ == "__main__":
    model = HealthFunctionLM(
        repo_id="ramgovindv/health_function_call_llama3.2_3b_gguf",
        filename="Llama-3.2-3B-Instruct.Q4_K_M.gguf"
    )
    
    res = model.query("I am feeling very dizzy for couple of days. what could be the reason")
    print(json.dumps(res, indent=2))