How to use from the
Use from the
PEFT library
# Gated model: Login with a HF token with gated access permission
hf auth login
Task type is invalid.

You need to agree to share your contact information to access this model

This repository is publicly accessible, but you have to accept the conditions to access its files and content.

Log in or Sign Up to review the conditions and access this model content.

AayuBot Gemma-2B (LoRA Adapter)

Severity-aware, safety-focused medical conversational AI built on top of google/gemma-2b-it using QLoRA fine-tuning.

Field Detail
Creator Satyam Tiwari (thesatyam12)
Base model google/gemma-2b-it
Method QLoRA (4-bit NF4) / PEFT adapter
License CC BY 4.0 (adapter weights only)
DOI 10.57967/hf/7776
Status Research β€” Manuscript submitted

Available Checkpoints

Folder Training Steps Recommended?
checkpoint1-20k 20,000
checkpoint2-40k 40,000
checkpoint3-60k 60,000
checkpoint4-80k 80,000
checkpoint5-100k 100,000 βœ… Best / Final

By default, all code below uses checkpoint5-100k.


Step 1 β€” Install dependencies

# Create a fresh environment (recommended)
pip install --upgrade pip

# Core packages
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

# Hugging Face + Gemma + LoRA stack
pip install transformers accelerate bitsandbytes peft

# Login to Hugging Face (required for Gemma and our adapter access)
pip install huggingface_hub
#then
huggingface-cli login #or
from huggingface_hub import login
login("YOUR_HF_TOKEN")

Gemma access: You must first accept Google's Gemma license at https://huggingface.co/google/gemma-2b-it (click "Agree and access repository"). Otherwise the base model download will fail.


Step 2 β€” Download the adapter

Option A: Clone the full repo (all checkpoints)

git lfs install
git clone https://huggingface.co/thesatyam12/aayubot-gemma-2b

Option B: Download only checkpoint5-100k (recommended)

from huggingface_hub import snapshot_download

snapshot_download(
    repo_id="thesatyam12/aayubot-gemma-2b",
    allow_patterns="checkpoint5-100k/*",
    local_dir="./aayubot-gemma-2b",
)

Step 3 β€” Load the model

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel

# ── Config ──────────────────────────────────────────────
BASE_MODEL = "google/gemma-2b-it"
ADAPTER_REPO = "thesatyam12/aayubot-gemma-2b"
CHECKPOINT = "checkpoint5-100k"          # change if you want another checkpoint
# ────────────────────────────────────────────────────────

# 4-bit quantization (same settings used during training)
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
)

# Load tokenizer from the adapter checkpoint
tokenizer = AutoTokenizer.from_pretrained(
    ADAPTER_REPO,
    subfolder=CHECKPOINT,
)

# Load base Gemma model in 4-bit
base_model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    quantization_config=bnb_config,
    device_map="auto",
    torch_dtype=torch.float16,
    attn_implementation="eager",
)

# Attach the LoRA adapter
model = PeftModel.from_pretrained(
    base_model,
    ADAPTER_REPO,
    subfolder=CHECKPOINT,
)
model.eval()

print("βœ… AayuBot loaded successfully!")

Step 4 β€” Run inference

