--- license: apache-2.0 language: - en pipeline_tag: text-generation tags: - text-generation-inference - maxtext - base - bexamask - pile --- # ๐Ÿš€ BexaMask-v2 (โ‰ˆ800M Parameters) **BexaMask-v2** is a **pretrained base (foundation) decoder-only Transformer model** trained on large-scale **permissively licensed and uncopyrighted text data** using the MaxText framework on TPU v4-16. > โš ๏ธ This is a **base model** โ€” it is **not instruction-tuned** and may not follow prompts like ChatGPT without further fine-tuning. --- ## ๐Ÿง  Model Overview - **Type:** Pretrained Base Model (Foundation Model) - **Architecture:** Decoder-only Transformer - **Parameters:** ~800M - **Layers:** 16 - **Embedding Dimension:** 2048 - **MLP Dimension:** 5120 - **Attention Heads:** - Query Heads: 16 - KV Heads: 4 (Grouped Query Attention) - **Head Dimension:** 128 - **Activation:** SiLU + Linear - **Max Context Length:** 4096 tokens - **Vocabulary Size:** 32,000 (SentencePiece) --- ## โš™๏ธ Training Details - **Framework:** MaxText - **Hardware:** TPU v4-16 (8 chips, 256GB HBM) ### ๐Ÿ“ฆ Dataset - Subset of **The Pile (uncopyrighted / permissive sources only)** - Filtered to remove restricted or copyrighted data ### ๐Ÿ”ง Training Config - **Steps:** 100,000 - **Epochs:** 2 - **Batch Size:** 16 per device - **Learning Rate:** 3e-4 - **Warmup Steps:** 2,000 - **Scheduler:** Cosine decay --- ## โšก Optimization Techniques - Flash Attention - Full Rematerialization (Remat) - Asynchronous Checkpointing - Distributed GCS checkpointing - IOTA embeddings --- ## ๐Ÿงช Inference Run inference using MaxText: ```bash python3 -m MaxText.decode \ maxtext/configs/pretrain.yml \ run_name=inference \ load_parameters_path=/home/pynatic079/bexamask_v2_inference_local/items \ tokenizer_path=/path/to/llama/tokenizer.model \ max_target_length=512 \ 'prompt=""' \ decode_sampling_strategy="topk" \ decode_sampling_top_k=4 \ decode_sampling_temperature=1.9 \ attention=dot_product