aayubot-gemma-2b / README.md
thesatyam12's picture
Update README.md
8c6113e verified
|
Raw
History Blame Contribute Delete
8.63 kB
metadata
license: cc-by-4.0
library_name: transformers
base_model: google/gemma-2b-it
tags:
  - aayuai
  - aayu-ai
  - aayubot
  - medical-ai
  - healthcare-ai
  - conversational-ai
  - chatbot
  - llm
  - generative-ai
  - gemma
  - gemma-2b
  - transformers
  - nlp
  - fine-tuned
  - medical-chatbot
  - clinical-ai
  - digital-health
  - ai-safety
  - safe-ai
  - responsible-ai
  - ai-alignment
  - content-moderation
  - severity-aware
  - virtual-assistant
  - open-source
  - research-project
  - startup
  - innovation
  - lora
  - peft
  - safety

AayuBot Gemma-2B (LoRA Adapter)

Severity-aware, safety-focused medical conversational AI built on top of google/gemma-2b-it using QLoRA fine-tuning.

Field Detail
Creator Satyam Tiwari (thesatyam12)
Base model google/gemma-2b-it
Method QLoRA (4-bit NF4) / PEFT adapter
License CC BY 4.0 (adapter weights only)
DOI 10.57967/hf/7776
Status Research β€” Manuscript submitted

Available Checkpoints

Folder Training Steps Recommended?
checkpoint1-20k 20,000
checkpoint2-40k 40,000
checkpoint3-60k 60,000
checkpoint4-80k 80,000
checkpoint5-100k 100,000 βœ… Best / Final

By default, all code below uses checkpoint5-100k.


Step 1 β€” Install dependencies

# Create a fresh environment (recommended)
pip install --upgrade pip

# Core packages
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

# Hugging Face + Gemma + LoRA stack
pip install transformers accelerate bitsandbytes peft

# Login to Hugging Face (required for Gemma and our adapter access)
pip install huggingface_hub
#then
huggingface-cli login #or
from huggingface_hub import login
login("YOUR_HF_TOKEN")

Gemma access: You must first accept Google's Gemma license at https://huggingface.co/google/gemma-2b-it (click "Agree and access repository"). Otherwise the base model download will fail.


Step 2 β€” Download the adapter

Option A: Clone the full repo (all checkpoints)

git lfs install
git clone https://huggingface.co/thesatyam12/aayubot-gemma-2b

Option B: Download only checkpoint5-100k (recommended)

from huggingface_hub import snapshot_download

snapshot_download(
    repo_id="thesatyam12/aayubot-gemma-2b",
    allow_patterns="checkpoint5-100k/*",
    local_dir="./aayubot-gemma-2b",
)

Step 3 β€” Load the model

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel

# ── Config ──────────────────────────────────────────────
BASE_MODEL = "google/gemma-2b-it"
ADAPTER_REPO = "thesatyam12/aayubot-gemma-2b"
CHECKPOINT = "checkpoint5-100k"          # change if you want another checkpoint
# ────────────────────────────────────────────────────────

# 4-bit quantization (same settings used during training)
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
)

# Load tokenizer from the adapter checkpoint
tokenizer = AutoTokenizer.from_pretrained(
    ADAPTER_REPO,
    subfolder=CHECKPOINT,
)

# Load base Gemma model in 4-bit
base_model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    quantization_config=bnb_config,
    device_map="auto",
    torch_dtype=torch.float16,
    attn_implementation="eager",
)

# Attach the LoRA adapter
model = PeftModel.from_pretrained(
    base_model,
    ADAPTER_REPO,
    subfolder=CHECKPOINT,
)
model.eval()

print("βœ… AayuBot loaded successfully!")

Step 4 β€” Run inference

