ramgovindv's picture
Update inference.py
ff6e4b2 verified
import os
import json
import re
import logging
import warnings
from llama_cpp import Llama
from huggingface_hub import hf_hub_download
# Silence unnecessary logs
warnings.filterwarnings("ignore", category=UserWarning)
logging.getLogger("huggingface_hub").setLevel(logging.ERROR)
os.environ["GGML_PYTHON_VERBOSE"] = "0"
class HealthFunctionLM:
"""
A specialized wrapper for Llama-3.2 GGUF models to perform
health-related function calling.
"""
def __init__(self, repo_id: str, filename: str, n_ctx: int = 2048):
# Download model automatically from Hugging Face
model_path = hf_hub_download(repo_id=repo_id, filename=filename)
self.llm = Llama(
model_path=model_path,
n_ctx=n_ctx,
n_threads=os.cpu_count() or 4,
verbose=False
)
def _build_prompt(self, query: str) -> str:
return (
"You are an API generator. Return ONLY a JSON object.\n"
"Format: {\"name\": \"function_name\", \"parameters\": {\"key\": \"value\"}}\n\n"
f"User query: {query}\n\n"
"JSON:"
)
def _extract_json(self, text: str):
"""Extracts JSON even if the model wraps it in markdown blocks."""
# Remove <think> tags if present
text = re.sub(r"<think>.*?</think>", "", text, flags=re.DOTALL)
# Find JSON block
match = re.search(r"\{.*\}", text, re.DOTALL)
if match:
try:
return json.loads(match.group(0))
except json.JSONDecodeError:
return None
return None
def query(self, user_query: str):
prompt = self._build_prompt(user_query)
output = self.llm.create_chat_completion(
messages=[{"role": "user", "content": prompt}],
temperature=0.1
)
message = output["choices"][0]["message"]
content = message.get("content", "").strip()
# Try to parse function call
function_data = self._extract_json(content)
if function_data:
return {
"query": user_query,
"type": "function_call",
"data": function_data
}
return {
"query": user_query,
"type": "text",
"data": {"content": content}
}
# --- Example Usage ---
if __name__ == "__main__":
model = HealthFunctionLM(
repo_id="ramgovindv/health_function_call_llama3.2_3b_gguf",
filename="Llama-3.2-3B-Instruct.Q4_K_M.gguf"
)
res = model.query("I am feeling very dizzy for couple of days. what could be the reason")
print(json.dumps(res, indent=2))