def ask_aayubot(question, severity="simple"):
    # Step 1: Generate initial response
    prompt = (
        f"<start_of_turn>user\n"
        f"[Severity: {severity}]\n"
        f"{question}\n"
        f"<end_of_turn>\n"
        f"<start_of_turn>model\n"
    )

    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    with torch.no_grad():
        output_ids = model.generate(
            **inputs,
            max_new_tokens=300,
            temperature=0.7,
            top_p=0.9,
            do_sample=True,
            repetition_penalty=1.25,
            no_repeat_ngram_size=5,
            pad_token_id=tokenizer.eos_token_id,
        )

    raw_response = tokenizer.decode(
        output_ids[0][inputs["input_ids"].shape[-1]:],
        skip_special_tokens=True,
    ).strip()

    # Step 2: Post-processing (Clean + One Disclaimer)
    post_prompt = (
        f"<start_of_turn>user\n"
        f"Rewrite the following medical answer cleanly. "
        f"Remove all repeated lines and duplicate disclaimers. "
        f"Keep only ONE clear disclaimer at the very end. "
        f"Make the response professional and concise.\n\n"
        f"Text:\n{raw_response}\n"
        f"<end_of_turn>\n"
        f"<start_of_turn>model\n"
    )

    post_inputs = tokenizer(post_prompt, return_tensors="pt").to(model.device)

    with torch.no_grad():
        post_output = model.generate(
            **post_inputs,
            max_new_tokens=300,
            temperature=0.4,
            top_p=0.95,
            do_sample=True,
            repetition_penalty=1.2,
            pad_token_id=tokenizer.eos_token_id,
        )

    final_response = tokenizer.decode(
        post_output[0][post_inputs["input_ids"].shape[-1]:],
        skip_special_tokens=True,
    ).strip()

    # Step 3: Final cleaning (remove any remaining duplicates)
    lines = []
    for line in final_response.split("\n"):
        line = line.strip()
        if line and line not in lines:
            lines.append(line)

    cleaned_response = "\n".join(lines)

    # Ensure disclaimer appears only once at the end
    disclaimer = "This is general information only. Please consult a doctor for proper diagnosis and treatment."

    # Remove any existing disclaimer lines
    cleaned_lines = [l for l in cleaned_response.split("\n") if "general information only" not in l.lower()]
    cleaned_response = "\n".join(cleaned_lines).strip()

    # Add single disclaimer at the end
    if disclaimer.lower() not in cleaned_response.lower():
        cleaned_response += "\n\n" + disclaimer

    return cleaned_response.strip()


# ── Try it ──
print(ask_aayubot("What should I do for a mild headache?", severity="simple"))
print(ask_aayubot("I have chest pain and shortness of breath", severity="high"))

Prompt format

AayuBot expects a severity tag before the user query:

<start_of_turn>user
[Severity: simple|medium|high]
Your medical question here
<end_of_turn>
<start_of_turn>model
Severity When to use
simple General wellness, mild symptoms
medium Moderate symptoms, medication questions
high Emergency-level, chest pain, severe symptoms

Loading from a local clone

If you used git clone (Option A above):

# Just point to the local path instead of the HF repo ID
model = PeftModel.from_pretrained(
    base_model,
    "./aayubot-gemma-2b/checkpoint5-100k",
)

Paper / Citation

AayuAI: Transform-Not-Delete Safety Curation and Severity-Aware Fine-Tuning for Medical Conversational AI

Authors: Jayraj S. Lakkad, Satyam R. Tiwari, Laksh J. Savaliya, Akshar V. Prajapati, Dipali Kasat, Dr. Premalkumar J. Patel

@misc{aayuai2025,
  title   = {AayuAI: Transform-Not-Delete Safety Curation and Severity-Aware Fine-Tuning for Medical Conversational AI},
  author  = {Lakkad, Jayraj S. and Tiwari, Satyam R. and Savaliya, Laksh J. and Prajapati, Akshar V. and Kasat, Dipali and Patel, Premalkumar J.},
  year    = {2025},
  doi     = {10.57967/hf/7776},
  url     = {https://huggingface.co/thesatyam12/aayubot-gemma-2b},
}

⚠️ Important disclaimers

  1. Not a doctor. AayuBot is a research prototype. Do not use it for real medical decisions.
  2. Base model license. This repo contains adapter weights only. Usage of google/gemma-2b-it is governed by Google's Gemma Terms of Use.
  3. Adapter license. The adapter weights in this repo are released under CC BY 4.0.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for thesatyam12/aayubot-gemma-2b

Adapter
(674)
this model