| --- |
| license: apache-2.0 |
| language: |
| - en |
| pipeline_tag: text-generation |
| tags: |
| - text-generation-inference |
| - maxtext |
| - base |
| - bexamask |
| - pile |
| --- |
| # π BexaMask-v2 (β800M Parameters) |
|
|
| **BexaMask-v2** is a **pretrained base (foundation) decoder-only Transformer model** trained on large-scale **permissively licensed and uncopyrighted text data** using the MaxText framework on TPU v4-16. |
|
|
| > β οΈ This is a **base model** β it is **not instruction-tuned** and may not follow prompts like ChatGPT without further fine-tuning. |
|
|
| --- |
|
|
| ## π§ Model Overview |
|
|
| - **Type:** Pretrained Base Model (Foundation Model) |
| - **Architecture:** Decoder-only Transformer |
| - **Parameters:** ~800M |
| - **Layers:** 16 |
| - **Embedding Dimension:** 2048 |
| - **MLP Dimension:** 5120 |
| - **Attention Heads:** |
| - Query Heads: 16 |
| - KV Heads: 4 (Grouped Query Attention) |
| - **Head Dimension:** 128 |
| - **Activation:** SiLU + Linear |
| - **Max Context Length:** 4096 tokens |
| - **Vocabulary Size:** 32,000 (SentencePiece) |
|
|
| --- |
|
|
| ## βοΈ Training Details |
|
|
| - **Framework:** MaxText |
| - **Hardware:** TPU v4-16 (8 chips, 256GB HBM) |
|
|
| ### π¦ Dataset |
| - Subset of **The Pile (uncopyrighted / permissive sources only)** |
| - Filtered to remove restricted or copyrighted data |
|
|
| ### π§ Training Config |
| - **Steps:** 100,000 |
| - **Epochs:** 2 |
| - **Batch Size:** 16 per device |
| - **Learning Rate:** 3e-4 |
| - **Warmup Steps:** 2,000 |
| - **Scheduler:** Cosine decay |
|
|
| --- |
|
|
| ## β‘ Optimization Techniques |
|
|
| - Flash Attention |
| - Full Rematerialization (Remat) |
| - Asynchronous Checkpointing |
| - Distributed GCS checkpointing |
| - IOTA embeddings |
|
|
| --- |
|
|
| ## π§ͺ Inference |
|
|
| Run inference using MaxText: |
|
|
| ```bash |
| python3 -m MaxText.decode \ |
| maxtext/configs/pretrain.yml \ |
| run_name=inference \ |
| load_parameters_path=/home/pynatic079/bexamask_v2_inference_local/items \ |
| tokenizer_path=/path/to/llama/tokenizer.model \ |
| max_target_length=512 \ |
| 'prompt="<Your prompt>"' \ |
| decode_sampling_strategy="topk" \ |
| decode_sampling_top_k=4 \ |
| decode_sampling_temperature=1.9 \ |
| attention=dot_product |