frozbite's picture
Update README.md
68e68e4 verified
|
Raw
History Blame Contribute Delete
2.07 kB
---
license: apache-2.0
language:
- en
pipeline_tag: text-generation
tags:
- text-generation-inference
- maxtext
- base
- bexamask
- pile
---
# πŸš€ BexaMask-v2 (β‰ˆ800M Parameters)
**BexaMask-v2** is a **pretrained base (foundation) decoder-only Transformer model** trained on large-scale **permissively licensed and uncopyrighted text data** using the MaxText framework on TPU v4-16.
> ⚠️ This is a **base model** β€” it is **not instruction-tuned** and may not follow prompts like ChatGPT without further fine-tuning.
---
## 🧠 Model Overview
- **Type:** Pretrained Base Model (Foundation Model)
- **Architecture:** Decoder-only Transformer
- **Parameters:** ~800M
- **Layers:** 16
- **Embedding Dimension:** 2048
- **MLP Dimension:** 5120
- **Attention Heads:**
- Query Heads: 16
- KV Heads: 4 (Grouped Query Attention)
- **Head Dimension:** 128
- **Activation:** SiLU + Linear
- **Max Context Length:** 4096 tokens
- **Vocabulary Size:** 32,000 (SentencePiece)
---
## βš™οΈ Training Details
- **Framework:** MaxText
- **Hardware:** TPU v4-16 (8 chips, 256GB HBM)
### πŸ“¦ Dataset
- Subset of **The Pile (uncopyrighted / permissive sources only)**
- Filtered to remove restricted or copyrighted data
### πŸ”§ Training Config
- **Steps:** 100,000
- **Epochs:** 2
- **Batch Size:** 16 per device
- **Learning Rate:** 3e-4
- **Warmup Steps:** 2,000
- **Scheduler:** Cosine decay
---
## ⚑ Optimization Techniques
- Flash Attention
- Full Rematerialization (Remat)
- Asynchronous Checkpointing
- Distributed GCS checkpointing
- IOTA embeddings
---
## πŸ§ͺ Inference
Run inference using MaxText:
```bash
python3 -m MaxText.decode \
maxtext/configs/pretrain.yml \
run_name=inference \
load_parameters_path=/home/pynatic079/bexamask_v2_inference_local/items \
tokenizer_path=/path/to/llama/tokenizer.model \
max_target_length=512 \
'prompt="<Your prompt>"' \
decode_sampling_strategy="topk" \
decode_sampling_top_k=4 \
decode_sampling_temperature=1.9 \
attention=dot_product