Pradheep1647's picture
Add production model card artifact README.md
019ca64 verified
|
raw
history blame
3.51 kB
metadata
library_name: pytorch
license: apache-2.0
datasets:
  - huuuyeah/meetingbank
tags:
  - pytorch
  - transformer
  - meeting-summarization
  - custom-code
  - attention-variant

Run Sliding Gqa

Custom PyTorch Transformer checkpoint trained on MeetingBank for meeting summarization research. This repository is part of the transformer-lab collection.

Model Details

Field Value
Repository Pradheep1647/run_sliding_gqa-meetingbank-bs8-e20-fp32-19
Attention sliding_gqa
Dataset meetingbank
Layers 6
Hidden size 512
Heads 8
Batch size 8
Epochs 20
Precision fp32
Checkpoint meeting_model19.pt

Training Loss

Training loss

Raw curve data is available in loss_curve.csv.

Available Models

Files

File Purpose
meeting_model19.pt PyTorch checkpoint containing model_state_dict, optimizer states, epoch, and global step.
config.json Training and architecture config converted from the Hydra run config.
tokenizer.json MeetingBank transcript tokenizer alias for source inputs.
transcript_tokenizer.json Explicit MeetingBank transcript tokenizer.
summary_tokenizer.json MeetingBank summary tokenizer for target text.
loss_curve.csv TensorBoard train/loss scalar export.
loss_curve.svg Static training-loss plot generated from loss_curve.csv.

Usage

These checkpoints are from a custom PyTorch codebase, not a transformers.AutoModel checkpoint. Use the repo-native builder to instantiate the architecture, then load the checkpoint state dict.

from pathlib import Path

import torch
from huggingface_hub import hf_hub_download
from omegaconf import OmegaConf

import src  # registers components
from src.model.builder import build_transformer

repo_id = "Pradheep1647/run_sliding_gqa-meetingbank-bs8-e20-fp32-19"

config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
checkpoint_path = hf_hub_download(repo_id=repo_id, filename="meeting_model19.pt")

cfg = OmegaConf.load(config_path)
model = build_transformer(cfg)

state = torch.load(checkpoint_path, map_location="cpu")
model.load_state_dict(state["model_state_dict"])
model.eval()

print(f"Loaded {repo_id} from {Path(checkpoint_path).name}")

Notes

  • This is a research checkpoint for comparing attention variants under the same MeetingBank setup.
  • The config and tokenizers are included so future runs can reproduce the architecture and preprocessing assumptions.
  • Use config.json as the source of truth for architecture parameters.