""" Sparse Autoencoder for Majestrino 1.00 voice embeddings. Top-K SAE that decomposes 768-dimensional Majestrino embeddings into 12,288 interpretable features (16x expansion, k=5 active features). Usage: from sae import SparseAutoencoder # Load from local directory sae = SparseAutoencoder.load_from_disk("model/") # Or download from HuggingFace sae = SparseAutoencoder.from_pretrained("laion/majestrino-1.00-16xk5-sae") # Encode embeddings (768-d) -> sparse features (12288-d) embedding = ... # (batch, 768) tensor from Majestrino 1.00 latents = sae.encode(embedding) # (batch, 12288) — mostly zeros, k=5 active # Or get top-k indices and values directly recons, info = sae(embedding) top_k_indices = info["inds"] # (batch, 5) — which features fired top_k_values = info["vals"] # (batch, 5) — activation strengths """ import torch import torch.nn as nn import os import json class SparseAutoencoder(nn.Module): """ Top-K Sparse Autoencoder with frequency penalty support. Architecture: latents = relu(topk(encoder(x - pre_bias) + latent_bias)) recons = decoder(latents) + pre_bias Args: n_dirs: Number of dictionary directions (features). Default: 12288 d_model: Input embedding dimension. Default: 768 k: Number of active features per input. Default: 5 auxk: Auxiliary top-k for dead neuron recovery (training only). Default: 256 dead_steps_threshold: Steps before a neuron is considered dead. Default: 2000 overact_coef: Overactivation penalty coefficient. Default: 3.0 freq_decay: EMA decay for frequency tracking. Default: 0.999 """ def __init__(self, n_dirs: int = 12288, d_model: int = 768, k: int = 5, auxk: int = 256, dead_steps_threshold: int = 2000, overact_coef: float = 3.0, freq_decay: float = 0.999): super().__init__() self.n_dirs = n_dirs self.d_model = d_model self.k = k self.auxk = auxk self.dead_steps_threshold = dead_steps_threshold self.overact_coef = overact_coef self.freq_decay = freq_decay self.encoder = nn.Linear(d_model, n_dirs, bias=False) self.decoder = nn.Linear(n_dirs, d_model, bias=False) self.pre_bias = nn.Parameter(torch.zeros(d_model)) self.latent_bias = nn.Parameter(torch.zeros(n_dirs)) # Training buffers (needed for loading saved state_dict) self.register_buffer("stats_last_nonzero", torch.zeros(n_dirs, dtype=torch.long)) self.register_buffer("freq_ema", torch.zeros(n_dirs)) def encode(self, x): """Encode input to sparse latent representation. Args: x: Input tensor of shape (batch, d_model) Returns: Dense latent tensor of shape (batch, n_dirs) with exactly k non-zero entries per row. """ x = x - self.pre_bias latents_pre_act = self.encoder(x) + self.latent_bias vals, inds = torch.topk(latents_pre_act, k=self.k, dim=-1) latents = torch.zeros_like(latents_pre_act) latents.scatter_(-1, inds, torch.relu(vals)) return latents def forward(self, x): """Forward pass: encode and reconstruct. Args: x: Input tensor of shape (batch, d_model) Returns: Tuple of (reconstruction, info_dict) where info_dict contains: - inds: (batch, k) top-k feature indices - vals: (batch, k) top-k activation values (after ReLU) """ x = x - self.pre_bias latents_pre_act = self.encoder(x) + self.latent_bias vals, inds = torch.topk(latents_pre_act, k=self.k, dim=-1) vals = torch.relu(vals) # Sparse reconstruction rows, cols = latents_pre_act.size() row_indices = torch.arange(rows, device=inds.device).unsqueeze(1).expand(-1, self.k).reshape(-1) indices = torch.stack([row_indices, inds.reshape(-1)]) sparse_tensor = torch.sparse_coo_tensor(indices, vals.reshape(-1), torch.Size([rows, cols])) recons = torch.sparse.mm(sparse_tensor, self.decoder.weight.T) + self.pre_bias return recons, {"inds": inds, "vals": vals} def decode_sparse(self, inds, vals): """Reconstruct from sparse representation. Args: inds: (batch, k) feature indices vals: (batch, k) activation values Returns: Reconstructed tensor of shape (batch, d_model) """ rows, cols = inds.shape[0], self.n_dirs row_indices = torch.arange(rows, device=inds.device).unsqueeze(1).expand(-1, inds.shape[1]).reshape(-1) indices = torch.stack([row_indices, inds.reshape(-1)]) sparse_tensor = torch.sparse_coo_tensor(indices, vals.reshape(-1), torch.Size([rows, cols])) recons = torch.sparse.mm(sparse_tensor, self.decoder.weight.T) + self.pre_bias return recons def save_to_disk(self, path): """Save model config and weights to a directory.""" os.makedirs(path, exist_ok=True) cfg = { "n_dirs": self.n_dirs, "d_model": self.d_model, "k": self.k, "auxk": self.auxk, "dead_steps_threshold": self.dead_steps_threshold, "overact_coef": self.overact_coef, "freq_decay": self.freq_decay, } with open(os.path.join(path, "config.json"), "w") as f: json.dump(cfg, f, indent=2) torch.save({"state_dict": self.state_dict()}, os.path.join(path, "state_dict.pth")) @classmethod def load_from_disk(cls, path): """Load model from a local directory containing config.json and state_dict.pth. Args: path: Directory containing config.json and state_dict.pth Returns: Loaded SparseAutoencoder instance """ with open(os.path.join(path, "config.json")) as f: cfg = json.load(f) ae = cls(**cfg) state = torch.load(os.path.join(path, "state_dict.pth"), map_location="cpu", weights_only=True) ae.load_state_dict(state["state_dict"]) return ae @classmethod def from_pretrained(cls, repo_id="laion/majestrino-1.00-16xk5-sae", device="cpu"): """Download and load model from HuggingFace Hub. Args: repo_id: HuggingFace repository ID device: Device to load model on ("cpu", "cuda", etc.) Returns: Loaded SparseAutoencoder instance on the specified device """ from huggingface_hub import hf_hub_download config_path = hf_hub_download(repo_id, "model/config.json") weights_path = hf_hub_download(repo_id, "model/state_dict.pth") with open(config_path) as f: cfg = json.load(f) ae = cls(**cfg) state = torch.load(weights_path, map_location=device, weights_only=True) ae.load_state_dict(state["state_dict"]) return ae.to(device).eval()