| import time |
| import json |
| from typing import List, Literal |
|
|
| from fastapi import FastAPI |
| from fastapi.responses import StreamingResponse, JSONResponse |
| from pydantic import BaseModel |
|
|
| from huggingface_hub import InferenceClient |
|
|
| app = FastAPI() |
| client = InferenceClient( |
| "mistralai/Mistral-7B-Instruct-v0.2" |
| ) |
|
|
| class Message(BaseModel): |
| role: Literal["user", "assistant"] |
| content: str |
|
|
| class Payload(BaseModel): |
| stream: bool = False |
| model: Literal["mistral-7b-instruct-v0.2"] = "mistral-7b-instruct-v0.2" |
| messages: List[Message] |
| temperature: float = 0.9 |
| frequency_penalty: float = 1.2 |
| top_p: float = 0.9 |
|
|
| async def stream(iter): |
| while True: |
| try: |
| value = await asyncio.to_thread(iter.__next__) |
| yield value |
| except StopIteration: |
| break |
|
|
| def format_prompt(messages: List[Message]): |
| prompt = "<s>" |
| |
| for message in messages: |
| if message['role'] == 'user': |
| prompt += f"[INST] {message['content']} [/INST]" |
| else: |
| prompt += f" {message['content']}</s> " |
|
|
| return prompt |
|
|
| def make_chunk_obj(i, delta, fr): |
| return { |
| "id": str(time.time_ns()), |
| "object": "chat.completion.chunk", |
| "created": round(time.time()), |
| "model": "mistral-7b-instruct-v0.2", |
| "system_fingerprint": "wtf", |
| "choices": [ |
| { |
| "index": i, |
| "delta": { |
| "content": delta |
| }, |
| "finish_reason": fr |
| } |
| ] |
| } |
|
|
| def generate( |
| messages, |
| temperature=0.9, |
| max_new_tokens=256, |
| top_p=0.95, |
| repetition_penalty=1.0, |
| ): |
| temperature = float(temperature) |
| if temperature < 1e-2: |
| temperature = 1e-2 |
| top_p = float(top_p) |
|
|
| generate_kwargs = dict( |
| temperature=temperature, |
| max_new_tokens=max_new_tokens, |
| top_p=top_p, |
| repetition_penalty=repetition_penalty, |
| do_sample=True, |
| seed=None |
| ) |
|
|
| formatted_prompt = format_prompt(messages) |
|
|
| stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False) |
|
|
| for response in stream: |
| t = response.token.text |
| yield t if t != "</s>" else "" |
| |
| |
|
|
| def generate_norm(*args) -> str: |
| t = "" |
| for chunk in generate(*args): |
| t += chunk |
| return t |
|
|
| @app.get('/') |
| async def index(): |
| return JSONResponse({ "message": "hello", "url": "https://aweirddev-mistral-7b-instruct-v0-2-leicht.hf.space" }) |
|
|
| @app.post('/chat/completions') |
| async def c_cmp(payload: Payload): |
| if not payload.stream: |
| return JSONResponse( |
| { |
| "id": str(time.time_ns()), |
| "object": "chat.completion", |
| "created": round(time.time()), |
| "model": payload.model, |
| "system_fingerprint": "wtf", |
| "choices": [ |
| { |
| "index": 0, |
| "message": { |
| "role": "assistant", |
| "content": generate_norm( |
| payload.model_dump()['messages'], |
| payload.temperature, |
| 4096, |
| payload.top_p, |
| payload.frequency_penalty |
| ) |
| } |
| } |
| ] |
| } |
| ) |
| |
| |
| def streamer(): |
| text = "" |
| result = generate( |
| payload.model_dump()['messages'], |
| payload.temperature, |
| 4096, |
| payload.top_p, |
| payload.frequency_penalty, |
| ) |
| for i, item in enumerate(result): |
| yield item |
|
|
| return StreamingResponse(streamer()) |
|
|