ramgovindv commited on
Commit
ff6e4b2
·
verified ·
1 Parent(s): 9429f10

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +48 -144
inference.py CHANGED
@@ -3,179 +3,83 @@ import json
3
  import re
4
  import logging
5
  import warnings
6
- from typing import Optional, Dict, Any
7
-
8
  from llama_cpp import Llama
 
9
 
10
-
11
- # ---------------------------
12
- # Silence noisy logs
13
- # ---------------------------
14
  warnings.filterwarnings("ignore", category=UserWarning)
15
  logging.getLogger("huggingface_hub").setLevel(logging.ERROR)
16
-
17
  os.environ["GGML_PYTHON_VERBOSE"] = "0"
18
- os.environ["LLAMA_CPP_LIB_VERBOSE"] = "0"
19
-
20
 
21
- # ---------------------------
22
- # Core Class
23
- # ---------------------------
24
  class HealthFunctionLM:
25
- def __init__(
26
- self,
27
- repo_id: Optional[str] = None,
28
- filename: Optional[str] = None,
29
- n_ctx: int = 2048,
30
- n_threads: int = 4,
31
- temperature: float = 0.1,
32
- ):
33
- """
34
- Initialize model
35
-
36
- Args:
37
- repo_id: Hugging Face repo ID
38
- filename: GGUF file name
39
- n_ctx: context length
40
- n_threads: CPU threads
41
- temperature: sampling temperature
42
- """
43
-
44
- # Defaults (easy mode)
45
- self.repo_id = repo_id or "ramgovindv/health_function_call_llama3.2_3b_gguf"
46
- self.filename = filename or "Llama-3.2-3B-Instruct.Q4_K_M.gguf"
47
- self.temperature = temperature
48
-
49
- if not self.filename.endswith(".gguf"):
50
- raise ValueError("Only GGUF models are supported")
51
-
52
- self.llm = Llama.from_pretrained(
53
- repo_id=self.repo_id,
54
- filename=self.filename,
55
  n_ctx=n_ctx,
56
- n_threads=n_threads,
57
- chat_format=None,
58
- verbose=False,
59
  )
60
 
61
- # ---------------------------
62
- # Prompt Builder
63
- # ---------------------------
64
  def _build_prompt(self, query: str) -> str:
65
- return f"""
66
- You are an API generator.
67
-
68
- Return ONLY valid JSON in this format:
69
- {{
70
- "name": "function_name",
71
- "parameters": {{}},
72
- "reasoning": "optional short explanation"
73
- }}
74
-
75
- User query:
76
- {query}
77
-
78
- JSON:
79
- """
80
-
81
- # ---------------------------
82
- # Model Call
83
- # ---------------------------
84
- def _generate(self, prompt: str) -> str:
85
- response = self.llm.create_chat_completion(
86
- messages=[{"role": "user", "content": prompt}],
87
- temperature=self.temperature,
88
  )
89
 
90
- return response["choices"][0]["message"]["content"].strip()
91
-
92
- # ---------------------------
93
- # Parsing Utilities
94
- # ---------------------------
95
- def _safe_json_load(self, text: str) -> Optional[Dict[str, Any]]:
96
- """
97
- Try strict JSON parsing first.
98
- Fallback: extract first JSON block.
99
- """
100
- try:
101
- return json.loads(text)
102
- except Exception:
103
- pass
104
-
105
- # fallback: extract JSON substring
106
  match = re.search(r"\{.*\}", text, re.DOTALL)
107
  if match:
108
  try:
109
  return json.loads(match.group(0))
110
- except Exception:
111
  return None
112
-
113
  return None
114
 
115
- # ---------------------------
116
- # Public Query API
117
- # ---------------------------
118
- def query(self, user_query: str) -> Dict[str, Any]:
119
  prompt = self._build_prompt(user_query)
120
- raw_output = self._generate(prompt)
121
-
122
- parsed = self._safe_json_load(raw_output)
123
-
124
- if parsed:
 
 
 
 
 
 
 
125
  return {
126
  "query": user_query,
127
  "type": "function_call",
128
- "data": {
129
- "name": parsed.get("name"),
130
- "parameters": parsed.get("parameters", {}),
131
- "reasoning": parsed.get("reasoning"),
132
- },
133
- "raw": raw_output,
134
  }
135
 
