"""EDEN model for Hugging Face Transformers. EDEN (Encoder Decoder Enhancement Network) is a from-scratch encoder-decoder Transformer that rewrites rough text into polished text. This module wraps the original architecture in a ``PreTrainedModel`` so the model can be loaded with ``AutoModel.from_pretrained(..., trust_remote_code=True)`` and saved with ``save_pretrained`` as safetensors. The layer structure (embedding, positional encoding, ``nn.Transformer``, tied language-model head) matches the original training code exactly, so checkpoints trained with the standalone trainer load here without any key remapping beyond loading the model weights. """ from __future__ import annotations import math import re import torch import torch.nn as nn import torch.nn.functional as F from transformers import PreTrainedModel from transformers.modeling_outputs import Seq2SeqLMOutput # When this file is loaded from the Hugging Face Hub it lives inside a package, # so the sibling config is a relative import. When imported directly by a local # script (for example the conversion script) it is a top-level module instead. # The try/except supports both, and Transformers ignores imports inside a try # block when checking dependencies. try: from .configuration_eden import EdenConfig except ImportError: from configuration_eden import EdenConfig def _normalise_text(text: str) -> str: text = str(text or "") text = text.replace("‘", "'").replace("’", "'") text = text.replace("“", '"').replace("”", '"') text = re.sub(r"\s+", " ", text).strip() return text def _sentence_split(text: str) -> list[str]: parts = re.split(r"(?<=[.!?])\s+", text.strip()) return [p for p in parts if p] class PositionalEncoding(nn.Module): def __init__(self, d_model: int, max_len: int, dropout: float): super().__init__() self.dropout = nn.Dropout(dropout) pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) # Persistent so the table is written to safetensors and restored on load. # Transformers initialises models on the meta device, which would leave a # non-persistent buffer uninitialised (NaN) after from_pretrained. self.register_buffer("pe", pe.unsqueeze(0), persistent=True) def forward(self, x: torch.Tensor) -> torch.Tensor: x = x + self.pe[:, : x.size(1), :].to(dtype=x.dtype) return self.dropout(x) class EdenPreTrainedModel(PreTrainedModel): config_class = EdenConfig base_model_prefix = "eden" supports_gradient_checkpointing = False def _init_weights(self, module: nn.Module) -> None: if isinstance(module, nn.Linear): nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.padding_idx is not None: with torch.no_grad(): module.weight[module.padding_idx].zero_() class EdenForTextEnhancement(EdenPreTrainedModel): """Encoder-decoder Transformer with a tied language-model head.""" _tied_weights_keys = {"lm_head.weight": "embedding.weight"} def __init__(self, config: EdenConfig): super().__init__(config) self.pad_id = config.pad_token_id self.bos_id = config.bos_token_id self.eos_id = config.eos_token_id self.unk_id = config.unk_token_id self.scale = math.sqrt(config.d_model) self.embedding = nn.Embedding(config.vocab_size, config.d_model, padding_idx=config.pad_token_id) self.pos = PositionalEncoding(config.d_model, config.max_len + 4, config.dropout) self.transformer = nn.Transformer( d_model=config.d_model, nhead=config.n_heads, num_encoder_layers=config.n_layers, num_decoder_layers=config.n_layers, dim_feedforward=config.dim_feedforward, dropout=config.dropout, activation="gelu", batch_first=True, norm_first=True, ) self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) self.post_init() # ------------------------------------------------------------------ # # Hugging Face plumbing # ------------------------------------------------------------------ # def get_input_embeddings(self) -> nn.Module: return self.embedding def set_input_embeddings(self, value: nn.Module) -> None: self.embedding = value def get_output_embeddings(self) -> nn.Module: return self.lm_head def set_output_embeddings(self, new_embeddings: nn.Module) -> None: self.lm_head = new_embeddings def _tie_weights(self) -> None: self.lm_head.weight = self.embedding.weight # ------------------------------------------------------------------ # # Core encoder-decoder # ------------------------------------------------------------------ # def encode(self, src: torch.Tensor): src_padding = src.eq(self.pad_id) src_emb = self.pos(self.embedding(src) * self.scale) memory = self.transformer.encoder(src_emb, src_key_padding_mask=src_padding) return memory, src_padding def decode(self, tgt: torch.Tensor, memory: torch.Tensor, src_padding: torch.Tensor) -> torch.Tensor: tgt_padding = tgt.eq(self.pad_id) tgt_emb = self.pos(self.embedding(tgt) * self.scale) tgt_len = tgt.size(1) causal_mask = torch.triu( torch.ones(tgt_len, tgt_len, dtype=torch.bool, device=tgt.device), diagonal=1, ) return self.transformer.decoder( tgt_emb, memory, tgt_mask=causal_mask, tgt_key_padding_mask=tgt_padding, memory_key_padding_mask=src_padding, ) def forward( self, input_ids: torch.Tensor, decoder_input_ids: torch.Tensor | None = None, labels: torch.Tensor | None = None, return_dict: bool | None = None, **kwargs, ) -> Seq2SeqLMOutput: return_dict = return_dict if return_dict is not None else self.config.use_return_dict if decoder_input_ids is None and labels is not None: decoder_input_ids = self._shift_right(labels) if decoder_input_ids is None: raise ValueError("Provide decoder_input_ids or labels to EdenForTextEnhancement.forward.") memory, src_padding = self.encode(input_ids) hidden = self.decode(decoder_input_ids, memory, src_padding) logits = self.lm_head(hidden) loss = None if labels is not None: loss = F.cross_entropy( logits.float().reshape(-1, logits.size(-1)), labels.reshape(-1), ignore_index=-100, ) if not return_dict: output = (logits,) return ((loss,) + output) if loss is not None else output return Seq2SeqLMOutput(loss=loss, logits=logits) def _shift_right(self, labels: torch.Tensor) -> torch.Tensor: shifted = labels.new_zeros(labels.shape) shifted[:, 1:] = labels[:, :-1].clone() shifted[:, 0] = self.bos_id shifted.masked_fill_(shifted == -100, self.pad_id) return shifted # ------------------------------------------------------------------ # # Generation (ported from the original trainer, no KV cache needed for # the short sequences this model handles) # ------------------------------------------------------------------ # @torch.no_grad() def _beam_generate(self, src, beam_size, max_new_tokens, length_penalty, repetition_penalty): self.eval() device = src.device max_len = self.config.max_len memory, src_padding = self.encode(src) beams = [([self.bos_id], 0.0, False)] for _ in range(max_new_tokens): if all(done for _, _, done in beams): break candidates = [] for tokens, score, done in beams: if done: candidates.append((tokens, score, True)) continue tgt = torch.tensor([tokens[-max_len:]], dtype=torch.long, device=device) hidden = self.decode(tgt, memory, src_padding) logits = self.lm_head(hidden[:, -1, :]).float().squeeze(0) if repetition_penalty != 1.0: for token_id in set(tokens): if 0 <= token_id < logits.numel(): logits[token_id] /= repetition_penalty logits[self.unk_id] = -float("inf") logits[self.pad_id] = -float("inf") logits[self.bos_id] = -float("inf") log_probs = F.log_softmax(logits, dim=-1) values, indices = torch.topk(log_probs, k=min(beam_size, log_probs.numel())) for value, index in zip(values.tolist(), indices.tolist()): new_tokens = tokens + [int(index)] candidates.append((new_tokens, score + float(value), int(index) == self.eos_id)) def rank(item): toks, sc, _ = item length = max(1, len(toks) - 1) return sc / (length ** length_penalty) candidates.sort(key=rank, reverse=True) beams = candidates[:beam_size] best = max(beams, key=lambda item: item[1] / (max(1, len(item[0]) - 1) ** length_penalty)) out = best[0][1:] if self.eos_id in out: out = out[: out.index(self.eos_id)] skip = {self.pad_id, self.bos_id, self.eos_id, self.unk_id} return [t for t in out if t not in skip] @torch.no_grad() def _sample_generate(self, src, strategy, max_new_tokens, temperature, top_k, top_p, repetition_penalty): self.eval() device = src.device max_len = self.config.max_len memory, src_padding = self.encode(src) tokens = [self.bos_id] skip = {self.pad_id, self.bos_id, self.eos_id, self.unk_id} for _ in range(max_new_tokens): tgt = torch.tensor([tokens[-max_len:]], dtype=torch.long, device=device) hidden = self.decode(tgt, memory, src_padding) logits = self.lm_head(hidden[:, -1, :]).float().squeeze(0) if repetition_penalty != 1.0: for token_id in set(tokens): if 0 <= token_id < logits.numel(): logits[token_id] /= repetition_penalty logits[self.unk_id] = -float("inf") logits[self.pad_id] = -float("inf") logits[self.bos_id] = -float("inf") if strategy == "sample": logits = logits / max(0.05, temperature) logits = self._filter_top_k_top_p(logits, top_k, top_p) probs = F.softmax(logits, dim=-1) if not torch.isfinite(probs).all() or float(probs.sum().item()) <= 0: next_id = int(torch.argmax(logits).item()) else: next_id = int(torch.multinomial(probs.detach().cpu(), 1).item()) else: next_id = int(torch.argmax(logits).item()) if next_id == self.eos_id: break if next_id not in skip: tokens.append(next_id) return tokens[1:] @staticmethod def _filter_top_k_top_p(logits, top_k, top_p): filtered = logits.clone() if top_k > 0 and top_k < filtered.numel(): threshold = torch.topk(filtered, top_k).values[-1] filtered[filtered < threshold] = -float("inf") if 0.0 < top_p < 1.0: sorted_logits, sorted_indices = torch.sort(filtered, descending=True) probs = F.softmax(sorted_logits, dim=-1) cumulative = torch.cumsum(probs, dim=-1) remove = cumulative > top_p remove[1:] = remove[:-1].clone() remove[0] = False filtered[sorted_indices[remove]] = -float("inf") return filtered def _chunk_text(self, text, tokenizer): text = _normalise_text(text) ids = tokenizer.encode(text, add_special_tokens=False) max_src = self.config.max_len - 2 if len(ids) <= max_src: return [text] chunks, current, current_ids = [], [], [] for sent in _sentence_split(text) or [text]: sent_ids = tokenizer.encode(sent, add_special_tokens=False) if current and len(current_ids) + len(sent_ids) > max_src: chunks.append(" ".join(current)) current, current_ids = [], [] if len(sent_ids) > max_src: for i in range(0, len(sent_ids), max_src): chunks.append(tokenizer.decode(sent_ids[i : i + max_src], skip_special_tokens=True)) else: current.append(sent) current_ids.extend(sent_ids) if current: chunks.append(" ".join(current)) return chunks @torch.no_grad() def enhance( self, tokenizer, text: str, strategy: str = "beam", beam_size: int | None = None, max_new_tokens: int | None = None, temperature: float = 0.7, top_k: int = 40, top_p: float = 0.9, length_penalty: float | None = None, repetition_penalty: float | None = None, ) -> str: """Rewrite ``text`` into polished text. This mirrors the original trainer's enhancement pipeline: long inputs are split into sentence-aware chunks, each chunk is rewritten, and the results are joined back together. """ strategy = strategy if strategy in {"beam", "greedy", "sample"} else "beam" beam = max(1, int(beam_size or self.config.beam_size)) cap = max(8, self.config.max_len - 1) max_tokens = int(max_new_tokens) if max_new_tokens else min(256, cap) max_tokens = max(8, min(cap, max_tokens)) len_penalty = self.config.length_penalty if length_penalty is None else float(length_penalty) rep_penalty = self.config.repetition_penalty if repetition_penalty is None else float(repetition_penalty) device = self.embedding.weight.device outputs = [] for chunk in self._chunk_text(text, tokenizer): src_ids = tokenizer.encode(chunk, add_special_tokens=False)[: self.config.max_len - 2] src = torch.tensor([[self.bos_id] + src_ids + [self.eos_id]], dtype=torch.long, device=device) if strategy == "beam": out_ids = self._beam_generate(src, beam, max_tokens, len_penalty, rep_penalty) else: out_ids = self._sample_generate( src, strategy, max_tokens, temperature, top_k, top_p, rep_penalty ) decoded = _normalise_text(tokenizer.decode(out_ids, skip_special_tokens=True)) outputs.append(decoded or chunk) return _normalise_text(" ".join(outputs))