Text Generation
Transformers
Safetensors
English
Arabic
quasar
silx-ai
foundation-model
3b
Mixture of Experts
long-context
bittensor
sn24
distillation
hybrid-transformer
conversational
custom_code
Instructions to use silx-ai/Quasar-3B-A1B-Preview with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use silx-ai/Quasar-3B-A1B-Preview with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="silx-ai/Quasar-3B-A1B-Preview", trust_remote_code=True) messages = [ {"role": "user", "content": "Who are you?"}, ] pipe(messages)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("silx-ai/Quasar-3B-A1B-Preview", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
- Local Apps Settings
- vLLM
How to use silx-ai/Quasar-3B-A1B-Preview with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "silx-ai/Quasar-3B-A1B-Preview" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "silx-ai/Quasar-3B-A1B-Preview", "messages": [ { "role": "user", "content": "What is the capital of France?" } ] }'Use Docker
docker model run hf.co/silx-ai/Quasar-3B-A1B-Preview
- SGLang
How to use silx-ai/Quasar-3B-A1B-Preview with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "silx-ai/Quasar-3B-A1B-Preview" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "silx-ai/Quasar-3B-A1B-Preview", "messages": [ { "role": "user", "content": "What is the capital of France?" } ] }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "silx-ai/Quasar-3B-A1B-Preview" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "silx-ai/Quasar-3B-A1B-Preview", "messages": [ { "role": "user", "content": "What is the capital of France?" } ] }' - Docker Model Runner
How to use silx-ai/Quasar-3B-A1B-Preview with Docker Model Runner:
docker model run hf.co/silx-ai/Quasar-3B-A1B-Preview
| """Quasar hybrid transformer — HuggingFace compatible. | |
| """ | |
| import math | |
| import os | |
| from dataclasses import dataclass | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch.utils.checkpoint | |
| from transformers import GenerationMixin | |
| from transformers.cache_utils import Cache | |
| from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast | |
| from transformers.modeling_utils import PreTrainedModel | |
| from transformers.utils import logging | |
| from .configuration_quasar import QuasarConfig | |
| logger = logging.get_logger(__name__) | |
| # FLA layer imports — required | |
| from fla.layers.quasar import QuasarAttention | |
| from fla.layers.gla import GatedLinearAttention | |
| from fla.models.utils import Cache as FlaCache, FLAGenerationMixin | |
| # =================================================================== | |
| # RMSNorm (standalone — weight name: .weight, no bias) | |
| # =================================================================== | |
| class RMSNorm(nn.Module): | |
| def __init__(self, hidden_size, eps=1e-6): | |
| super().__init__() | |
| self.weight = nn.Parameter(torch.ones(hidden_size)) | |
| self.variance_epsilon = eps | |
| def forward(self, hidden_states): | |
| input_dtype = hidden_states.dtype | |
| hidden_states = hidden_states.to(torch.float32) | |
| variance = hidden_states.pow(2).mean(-1, keepdim=True) | |
| hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) | |
| return self.weight * hidden_states.to(input_dtype) | |
| # =================================================================== | |
| # Rotary Embedding (persistent inv_freq to match checkpoint) | |
| # =================================================================== | |
| class RotaryEmbedding(nn.Module): | |
| def __init__(self, dim, max_position_embeddings=4096, base=100000, device=None): | |
| super().__init__() | |
| self.dim = dim | |
| self.max_position_embeddings = max_position_embeddings | |
| self.base = base | |
| inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) | |
| self.register_buffer("inv_freq", inv_freq, persistent=False) | |
| # Pre-compute cos/sin cache | |
| t = torch.arange(max_position_embeddings + 1, device=device, dtype=inv_freq.dtype) | |
| freqs = torch.einsum("i,j->ij", t, inv_freq) | |
| emb = torch.cat((freqs, freqs), dim=-1) | |
| self.register_buffer("_cos_cached", emb.cos()[None, None, :, :], persistent=False) | |
| self.register_buffer("_sin_cached", emb.sin()[None, None, :, :], persistent=False) | |
| def forward(self, x, seq_len=None): | |
| if seq_len is not None and seq_len > self._cos_cached.shape[2]: | |
| t = torch.arange(seq_len + 1024, device=x.device, dtype=self.inv_freq.dtype) | |
| freqs = torch.einsum("i,j->ij", t, self.inv_freq) | |
| emb = torch.cat((freqs, freqs), dim=-1) | |
| self.register_buffer("_cos_cached", emb.cos()[None, None, :, :].to(self._cos_cached.dtype), persistent=False) | |
| self.register_buffer("_sin_cached", emb.sin()[None, None, :, :].to(self._sin_cached.dtype), persistent=False) | |
| return ( | |
| self._cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), | |
| self._sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), | |
| ) | |
| # =================================================================== | |
| # Latent Memory Module (use_triton=False — PyTorch bmm is faster) | |
| # =================================================================== | |
| class LatentMemoryModule(nn.Module): | |
| """Persistent Latent Parameter Memory — weight names match checkpoint.""" | |
| def __init__(self, hidden_size, memory_slots=128, memory_dim=128, use_triton=False): | |
| super().__init__() | |
| self.K = memory_slots | |
| self.D = memory_dim | |
| self.W_eta = nn.Linear(hidden_size, 1, bias=True) | |
| nn.init.zeros_(self.W_eta.weight) | |
| nn.init.constant_(self.W_eta.bias, -5.0) | |
| self.segment_len = 64 | |
| self.summary_query = nn.Parameter(torch.randn(1, 1, memory_dim)) | |
| self.summary_proj = nn.Linear(hidden_size, memory_dim, bias=True) | |
| self.eta_channels = nn.Parameter(torch.ones(1, 1, memory_dim)) | |
| self.temperature = nn.Parameter(torch.ones(1)) | |
| self.hidden_size = hidden_size | |
| self.use_triton = False | |
| self.input_norm = nn.LayerNorm(hidden_size) | |
| self.compress_z = nn.Sequential( | |
| nn.Linear(hidden_size, memory_dim * 2, bias=False), | |
| nn.SiLU(), | |
| nn.Linear(memory_dim * 2, memory_dim, bias=False), | |
| ) | |
| self.W_qkv_mem = nn.Linear(hidden_size, memory_dim * 3, bias=False) | |
| self.scale = 1.0 / math.sqrt(memory_dim) | |
| def get_diversity_loss(self, M): | |
| B, K, D = M.shape | |
| M_norm = F.normalize(M, p=2, dim=-1) | |
| sim = torch.bmm(M_norm, M_norm.transpose(1, 2)) | |
| mask = torch.eye(K, device=M.device).unsqueeze(0) | |
| sim = sim * (1 - mask) | |
| return sim.pow(2).mean() | |
| def write_memory(self, H, M, chunk_idx=0): | |
| H = self.input_norm(H) | |
| B, T, _ = H.shape | |
| H_mem = self.summary_proj(H) | |
| eta_tokens = self.W_eta(H).squeeze(-1) | |
| L = self.segment_len | |
| if T % L != 0: | |
| pad_len = L - (T % L) | |
| H_padded = F.pad(H_mem, (0, 0, 0, pad_len)) | |
| eta_padded = F.pad(eta_tokens, (0, pad_len), value=-10.0) | |
| else: | |
| H_padded = H_mem | |
| eta_padded = eta_tokens | |
| T_pad = H_padded.shape[1] | |
| num_segments = T_pad // L | |
| H_segs = H_padded.view(B * num_segments, L, self.D) | |
| summary_scores = torch.bmm( | |
| self.summary_query.expand(B * num_segments, -1, -1), | |
| H_segs.transpose(1, 2), | |
| ) | |
| summary_weights = F.softmax(summary_scores * self.scale, dim=-1) | |
| Z_seg = torch.bmm(summary_weights, H_segs).view(B, num_segments, self.D) | |
| eta_raw_sig = torch.sigmoid(eta_tokens) | |
| eta_seg_sig = torch.max( | |
| torch.sigmoid(eta_padded.view(B, num_segments, L)), dim=-1, keepdim=True | |
| )[0] | |
| scores = torch.bmm(Z_seg, M.transpose(-1, -2)) * self.scale * torch.exp(self.temperature) | |
| A = F.softmax(scores, dim=-1) | |
| DeltaM_seg = torch.bmm(A.transpose(1, 2), Z_seg * eta_seg_sig) | |
| eta_avg = eta_seg_sig.mean(dim=1, keepdim=True) | |
| gate = eta_avg * torch.sigmoid(self.eta_channels) | |
| M_new = (1.0 - gate) * M + DeltaM_seg / num_segments | |
| norm_sq = torch.sum(DeltaM_seg ** 2) / num_segments | |
| div_loss = self.get_diversity_loss(M_new) | |
| return M_new, norm_sq * 0.01 + div_loss * 0.1, eta_raw_sig | |
| def read_memory(self, H, M, memory_scale=1.0): | |
| H = self.input_norm(H) | |
| qkv_mem = self.W_qkv_mem(H) | |
| _, _, Q_r = torch.split(qkv_mem, [self.D, self.D, self.D], dim=-1) | |
| scores = torch.bmm(Q_r, M.transpose(-1, -2)) | |
| if M.shape[1] > 1024: | |
| top_k = 64 | |
| top_vals, top_idx = torch.topk(scores, top_k, dim=-1) | |
| mask = torch.full_like(scores, float('-inf')) | |
| mask.scatter_(-1, top_idx, top_vals) | |
| scores = mask | |
| A = F.softmax(scores * 2.0, dim=-1) | |
| C = torch.bmm(A, M) | |
| return C * memory_scale | |
| # =================================================================== | |
| # FFN Components | |
| # =================================================================== | |
| class SwiGLUBlock(nn.Module): | |
| """Dense FFN — weight names: gate.weight, up.weight, down.weight""" | |
| def __init__(self, d_model, d_ff): | |
| super().__init__() | |
| self.gate = nn.Linear(d_model, d_ff, bias=False) | |
| self.up = nn.Linear(d_model, d_ff, bias=False) | |
| self.down = nn.Linear(d_ff, d_model, bias=False) | |
| def forward(self, x): | |
| return self.down(F.silu(self.gate(x)) * self.up(x)) | |
| class SigmoidRouter(nn.Module): | |
| """Router with router_weights Parameter — weight name: router.router_weights""" | |
| def __init__(self, d_model, num_experts): | |
| super().__init__() | |
| self.router_weights = nn.Parameter(torch.zeros(num_experts, d_model)) | |
| nn.init.kaiming_uniform_(self.router_weights, a=math.sqrt(5)) | |
| def forward(self, x): | |
| logits = F.linear(x, self.router_weights) | |
| scores = torch.sigmoid(logits) | |
| return scores, logits | |
| class BigMacMoE(nn.Module): | |
| """BigMac MoE with DCCA bottleneck — matches checkpoint weight names exactly. | |
| Weights: w_down_proj, w_up_proj, experts_w12, experts_w3, | |
| router.router_weights, shared_experts.{i}.{gate,up,down}.weight, | |
| max_vio | |
| """ | |
| def __init__(self, config, layer_idx=None): | |
| super().__init__() | |
| self.d_model = config.d_model | |
| self.bigmac_r = getattr(config, 'bigmac_r', 0.25) | |
| self.bottle_dim = int(self.d_model * self.bigmac_r) | |
| self.num_shared_experts = getattr(config, 'num_shared_experts', 1) | |
| self.num_routed_experts = getattr(config, 'num_routed_experts', 64) | |
| self.top_k = getattr(config, 'top_k', 4) | |
| default_routed_size = int(getattr(config, 'routed_expert_size', 768) / self.bigmac_r) | |
| self.routed_expert_size = getattr(config, 'bigmac_expert_size', default_routed_size) | |
| self.shared_expert_size = getattr(config, 'shared_expert_size', config.d_ff) | |
| self.layer_idx = layer_idx | |
| self.shared_experts = nn.ModuleList([ | |
| SwiGLUBlock(self.d_model, self.shared_expert_size) | |
| for _ in range(self.num_shared_experts) | |
| ]) | |
| # BigMac DCCA Projections | |
| self.w_down_proj = nn.Linear(self.d_model, self.bottle_dim, bias=False) | |
| self.w_up_proj = nn.Linear(self.bottle_dim, self.d_model, bias=False) | |
| # BigMac Experts (fused gate+up W12, down W3) | |
| self.experts_w12 = nn.Parameter(torch.zeros(self.num_routed_experts, self.bottle_dim, 2 * self.routed_expert_size)) | |
| self.experts_w3 = nn.Parameter(torch.zeros(self.num_routed_experts, self.routed_expert_size, self.bottle_dim)) | |
| self.router = SigmoidRouter(self.d_model, self.num_routed_experts) | |
| self.expert_bias = None | |
| self.expert_momentum = None | |
| self.smebu_kappa = getattr(config, 'smebu_kappa', 2.0) | |
| self.smebu_lambda = getattr(config, 'smebu_lambda', 2e-3) | |
| self.smebu_beta = getattr(config, 'smebu_beta', 0.5) | |
| self.z_loss_weight = getattr(config, 'moe_z_loss_coeff', 1e-4) | |
| self.aux_loss_weight = getattr(config, 'moe_aux_loss_coeff', 1e-4) | |
| self.register_buffer("max_vio", torch.tensor(0.0)) | |
| self.route_scale = math.sqrt(self.top_k) | |
| self.moe_scale = 1.0 / (1.0 + float(self.num_shared_experts > 0)) | |
| # Buffers for padded BMM dispatch | |
| self.register_buffer("_dummy_token", torch.zeros(1, self.bottle_dim, dtype=torch.bfloat16), persistent=False) | |
| self.register_buffer("_dummy_out", torch.zeros(1, self.bottle_dim, dtype=torch.bfloat16), persistent=False) | |
| self._cached_N = -1 | |
| self._cached_K = -1 | |
| self._cached_indices = None | |
| def _init_weights(self, std=0.011): | |
| nn.init.normal_(self.w_down_proj.weight, std=std) | |
| nn.init.normal_(self.w_up_proj.weight, std=std) | |
| nn.init.normal_(self.experts_w12, std=std) | |
| nn.init.normal_(self.experts_w3, std=std) | |
| for expert in self.shared_experts: | |
| nn.init.normal_(expert.gate.weight, std=std) | |
| nn.init.normal_(expert.up.weight, std=std) | |
| nn.init.normal_(expert.down.weight, std=std) | |
| def forward(self, x, expert_bias=None): | |
| batch_size, seq_len, d_model = x.shape | |
| hidden_states = x.view(-1, d_model) | |
| N, D = hidden_states.shape | |
| K = self.top_k | |
| num_tokens_total = N * K | |
| # 1. Routing & Gating | |
| with torch.autocast(device_type=x.device.type, dtype=torch.float32): | |
| scores, logits = self.router(hidden_states) | |
| z_loss = torch.mean(logits.nan_to_num() ** 2) * self.z_loss_weight | |
| bias = expert_bias if expert_bias is not None else torch.zeros(self.num_routed_experts, device=x.device) | |
| selection_scores = scores + bias | |
| _, topk_indices = torch.topk(selection_scores, K, dim=-1) | |
| topk_indices = topk_indices.clamp(0, logits.shape[1] - 1) | |
| topk_logits = torch.gather(logits, 1, topk_indices) | |
| gating_scores = F.softmax(topk_logits, dim=-1).to(torch.bfloat16) | |
| # 2. Aux loss | |
| if self.training: | |
| flat_topk_idx = topk_indices.view(-1) | |
| expert_counts = torch.bincount(flat_topk_idx, minlength=self.num_routed_experts) | |
| fi = expert_counts.float() / num_tokens_total | |
| Pi = scores.nan_to_num().mean(dim=0) | |
| aux_loss = torch.sum(fi * Pi) * self.aux_loss_weight | |
| else: | |
| aux_loss = torch.tensor(0.0, device=x.device) | |
| expert_counts = None | |
| # 3. Shared experts | |
| shared_out = 0 | |
| if self.num_shared_experts > 0: | |
| for expert in self.shared_experts: | |
| shared_out = shared_out + expert(hidden_states) | |
| # 4. Bottleneck projection | |
| down_proj_hidden = self.w_down_proj(hidden_states) | |
| # 5. Routed experts (padded BMM dispatch) | |
| flat_topk_idx = topk_indices.view(-1).clamp(0, self.num_routed_experts - 1) | |
| sorted_experts, permutation = torch.sort(flat_topk_idx) | |
| if self._cached_N == N and self._cached_K == K: | |
| token_indices, global_rel_idx = self._cached_indices | |
| else: | |
| token_indices = torch.arange(N, device=x.device).repeat_interleave(K) | |
| global_rel_idx = torch.arange(num_tokens_total, device=x.device) | |
| self._cached_N, self._cached_K = N, K | |
| self._cached_indices = (token_indices, global_rel_idx) | |
| max_load = ((num_tokens_total // self.num_routed_experts) // 8 + 6) * 8 | |
| used_counts = expert_counts if expert_counts is not None else torch.bincount(flat_topk_idx, minlength=self.num_routed_experts) | |
| expert_ptr = torch.cumsum(used_counts, dim=0) - used_counts | |
| local_idx = global_rel_idx - expert_ptr.index_select(0, sorted_experts) | |
| capacity_mask = local_idx < max_load | |
| valid_slots = sorted_experts[capacity_mask] * max_load + local_idx[capacity_mask] | |
| num_slots = self.num_routed_experts * max_load | |
| hidden_with_dummy = torch.cat([down_proj_hidden, self._dummy_token], dim=0) | |
| reverse_map = torch.full((num_slots,), N, device=x.device, dtype=torch.long) | |
| reverse_map.scatter_(0, valid_slots.long(), token_indices[permutation][capacity_mask]) | |
| padding = hidden_with_dummy.index_select(0, reverse_map).view(self.num_routed_experts, max_load, self.bottle_dim) | |
| h12 = torch.bmm(padding, self.experts_w12) | |
| h1, h2 = h12.chunk(2, dim=-1) | |
| padded_out = torch.bmm(F.silu(h1) * h2, self.experts_w3) | |
| padded_out_flat = padded_out.view(-1, self.bottle_dim) | |
| padded_out_with_dummy = torch.cat([padded_out_flat, self._dummy_out], dim=0) | |
| gather_map = torch.full((num_tokens_total,), num_slots, device=x.device, dtype=torch.long) | |
| gather_map.scatter_(0, permutation[capacity_mask], valid_slots) | |
| gathered_out = padded_out_with_dummy.index_select(0, gather_map).view(N, K, self.bottle_dim) | |
| routed_out_bottle = torch.bmm(gating_scores.to(gathered_out.dtype).unsqueeze(1), gathered_out).squeeze(1) | |
| routed_out = self.w_up_proj(routed_out_bottle) | |
| if self.training: | |
| mean_load = num_tokens_total / self.num_routed_experts | |
| self._pending_violation = (mean_load - used_counts.float()) / (mean_load + 1e-6) | |
| route_scale = math.sqrt(self.top_k) if self.training else 1.0 | |
| out = (shared_out + routed_out * route_scale) * self.moe_scale | |
| out = out.view(batch_size, seq_len, d_model).to(x.dtype) | |
| return out, z_loss + aux_loss | |
| def update_bias(self, counts, num_tokens): | |
| expert_counts = counts.float() | |
| mean_load = num_tokens * self.top_k / self.num_routed_experts | |
| violation = (mean_load - expert_counts) / (mean_load + 1e-6) | |
| clamped_update = torch.tanh(self.smebu_kappa * violation) | |
| delta_bi = self.smebu_lambda * clamped_update | |
| delta_bi = delta_bi - delta_bi.mean() | |
| self.expert_momentum.data = self.smebu_beta * self.expert_momentum.data + (1 - self.smebu_beta) * delta_bi | |
| self.expert_bias.data = (self.expert_bias.data + self.expert_momentum.data).nan_to_num_().clamp(-10.0, 10.0) | |
| self.expert_bias.data -= self.expert_bias.data.mean() | |
| current_max_vio = -violation.min() | |
| self.max_vio.copy_(0.99 * self.max_vio + 0.01 * current_max_vio) | |
| class GroupedMoE(nn.Module): | |
| """Grouped MoE fallback — for non-BigMac configs.""" | |
| def __init__(self, config, layer_idx=None): | |
| super().__init__() | |
| self.d_model = config.d_model | |
| self.num_shared_experts = getattr(config, 'num_shared_experts', 1) | |
| self.num_routed_experts = getattr(config, 'num_routed_experts', 64) | |
| self.top_k = getattr(config, 'top_k', 6) | |
| self.shared_expert_size = getattr(config, 'shared_expert_size', config.d_ff) | |
| self.routed_expert_size = getattr(config, 'routed_expert_size', 1408) | |
| self.layer_idx = layer_idx | |
| self.shared_experts = nn.ModuleList([ | |
| SwiGLUBlock(self.d_model, self.shared_expert_size) | |
| for _ in range(self.num_shared_experts) | |
| ]) | |
| self.experts_w12 = nn.Parameter(torch.zeros(self.num_routed_experts, self.d_model, 2 * self.routed_expert_size)) | |
| self.experts_w3 = nn.Parameter(torch.zeros(self.num_routed_experts, self.routed_expert_size, self.d_model)) | |
| self.router = nn.Linear(config.d_model, config.num_routed_experts, bias=False) | |
| with torch.no_grad(): | |
| nn.init.normal_(self.router.weight, std=0.01) | |
| self.z_loss_weight = getattr(config, 'moe_z_loss_coeff', 1e-6) | |
| self.aux_loss_weight = getattr(config, 'moe_aux_loss_coeff', 1e-4) | |
| self.smebu_kappa = getattr(config, 'smebu_kappa', 2.0) | |
| self.smebu_lambda = getattr(config, 'smebu_lambda', 5e-4) | |
| self.smebu_beta = getattr(config, 'smebu_beta', 0.5) | |
| self.register_buffer("max_vio", torch.tensor(0.0)) | |
| self.moe_scale = 1.0 / (1.0 + float(self.num_shared_experts > 0)) | |
| def _init_weights(self, std=0.011): | |
| nn.init.normal_(self.experts_w12, std=std) | |
| nn.init.normal_(self.experts_w3, std=std) | |
| for expert in self.shared_experts: | |
| nn.init.normal_(expert.gate.weight, std=std) | |
| nn.init.normal_(expert.up.weight, std=std) | |
| nn.init.normal_(expert.down.weight, std=std) | |
| def forward(self, x, expert_bias=None): | |
| batch_size, seq_len, d_model = x.shape | |
| hidden_states = x.view(-1, d_model) | |
| N, D = hidden_states.shape | |
| K = self.top_k | |
| with torch.autocast(device_type=x.device.type, dtype=torch.float32): | |
| logits = self.router(hidden_states) | |
| scores = torch.sigmoid(logits) | |
| z_loss = torch.mean(logits.nan_to_num() ** 2) * self.z_loss_weight | |
| bias = expert_bias if expert_bias is not None else torch.zeros(self.num_routed_experts, device=x.device) | |
| selection_scores = scores + bias | |
| _, topk_indices = torch.topk(selection_scores, K, dim=-1) | |
| topk_indices = topk_indices.clamp(0, logits.shape[1] - 1) | |
| topk_logits = torch.gather(logits, 1, topk_indices) | |
| gating_scores = F.softmax(topk_logits, dim=-1).to(torch.bfloat16) | |
| if self.training: | |
| flat_topk_idx = topk_indices.view(-1) | |
| expert_counts = torch.bincount(flat_topk_idx, minlength=self.num_routed_experts) | |
| fi = expert_counts.float() / (N * K) | |
| Pi = scores.nan_to_num().mean(dim=0) | |
| aux_loss = torch.sum(fi * Pi) * self.aux_loss_weight | |
| self._pending_violation = fi.detach() - (1.0 / self.num_routed_experts) | |
| else: | |
| aux_loss = torch.tensor(0.0, device=x.device) | |
| expert_counts = None | |
| self._pending_violation = torch.zeros(self.num_routed_experts, device=x.device) | |
| shared_out = 0 | |
| if self.num_shared_experts > 0: | |
| for expert in self.shared_experts: | |
| shared_out = shared_out + expert(hidden_states) | |
| # Padded BMM dispatch | |
| num_experts = self.num_routed_experts | |
| flat_topk_idx = topk_indices.view(-1) | |
| tokens_per_expert = torch.bincount(flat_topk_idx, minlength=num_experts) | |
| max_tokens = tokens_per_expert.max().item() | |
| if max_tokens == 0: | |
| out = shared_out * self.moe_scale | |
| return out.view(batch_size, seq_len, d_model).to(x.dtype), aux_loss | |
| sorted_indices = torch.argsort(flat_topk_idx) | |
| token_indices = torch.arange(N, device=x.device).repeat_interleave(K)[sorted_indices] | |
| grouped_x = hidden_states[token_indices] | |
| padded_x = torch.zeros(num_experts, max_tokens, D, device=x.device, dtype=x.dtype) | |
| expert_starts = torch.cat([torch.tensor([0], device=x.device), tokens_per_expert[:-1].cumsum(0)]) | |
| intra_offsets = torch.arange(N * K, device=x.device) - expert_starts.repeat_interleave(tokens_per_expert) | |
| expert_idx = flat_topk_idx[sorted_indices] | |
| padded_x_flat = padded_x.view(-1, D) | |
| flat_dest_indices = expert_idx * max_tokens + intra_offsets | |
| padded_x_flat.index_put_((flat_dest_indices,), grouped_x) | |
| h12 = torch.bmm(padded_x, self.experts_w12) | |
| h1, h2 = h12.chunk(2, dim=-1) | |
| h = F.silu(h1) * h2 | |
| expert_out_padded = torch.bmm(h, self.experts_w3) | |
| full_expert_out = expert_out_padded.view(-1, D)[flat_dest_indices] | |
| gating_flat = gating_scores.view(-1) | |
| sorted_gating = gating_flat[sorted_indices].unsqueeze(1) | |
| weighted_out = full_expert_out * sorted_gating | |
| routed_out = torch.zeros_like(hidden_states) | |
| routed_out.index_add_(0, token_indices, weighted_out) | |
| route_scale = math.sqrt(self.top_k) if self.training else 1.0 | |
| out = (shared_out + routed_out * route_scale) * self.moe_scale | |
| out = out.view(batch_size, seq_len, d_model).to(x.dtype) | |
| return out, z_loss + aux_loss | |
| # =================================================================== | |
| # HybridBlock — one transformer layer | |
| # Weight names: ln1.weight, ln1_out.weight, ln2.weight, ln2_out.weight, | |
| # attn.*, memory.*, W_alpha.*, C_to_hidden.*, | |
| # ffn.*, injection_gate | |
| # =================================================================== | |
| class HybridBlock(nn.Module): | |
| def __init__(self, config: QuasarConfig, layer_idx: int): | |
| super().__init__() | |
| self.hidden_size = config.d_model | |
| self.layer_idx = layer_idx | |
| self.n_layers = config.n_layers | |
| self.config = config | |
| self.gradient_checkpointing = False | |
| # Looped Transformer injection gate (checkpoint always has it) | |
| self.use_looped_injection = config.use_looped_injection | |
| self.injection_gate = nn.Parameter(torch.tensor([-2.197])) | |
| # Determine layer type (use hybrid_layer_types for quasar/gla distinction) | |
| self.layer_type = config.hybrid_layer_types[layer_idx] | |
| # Attention layer | |
| if self.layer_type == "quasar": | |
| self.attn = QuasarAttention( | |
| mode=config.attn_mode, | |
| hidden_size=config.d_model, | |
| expand_v=config.expand_v, | |
| head_dim=config.head_dim, | |
| num_heads=config.n_heads, | |
| num_v_heads=config.num_v_heads, | |
| use_short_conv=config.use_short_conv, | |
| allow_neg_eigval=config.allow_neg_eigval, | |
| conv_size=config.conv_size, | |
| norm_eps=config.rms_norm_eps, | |
| layer_idx=layer_idx, | |
| ) | |
| elif self.layer_type == "gla": | |
| self.attn = GatedLinearAttention( | |
| mode=config.gla_mode, | |
| hidden_size=config.d_model, | |
| expand_k=config.expand_k, | |
| expand_v=config.expand_v, | |
| num_heads=config.n_heads, | |
| layer_idx=layer_idx, | |
| ) | |
| # Latent Memory Module | |
| self.memory = LatentMemoryModule( | |
| hidden_size=config.d_model, | |
| memory_slots=config.memory_slots, | |
| memory_dim=config.memory_dim, | |
| use_triton=False, | |
| ) | |
| nn.init.constant_(self.memory.W_eta.bias, -1.0) | |
| self.W_alpha = nn.Linear(config.d_model, 1) | |
| self.C_to_hidden = nn.Linear(config.memory_dim, config.d_model, bias=False) | |
| else: | |
| raise ValueError(f"Unknown layer_type: {self.layer_type}") | |
| # Sandwich norms | |
| self.ln1 = RMSNorm(config.d_model, eps=config.rms_norm_eps) | |
| self.ln1_out = RMSNorm(config.d_model, eps=config.rms_norm_eps) | |
| self.ln2 = RMSNorm(config.d_model, eps=config.rms_norm_eps) | |
| self.ln2_out = RMSNorm(config.d_model, eps=config.rms_norm_eps) | |
| # FFN vs MoE | |
| dense_layers = config.dense_input_layers | |
| num_routed = config.num_routed_experts | |
| if layer_idx < dense_layers or num_routed == 0: | |
| self.is_moe = False | |
| self.ffn = SwiGLUBlock(config.d_model, config.d_ff) | |
| else: | |
| self.is_moe = True | |
| if config.moe_type == "bigmac": | |
| self.ffn = BigMacMoE(config, layer_idx=layer_idx) | |
| elif config.moe_type == "deepseek": | |
| # DeepSeekMoE could be added here if needed | |
| self.ffn = BigMacMoE(config, layer_idx=layer_idx) | |
| else: | |
| self.ffn = GroupedMoE(config, layer_idx=layer_idx) | |
| self.dropout = nn.Dropout(config.dropout) | |
| self.scale_factor = 1.0 / math.sqrt(2 * self.n_layers) | |
| self.residual_scale = config.residual_scale | |
| self._init_weights() | |
| def _init_weights(self): | |
| trinity_std = 0.5 / math.sqrt(self.hidden_size) | |
| if self.layer_type == "gla": | |
| nn.init.constant_(self.W_alpha.bias, -10.0) | |
| nn.init.zeros_(self.W_alpha.weight) | |
| nn.init.normal_(self.C_to_hidden.weight, std=trinity_std) | |
| def apply_deep_init(m): | |
| if hasattr(m, 'down') and isinstance(m.down, nn.Linear): | |
| nn.init.normal_(m.down.weight, mean=0.0, std=trinity_std * self.scale_factor) | |
| if not self.is_moe: | |
| nn.init.normal_(self.ffn.gate.weight, mean=0.0, std=trinity_std) | |
| nn.init.normal_(self.ffn.up.weight, mean=0.0, std=trinity_std) | |
| apply_deep_init(self.ffn) | |
| else: | |
| self.ffn._init_weights(std=trinity_std) | |
| for expert in self.ffn.shared_experts: | |
| apply_deep_init(expert) | |
| nn.init.normal_(self.ffn.experts_w3, mean=0.0, std=trinity_std) | |
| nn.init.constant_(self.ln1_out.weight, 1.0) | |
| nn.init.constant_(self.ln2_out.weight, 1.0) | |
| if hasattr(self.attn, 'o_proj') and isinstance(self.attn.o_proj, nn.Linear): | |
| nn.init.normal_(self.attn.o_proj.weight, mean=0.0, std=trinity_std * self.scale_factor) | |
| for proj_name in ['q_proj', 'k_proj', 'v_proj', 'g_proj']: | |
| if hasattr(self.attn, proj_name): | |
| m = getattr(self.attn, proj_name) | |
| if isinstance(m, nn.Linear): | |
| nn.init.normal_(m.weight, mean=0.0, std=trinity_std) | |
| elif isinstance(m, nn.Sequential): | |
| for subm in m: | |
| if isinstance(subm, nn.Linear): | |
| nn.init.normal_(subm.weight, mean=0.0, std=trinity_std) | |
| def forward(self, x, cos=None, sin=None, expert_bias=None, | |
| memory_state=None, lambda_reg=0.01, **kwargs): | |
| if self.use_looped_injection: | |
| P = kwargs.get('P') | |
| if P is not None: | |
| x = x + (torch.sigmoid(self.injection_gate) * P) | |
| if self.gradient_checkpointing and self.training: | |
| return torch.utils.checkpoint.checkpoint( | |
| self._forward, x, cos, sin, expert_bias, memory_state, lambda_reg, | |
| use_reentrant=False, **kwargs, | |
| ) | |
| return self._forward(x, cos, sin, expert_bias, memory_state, lambda_reg, **kwargs) | |
| def _forward(self, x, cos=None, sin=None, expert_bias=None, | |
| memory_state=None, lambda_reg=0.01, **kwargs): | |
| # 1. Attention block | |
| residual = x | |
| x = self.ln1(x) | |
| # Build attention kwargs | |
| attn_kwargs = {} | |
| if cos is not None and sin is not None: | |
| attn_kwargs['cos'] = cos | |
| attn_kwargs['sin'] = sin | |
| # Pass past_key_values for FLA cache support | |
| if 'past_key_values' in kwargs and kwargs['past_key_values'] is not None: | |
| attn_kwargs['past_key_values'] = kwargs['past_key_values'] | |
| if 'use_cache' in kwargs: | |
| attn_kwargs['use_cache'] = kwargs['use_cache'] | |
| attn_out = self.attn(x, **attn_kwargs) | |
| if isinstance(attn_out, tuple): | |
| attn_out = attn_out[0] | |
| new_memory_state = None | |
| mem_loss = torch.tensor(0.0, device=x.device) | |
| # GLA layers: read/write latent memory | |
| if self.layer_type == "gla" and memory_state is not None: | |
| new_memory_state, total_mem_loss, _ = self.memory.write_memory(x, memory_state) | |
| C = self.memory.read_memory(x, new_memory_state) | |
| alpha = torch.sigmoid(self.W_alpha(x)) | |
| C_proj = self.C_to_hidden(C) | |
| attn_out = attn_out + (alpha * C_proj) | |
| mem_loss = total_mem_loss | |
| # Sandwich norm + residual scaling | |
| x = residual + self.residual_scale * self.dropout(self.ln1_out(attn_out)) | |
| # 2. FFN / MoE block | |
| residual = x | |
| x = self.ln2(x) | |
| if self.is_moe: | |
| block_out, aux_loss = self.ffn(x, expert_bias=expert_bias) | |
| else: | |
| block_out = self.ffn(x) | |
| aux_loss = torch.tensor(0.0, device=x.device) | |
| x = residual + self.residual_scale * self.dropout(self.ln2_out(block_out)) | |
| return x, aux_loss, new_memory_state, mem_loss | |
| # =================================================================== | |
| # Output dataclasses | |
| # =================================================================== | |
| class QuasarModelOutputWithPast(BaseModelOutputWithPast): | |
| memory_states: dict | None = None | |
| memory_loss: torch.Tensor | None = None | |
| class QuasarCausalLMOutputWithPast(CausalLMOutputWithPast): | |
| memory_states: dict | None = None | |
| memory_loss: torch.Tensor | None = None | |
| aux_loss: torch.Tensor | None = None | |
| # =================================================================== | |
| # PreTrainedModel base | |
| # =================================================================== | |
| class QuasarPreTrainedModel(PreTrainedModel): | |
| config_class = QuasarConfig | |
| base_model_prefix = "model" | |
| supports_gradient_checkpointing = True | |
| _no_split_modules = ["HybridBlock"] | |
| _supports_cache_class = True | |
| def _init_weights(self, module): | |
| std = getattr(self.config, "initializer_range", 0.02) | |
| if isinstance(module, nn.Linear): | |
| nn.init.normal_(module.weight, mean=0.0, std=std) | |
| if module.bias is not None: | |
| nn.init.zeros_(module.bias) | |
| elif isinstance(module, nn.Embedding): | |
| nn.init.normal_(module.weight, mean=0.0, std=0.02) | |
| if module.padding_idx is not None: | |
| module.weight.data[module.padding_idx].zero_() | |
| # =================================================================== | |
| # QuasarModel — base transformer (no LM head) | |
| # Weight prefix: model.* (embed_tokens, embed_norm, layers, norm, rotary_emb, all_moe_*) | |
| # =================================================================== | |
| class QuasarModel(QuasarPreTrainedModel): | |
| config: QuasarConfig | |
| def __init__(self, config: QuasarConfig): | |
| super().__init__(config) | |
| self.config = config | |
| d_model = config.d_model | |
| n_heads = config.n_heads | |
| n_layers = config.n_layers | |
| vocab_size = config.vocab_size | |
| max_seq_len = config.max_seq_len | |
| self.embed_tokens = nn.Embedding(vocab_size, d_model) | |
| self.embed_norm = RMSNorm(d_model, eps=config.rms_norm_eps) | |
| self.layers = nn.ModuleList([ | |
| HybridBlock(config, i) for i in range(n_layers) | |
| ]) | |
| self.norm = RMSNorm(d_model, eps=config.rms_norm_eps) | |
| self.rotary_emb = RotaryEmbedding( | |
| d_model // n_heads, max_seq_len, base=config.rope_theta, | |
| ) | |
| # SMEBU global buffers — sized [num_moe, num_experts] to match checkpoint | |
| self.moe_layer_ffns = [l.ffn for l in self.layers if getattr(l, 'is_moe', False)] | |
| self.num_moe = len(self.moe_layer_ffns) | |
| num_experts = config.num_routed_experts | |
| if self.num_moe > 0 and num_experts > 0: | |
| self.register_buffer("all_moe_bias", torch.zeros(self.num_moe, num_experts)) | |
| self.register_buffer("all_moe_momentum", torch.zeros(self.num_moe, num_experts)) | |
| self.register_buffer("all_moe_max_vio", torch.zeros(self.num_moe)) | |
| self.gradient_checkpointing = False | |
| self.post_init() | |
| def get_input_embeddings(self): | |
| return self.embed_tokens | |
| def set_input_embeddings(self, value): | |
| self.embed_tokens = value | |
| def init_memory(self, batch_size, device, dtype=torch.float32): | |
| memory_states = {} | |
| for i, layer in enumerate(self.layers): | |
| if layer.layer_type == "gla": | |
| m = torch.zeros(batch_size, layer.memory.K, layer.memory.D, device=device, dtype=dtype) | |
| memory_states[i] = m | |
| return memory_states | |
| def forward( | |
| self, | |
| input_ids: torch.LongTensor | None = None, | |
| attention_mask: torch.Tensor | None = None, | |
| position_ids: torch.LongTensor | None = None, | |
| past_key_values: Cache | None = None, | |
| inputs_embeds: torch.FloatTensor | None = None, | |
| use_cache: bool | None = None, | |
| output_hidden_states: bool | None = None, | |
| memory_states: dict | None = None, | |
| lambda_reg: float = 0.01, | |
| **kwargs, | |
| ): | |
| output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
| use_cache = use_cache if use_cache is not None else self.config.use_cache | |
| if (input_ids is None) ^ (inputs_embeds is not None): | |
| raise ValueError("Specify exactly one of input_ids or inputs_embeds") | |
| if inputs_embeds is None: | |
| inputs_embeds = self.embed_tokens(input_ids) | |
| # Embed norm for stability | |
| hidden_states = self.embed_norm(inputs_embeds) | |
| batch_size, seq_len, _ = hidden_states.shape | |
| # Position ids | |
| if position_ids is None: | |
| past_seen_tokens = 0 | |
| if past_key_values is not None: | |
| try: | |
| past_seen_tokens = past_key_values.get_seq_length() | |
| except Exception: | |
| past_seen_tokens = 0 | |
| position_ids = torch.arange(past_seen_tokens, past_seen_tokens + seq_len, device=hidden_states.device) | |
| # RoPE | |
| max_pos = int(position_ids.max().item() + 1) if position_ids.numel() > 0 else seq_len | |
| cos_full, sin_full = self.rotary_emb(hidden_states, seq_len=max_pos) | |
| if position_ids.dim() == 1: | |
| cos = cos_full[:, :, position_ids] | |
| sin = sin_full[:, :, position_ids] | |
| else: | |
| cos = cos_full[:, :, position_ids[0]] | |
| sin = sin_full[:, :, position_ids[0]] | |
| # Memory states | |
| if memory_states is None: | |
| memory_states = self.init_memory(batch_size, hidden_states.device, hidden_states.dtype) | |
| all_hidden_states = () if output_hidden_states else None | |
| aux_losses = [] | |
| mem_losses = [] | |
| new_memory_states = {} | |
| # Looped transformer anchor | |
| P = hidden_states | |
| num_loops = self.config.num_loops | |
| current_memory_states = memory_states | |
| # Snapshot expert bias for gradient checkpointing consistency | |
| if self.num_moe > 0: | |
| bias_snapshot = self.all_moe_bias.detach().clone() | |
| else: | |
| bias_snapshot = None | |
| for loop_idx in range(num_loops): | |
| moe_idx = 0 | |
| iteration_new_memory_states = {} | |
| for layer in self.layers: | |
| if output_hidden_states: | |
| all_hidden_states += (hidden_states,) | |
| bias = bias_snapshot[moe_idx] if (getattr(layer, 'is_moe', False) and bias_snapshot is not None) else None | |
| layer_out = layer( | |
| hidden_states, | |
| cos=cos, sin=sin, | |
| expert_bias=bias, | |
| memory_state=current_memory_states.get(layer.layer_idx), | |
| lambda_reg=lambda_reg, | |
| P=P if self.config.use_looped_injection else None, | |
| past_key_values=past_key_values, | |
| use_cache=use_cache, | |
| **kwargs, | |
| ) | |
| hidden_states, aux_loss, new_m, m_loss = layer_out | |
| if new_m is not None: | |
| iteration_new_memory_states[layer.layer_idx] = new_m | |
| mem_losses.append(m_loss) | |
| if bias is not None: | |
| moe_idx += 1 | |
| aux_losses.append(aux_loss) | |
| current_memory_states = iteration_new_memory_states | |
| new_memory_states = iteration_new_memory_states | |
| # SMEBU bias update (no_grad to avoid checkpointing issues) | |
| if self.training and self.num_moe > 0: | |
| with torch.no_grad(): | |
| self._update_all_moe_biases() | |
| hidden_states = self.norm(hidden_states) | |
| if output_hidden_states: | |
| all_hidden_states += (hidden_states,) | |
| total_aux = torch.stack(aux_losses).sum() if aux_losses else torch.tensor(0.0, device=hidden_states.device) | |
| total_mem = torch.stack(mem_losses).sum() if mem_losses else torch.tensor(0.0, device=hidden_states.device) | |
| return QuasarModelOutputWithPast( | |
| last_hidden_state=hidden_states, | |
| past_key_values=past_key_values, | |
| hidden_states=all_hidden_states, | |
| memory_states=new_memory_states, | |
| memory_loss=total_mem, | |
| ), total_aux | |
| def _update_all_moe_biases(self): | |
| violations = torch.stack([m._pending_violation for m in self.moe_layer_ffns]) | |
| m0 = self.moe_layer_ffns[0] | |
| kappa, lamb, beta = m0.smebu_kappa, m0.smebu_lambda, m0.smebu_beta | |
| clamped_update = torch.tanh(kappa * violations) | |
| delta_bi = lamb * clamped_update | |
| delta_bi = delta_bi - delta_bi.mean(dim=-1, keepdim=True) | |
| self.all_moe_momentum.mul_(beta).add_(delta_bi, alpha=1 - beta) | |
| self.all_moe_bias.add_(self.all_moe_momentum).nan_to_num_().clamp_(-10.0, 10.0) | |
| self.all_moe_bias.sub_(self.all_moe_bias.mean(dim=-1, keepdim=True)) | |
| current_max_vios = -violations.min(dim=-1).values | |
| self.all_moe_max_vio.mul_(0.99).add_(current_max_vios, alpha=0.01) | |
| for i, moe in enumerate(self.moe_layer_ffns): | |
| moe.max_vio.copy_(self.all_moe_max_vio[i]) | |
| del moe._pending_violation | |
| # =================================================================== | |
| # QuasarForCausalLM — with LM head + generation support | |
| # Weight prefix: lm_head.* (top-level), model.* (from QuasarModel) | |
| # =================================================================== | |
| class QuasarForCausalLM(QuasarPreTrainedModel, FLAGenerationMixin): | |
| config: QuasarConfig | |
| _tied_weights_keys = {} | |
| def __init__(self, config: QuasarConfig): | |
| super().__init__(config) | |
| self.model = QuasarModel(config) | |
| self.vocab_size = config.vocab_size | |
| self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) | |
| self.model.lm_head = self.lm_head | |
| def _remap_lm_head_state_dict(module, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): | |
| checkpoint_key = prefix + "model.lm_head.weight" | |
| module_key = prefix + "lm_head.weight" | |
| if checkpoint_key in state_dict and module_key not in state_dict: | |
| state_dict[module_key] = state_dict[checkpoint_key] | |
| self.register_load_state_dict_pre_hook(_remap_lm_head_state_dict) | |
| self.post_init() | |
| def get_input_embeddings(self): | |
| return self.model.embed_tokens | |
| def set_input_embeddings(self, value): | |
| self.model.embed_tokens = value | |
| def get_output_embeddings(self): | |
| return self.lm_head | |
| def set_output_embeddings(self, new_embeddings): | |
| self.lm_head = new_embeddings | |
| def tie_weights(self, missing_keys=None, recompute_mapping=False): | |
| pass # Don't tie — crashes FSDP | |
| def forward( | |
| self, | |
| input_ids: torch.LongTensor | None = None, | |
| attention_mask: torch.Tensor | None = None, | |
| position_ids: torch.LongTensor | None = None, | |
| past_key_values: Cache | None = None, | |
| inputs_embeds: torch.FloatTensor | None = None, | |
| labels: torch.LongTensor | None = None, | |
| use_cache: bool | None = None, | |
| output_hidden_states: bool | None = None, | |
| memory_states: dict | None = None, | |
| lambda_reg: float = 0.01, | |
| return_dict: bool | None = None, | |
| **kwargs, | |
| ): | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| model_outputs, total_aux = self.model( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| position_ids=position_ids, | |
| past_key_values=past_key_values, | |
| inputs_embeds=inputs_embeds, | |
| use_cache=use_cache, | |
| output_hidden_states=output_hidden_states, | |
| memory_states=memory_states, | |
| lambda_reg=lambda_reg, | |
| **kwargs, | |
| ) | |
| hidden_states = model_outputs.last_hidden_state | |
| loss = None | |
| if labels is not None: | |
| shift_hidden = hidden_states[..., :-1, :].contiguous() | |
| shift_labels = labels[..., 1:].contiguous() | |
| flat_hidden = shift_hidden.view(-1, self.config.d_model) | |
| flat_labels = shift_labels.view(-1) | |
| mask = flat_labels != -100 | |
| if mask.any(): | |
| active_hidden = flat_hidden[mask] | |
| active_labels = flat_labels[mask] | |
| chunk_size = int(os.environ.get("QUASAR_LM_HEAD_CHUNK_SIZE", "2048")) | |
| total_loss = 0.0 | |
| total_tokens = active_labels.numel() | |
| for i in range(0, total_tokens, chunk_size): | |
| end = min(i + chunk_size, total_tokens) | |
| chunk_logits = self.lm_head(active_hidden[i:end]) | |
| chunk_loss = F.cross_entropy(chunk_logits.float(), active_labels[i:end], reduction='sum') | |
| total_loss += chunk_loss | |
| loss = total_loss / total_tokens | |
| loss = loss + total_aux + model_outputs.memory_loss | |
| else: | |
| loss = torch.tensor(0.0, device=hidden_states.device, requires_grad=True) | |
| logits = None | |
| else: | |
| logits = self.lm_head(hidden_states) | |
| if not return_dict: | |
| output = (logits,) + model_outputs[1:] | |
| return ((loss,) + output) if loss is not None else output | |
| return QuasarCausalLMOutputWithPast( | |
| loss=loss, | |
| logits=logits, | |
| past_key_values=model_outputs.past_key_values, | |
| hidden_states=model_outputs.hidden_states, | |
| memory_states=model_outputs.memory_states, | |
| memory_loss=model_outputs.memory_loss, | |
| aux_loss=total_aux, | |
| ) | |
| def prepare_inputs_for_generation( | |
| self, | |
| input_ids, | |
| past_key_values=None, | |
| attention_mask=None, | |
| inputs_embeds=None, | |
| memory_states=None, | |
| cache_position=None, | |
| use_cache=True, | |
| **kwargs, | |
| ): | |
| if past_key_values is not None: | |
| if input_ids is not None: | |
| input_ids = input_ids[:, -1:] | |
| if inputs_embeds is not None: | |
| inputs_embeds = inputs_embeds[:, -1:] | |
| if inputs_embeds is not None and past_key_values is None: | |
| model_inputs = {"inputs_embeds": inputs_embeds} | |
| else: | |
| model_inputs = {"input_ids": input_ids} | |
| if memory_states is None and past_key_values is not None: | |
| memory_states = getattr(past_key_values, "memory_states", None) | |
| model_inputs.update({ | |
| "past_key_values": past_key_values, | |
| "use_cache": use_cache, | |
| "attention_mask": attention_mask, | |
| "cache_position": cache_position, | |
| "memory_states": memory_states, | |
| }) | |
| return model_inputs | |
| def update_model_kwargs_for_generation(self, outputs, model_kwargs, is_seq2seq=False, num_new_tokens=1): | |
| model_kwargs = super().update_model_kwargs_for_generation( | |
| outputs=outputs, model_kwargs=model_kwargs, | |
| is_seq2seq=is_seq2seq, num_new_tokens=num_new_tokens, | |
| ) | |
| if getattr(outputs, "memory_states", None) is not None: | |
| model_kwargs["memory_states"] = outputs.memory_states | |
| return model_kwargs | |
| def _reorder_cache(self, past_key_values, beam_idx): | |
| if past_key_values is None: | |
| return None | |
| return past_key_values.reorder_cache(beam_idx) | |
| __all__ = [ | |
| "QuasarConfig", | |
| "QuasarPreTrainedModel", | |
| "QuasarModel", | |
| "QuasarForCausalLM", | |
| "QuasarModelOutputWithPast", | |
| "QuasarCausalLMOutputWithPast", | |
| ] | |