136
- # fallback (model messed up)
137
  return {
138
  "query": user_query,
139
  "type": "text",
140
- "data": {
141
- "content": raw_output,
142
- "reasoning": None,
143
- },
144
  }
145
 
146
-
147
- # ---------------------------
148
- # Simple Loader (User Entry)
149
- # ---------------------------
150
- def load_model(
151
- repo_id: Optional[str] = None,
152
- filename: Optional[str] = None,
153
- **kwargs
154
- ) -> HealthFunctionLM:
155
- """
156
- Easy entry point for users
157
-
158
- Example:
159
- model = load_model()
160
- model = load_model(repo_id="other/repo", filename="model.gguf")
161
- """
162
- return HealthFunctionLM(
163
- repo_id=repo_id,
164
- filename=filename,
165
- **kwargs
166
- )
167
-
168
-
169
- # ---------------------------
170
- # Optional CLI usage
171
- # ---------------------------
172
  if __name__ == "__main__":
173
- model = load_model()
174
-
175
- while True:
176
- q = input("\nEnter query (or 'exit'): ")
177
- if q.lower() == "exit":
178
- break
179
-
180
- result = model.query(q)
181
- print(json.dumps(result, indent=2))
 
3
  import re
4
  import logging
5
  import warnings
 
 
6
  from llama_cpp import Llama
7
+ from huggingface_hub import hf_hub_download
8
 
9
+ # Silence unnecessary logs
 
 
 
10
  warnings.filterwarnings("ignore", category=UserWarning)
11
  logging.getLogger("huggingface_hub").setLevel(logging.ERROR)
 
12
  os.environ["GGML_PYTHON_VERBOSE"] = "0"
 
 
13
 
 
 
 
14
  class HealthFunctionLM:
15
+ """
16
+ A specialized wrapper for Llama-3.2 GGUF models to perform
17
+ health-related function calling.
18
+ """
19
+ def __init__(self, repo_id: str, filename: str, n_ctx: int = 2048):
20
+ # Download model automatically from Hugging Face
21
+ model_path = hf_hub_download(repo_id=repo_id, filename=filename)
22
+
23
+ self.llm = Llama(
24
+ model_path=model_path,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  n_ctx=n_ctx,
26
+ n_threads=os.cpu_count() or 4,
27
+ verbose=False
 
28
  )
29
 
 
 
 
30
  def _build_prompt(self, query: str) -> str:
31
+ return (
32
+ "You are an API generator. Return ONLY a JSON object.\n"
33
+ "Format: {\"name\": \"function_name\", \"parameters\": {\"key\": \"value\"}}\n\n"
34
+ f"User query: {query}\n\n"
35
+ "JSON:"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  )
37
 
38
+ def _extract_json(self, text: str):
39
+ """Extracts JSON even if the model wraps it in markdown blocks."""
40
+ # Remove <think> tags if present
41
+ text = re.sub(r"<think>.*?</think>", "", text, flags=re.DOTALL)
42
+ # Find JSON block
 
 
 
 
 
 
 
 
 
 
 
43
  match = re.search(r"\{.*\}", text, re.DOTALL)
44
  if match:
45
  try:
46
  return json.loads(match.group(0))
47
+ except json.JSONDecodeError:
48
  return None
 
49
  return None
50
 
51
+ def query(self, user_query: str):
 
 
 
52
  prompt = self._build_prompt(user_query)
53
+ output = self.llm.create_chat_completion(
54
+ messages=[{"role": "user", "content": prompt}],
55
+ temperature=0.1
56
+ )
57
+
58
+ message = output["choices"][0]["message"]
59
+ content = message.get("content", "").strip()
60
+
61
+ # Try to parse function call
62
+ function_data = self._extract_json(content)
63
+
64
+ if function_data:
65
  return {
66
  "query": user_query,
67
  "type": "function_call",
68
+ "data": function_data
 
 
 
 
 
69
  }
70
 
 
71
  return {
72
  "query": user_query,
73
  "type": "text",
74
+ "data": {"content": content}
 
 
 
75
  }
76
 
77
+ # --- Example Usage ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  if __name__ == "__main__":
79
+ model = HealthFunctionLM(
80
+ repo_id="ramgovindv/health_function_call_llama3.2_3b_gguf",
81
+ filename="Llama-3.2-3B-Instruct.Q4_K_M.gguf"
82
+ )
83
+
84
+ res = model.query("I am feeling very dizzy for couple of days. what could be the reason")
85
+ print(json.dumps(res, indent=2))