| """Load models to use them as a narrator and a common-sense oracle in the PAYADOR pipeline.""" |
| import google.generativeai as genai |
| import requests |
| import os |
|
|
|
|
| class GeminiModel(): |
| def __init__ (self, api_key_file:str, model_name:str = "gemini-pro") -> None: |
| """"Initialize the Gemini model using an API key.""" |
| self.safety_settings = [ |
| { |
| "category": "HARM_CATEGORY_DANGEROUS", |
| "threshold": "BLOCK_NONE", |
| }, |
| { |
| "category": "HARM_CATEGORY_HARASSMENT", |
| "threshold": "BLOCK_NONE", |
| }, |
| { |
| "category": "HARM_CATEGORY_HATE_SPEECH", |
| "threshold": "BLOCK_NONE", |
| }, |
| { |
| "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", |
| "threshold": "BLOCK_NONE", |
| }, |
| { |
| "category": "HARM_CATEGORY_DANGEROUS_CONTENT", |
| "threshold": "BLOCK_NONE", |
| }, |
| ] |
| genai.configure(api_key=os.getenv(api_key_file)) |
| self.model = genai.GenerativeModel(model_name) |
|
|
| def prompt_model(self,prompt: str) -> str: |
| """Prompt the Gemini model.""" |
| return self.model.generate_content(prompt, safety_settings=self.safety_settings).text |
|
|
|
|
| def prompt_HF_API (prompt: str, model: str = "microsoft/Phi-3-mini-4k-instruct", api_key_file: str = "HF_API_key"): |
| API_URL = f"https://api-inference.huggingface.co/models/{model}" |
|
|
| headers = {"Authorization": f"Bearer {get_api_key(api_key_file)}"} |
| payload = {"inputs": prompt} |
|
|
| output = requests.post(API_URL, headers=headers, json=payload).json() |
| |
| return output[0]["generated_text"] |
|
|
| def get_api_key(path: str) -> str: |
| """Load an API key from path.""" |
| key = "" |
| with open(path) as f: |
| key = f.readline() |
| return key |