def ask_aayubot(question, severity="simple"):
    # Step 1: Generate initial response
    prompt = (
        f"<start_of_turn>user\n"
        f"[Severity: {severity}]\n"
        f"{question}\n"
        f"<end_of_turn>\n"
        f"<start_of_turn>model\n"
    )

    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    with torch.no_grad():
        output_ids = model.generate(
            **inputs,
            max_new_tokens=300,
            temperature=0.7,
            top_p=0.9,
            do_sample=True,
            repetition_penalty=1.25,
            no_repeat_ngram_size=5,
            pad_token_id=tokenizer.eos_token_id,
        )

    raw_response = tokenizer.decode(
        output_ids[0][inputs["input_ids"].shape[-1]:],
        skip_special_tokens=True,
    ).strip()

    # Step 2: Post-processing (Clean + One Disclaimer)
    post_prompt = (
        f"<start_of_turn>user\n"
        f"Rewrite the following medical answer cleanly. "
        f"Remove all repeated lines and duplicate disclaimers. "
        f"Keep only ONE clear disclaimer at the very end. "
        f"Make the response professional and concise.\n\n"
        f"Text:\n{raw_response}\n"
        f"<end_of_turn>\n"
        f"<start_of_turn>model\n"
    )

    post_inputs = tokenizer(post_prompt, return_tensors="pt").to(model.device)

    with torch.no_grad():
        post_output = model.generate(
            **post_inputs,
            max_new_tokens=300,
            temperature=0.4,
            top_p=0.95,
            do_sample=True,
            repetition_penalty=1.2,
            pad_token_id=tokenizer.eos_token_id,
        )

    final_response = tokenizer.decode(
        post_output[0][post_inputs["input_ids"].shape[-1]:],
        skip_special_tokens=True,
    ).strip()

    # Step 3: Final cleaning (remove any remaining duplicates)
    lines = []
    for line in final_response.split("\n"):
        line = line.strip()
        if line and line not in lines:
            lines.append(line)

    cleaned_response = "\n".join(lines)

    # Ensure disclaimer appears only once at the end
    disclaimer = "This is general information only. Please consult a doctor for proper diagnosis and treatment."

    # Remove any existing disclaimer lines
    cleaned_lines = [l for l in cleaned_response.split("\n") if "general information only" not in l.lower()]
    cleaned_response = "\n".join(cleaned_lines).strip()

    # Add single disclaimer at the end
    if disclaimer.lower() not in cleaned_response.lower():
        cleaned_response += "\n\n" + disclaimer

    return cleaned_response.strip()


# ── Try it ──
print(ask_aayubot("What should I do for a mild headache?", severity="simple"))
print(ask_aayubot("I have chest pain and shortness of breath", severity="high"))

Prompt format

AayuBot expects a severity tag before the user query:

<start_of_turn>user
[Severity: simple|medium|high]
Your medical question here
<end_of_turn>
<start_of_turn>model
Severity When to use
simple General wellness, mild symptoms
medium Moderate symptoms, medication questions
high Emergency-level, chest pain, severe symptoms

Loading from a local clone

If you used git clone (Option A above):

# Just point to the local path instead of the HF repo ID
model = PeftModel.from_pretrained(
    base_model,
    "./aayubot-gemma-2b/checkpoint5-100k",
)

Paper / Citation

AayuAI: Transform-Not-Delete Safety Curation and Severity-Aware Fine-Tuning for Medical Conversational AI

Authors: Jayraj S. Lakkad, Satyam R. Tiwari, Laksh J. Savaliya, Akshar V. Prajapati, Dipali Kasat, Dr. Premalkumar J. Patel

@misc{aayuai2025,
  title   = {AayuAI: Transform-Not-Delete Safety Curation and Severity-Aware Fine-Tuning for Medical Conversational AI},
  author  = {Lakkad, Jayraj S. and Tiwari, Satyam R. and Savaliya, Laksh J. and Prajapati, Akshar V. and Kasat, Dipali and Patel, Premalkumar J.},
  year    = {2025},
  doi     = {10.57967/hf/7776},
  url     = {https://huggingface.co/thesatyam12/aayubot-gemma-2b},
}

⚠️ Important disclaimers

  1. Not a doctor. AayuBot is a research prototype. Do not use it for real medical decisions.
  2. Base model license. This repo contains adapter weights only. Usage of google/gemma-2b-it is governed by Google's Gemma Terms of Use.
  3. Adapter license. The adapter weights in this repo are released under CC BY 4.0.