# PYTORCH_ENABLE_MPS_FALLBACK=1 uvicorn main:app --host 0.0.0.0 --port 8888 --reload # PYTORCH_ENABLE_MPS_FALLBACK=1 gunicorn main:app -b 0.0.0.0:8000 -w 4 -k uvicorn.workers.UvicornWorker import io import re import os import logging import json from time import gmtime from datetime import datetime, timezone from scipy.io import wavfile from dotenv import load_dotenv from contextlib import asynccontextmanager from tts import synthesize, device from huggingface_hub import hf_hub_download from llama_cpp import Llama from fastapi import FastAPI, Response, Body, UploadFile, HTTPException from starlette.middleware.cors import CORSMiddleware load_dotenv(verbose=False) LOGGING_DIRECTORY = os.getenv('LOGGING_DIRECTORY', 'logs') if not os.path.isdir(LOGGING_DIRECTORY): os.makedirs(LOGGING_DIRECTORY) file_handler = logging.FileHandler(os.path.join(LOGGING_DIRECTORY, 'api.log'), mode='a', encoding='utf-8') formatter = logging.Formatter(fmt='%(asctime)s.%(msecs)03dZ - %(levelname)s - %(message)s', datefmt='%Y-%m-%dT%H:%M:%S') formatter.converter = gmtime file_handler.setFormatter(formatter) #logger = logging.getLogger('uvicorn') logger = logging.getLogger('gunicorn.error') logger.addHandler(file_handler) llm_prompt_format = os.getenv('LLM_PROMPT_FORMAT', None) model_path = os.environ.get('LLAMACPP_PATH', None) @asynccontextmanager async def lifespan(app: FastAPI): global model_path base_directory = 'data' for language in os.listdir(base_directory): path = os.path.join(base_directory, language) if os.path.isdir(path): for filename in os.listdir(path): _, extension = os.path.splitext(filename) if extension.lower() == '.wav': with open(os.path.join(path, filename), mode='rb') as f, io.BytesIO() as wave_bytes, open(os.path.join(path, 'prompt.txt'), 'r', encoding='utf-8') as prompt_file, open(os.path.join(path, 'input.txt'), 'r', encoding='utf-8') as input_file: wave_bytes.write(f.read()) wave_bytes.seek(0) synthesize(prompt_wave=wave_bytes, prompt_text=prompt_file.read(), prompt_language=language, input_text=input_file.read(), input_language=language, top_p=1, temperature=1) if model_path is None: model_path = hf_hub_download(repo_id=os.environ['LLAMACPP_REPO_ID'], filename=os.environ['LLAMACPP_FILENAME'], local_dir='./models') yield app = FastAPI(lifespan=lifespan) app.add_middleware(CORSMiddleware, allow_origins=['*'], allow_credentials=True, allow_methods=['*'], allow_headers=['*']) @app.get("/device") async def read_device(): return {'device': str(device), 'timestamp': int(datetime.now(timezone.utc).replace(tzinfo=timezone.utc).timestamp())} @app.post("/generate", status_code=201) def create_generated_text(messages: list[dict[str, str]] = Body(...), temperature: float = Body(default=1.0)): input_text = '' if llm_prompt_format == 'Llama': for message in messages: if message['role'] == 'system': input_text += f"<|start_header_id|>system<|end_header_id|>\n\n{message['content']}<|eot_id|>" elif message['role'] == 'user': input_text += f"<|start_header_id|>user<|end_header_id|>\n\n{message['content']}<|eot_id|>" elif message['role'] == 'assistant': input_text += f"<|start_header_id|>assistant<|end_header_id|>\n\n{message['content']}<|eot_id|>" input_text += '<|start_header_id|>assistant<|end_header_id|>\n\n' pattern = r'<|start_header_id|>assistant<|end_header_id|>\n\n(.+?)(?:(?:<|eot_id|>)|$)' else: for message in messages: if message['role'] == 'system' or message['role'] == 'user': input_text += f"user\n{message['content']}\n" elif message['role'] == 'assistant': input_text += f"model\n{message['content']}\n" input_text += 'model\n' pattern = r'model\n(.+?)(?:(?:)|$)' if len(input_text) > 0: llm = Llama(model_path=model_path, n_ctx=8192, n_gpu_layers=-1, n_batch=32, verbose=False) choices = [] try: for choice in llm(input_text, max_tokens=2048, temperature=temperature, top_p=0.95, echo=True)['choices']: matches = re.findall(pattern, choice['text'], re.DOTALL) if len(matches) > 0: choices.append({'role': 'assistant', 'content': matches[len(matches) - 1]}) finally: llm.close() return {'choices': choices, 'timestamp': int(datetime.now(timezone.utc).replace(tzinfo=timezone.utc).timestamp())} else: raise HTTPException(status_code=400) @app.post("/synthesize", status_code=201) def create_uploaded_file(file: UploadFile, data = Body(...)): if file.content_type == 'audio/wav': try: data = json.loads(data) with io.BytesIO() as prompt_wave_bytes, io.BytesIO() as output_wave_bytes: prompt_wave_bytes.write(file.file.read()) prompt_wave_bytes.seek(0) output, sample_rate = synthesize(prompt_wave=prompt_wave_bytes, prompt_text=data['prompt'] if 'prompt' in data else None, prompt_language=data['language'], input_text=data['input'], input_language=data['language'], top_p=data['top_p'] if 'top_p' in data else 1.0, temperature=data['temperature'] if 'temperature' in data else 1.0) wavfile.write(output_wave_bytes, sample_rate, output) output_wave_bytes.seek(0) return Response(content=output_wave_bytes.read(), media_type="audio/wav") except Exception as e: logging.error(f'{e}') raise HTTPException(status_code=400, detail=str(e)) else: raise HTTPException(status_code=400)