--- language: - en base_model: - albert/albert-large-v2 pipeline_tag: token-classification tags: - punctuation-prediction - punctuation-restoration - punctuation - casing - restoration - case - casing-restoration - casing-prediction - albert license: apache-2.0 --- # ReCasePunct 1 Flash We introduce **ReCasePunct 1 Flash**, our first model capable of punctuation and casing restoration! Given lowercase and non-punctated English text of any length (but it's not infinite length, as far as I tested), this model can predict punctuation and casing, and it's impressive! It also runs very fast on CPU too! Use cases could be for ASR tasks (some models give text without casing and punctuation, like on auto-generated subtitles for YouTube videos from 2023/2024/2025) ## Limitations This model was trained ONLY on English Tatoeba data (from 21 Feburary 2026) and doesn't do well for other languages. Also, it doesn't do perfectly sometimes (especially with proper nouns like "Minecraft"). We might train a multi-lingual and better ReCasePunct model next! ## How To Run It Code by Gemini 2.5 Flash: ```python from transformers import AutoTokenizer, AlbertConfig, AlbertModel import torch import torch.nn as nn import re import numpy as np from safetensors.torch import load_file # Import safe_load for safetensors from huggingface_hub import hf_hub_download # Import hf_hub_download # Redefine the model class (must be the same as during training) class AlbertForPunctuationAndCasing(nn.Module): def __init__(self, config): super().__init__() self.num_punctuation_labels = config.num_punctuation_labels self.num_casing_labels = config.num_casing_labels # Initialize AlbertModel directly with the config provided # This config should ideally reflect the true albert-large-v2 architecture self.albert = AlbertModel(config) self.dropout = nn.Dropout(config.classifier_dropout_prob) self.punctuation_classifier = nn.Linear(config.hidden_size, self.num_punctuation_labels) self.casing_classifier = nn.Linear(config.hidden_size, self.num_casing_labels) def forward( self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, casing_labels=None, punctuation_labels=None, output_attentions=None, output_hidden_states=None, return_dict=None, ): return_dict = return_dict if return_dict is not None else True outputs = self.albert( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = outputs[0] sequence_output = self.dropout(sequence_output) punctuation_logits = self.punctuation_classifier(sequence_output) casing_logits = self.casing_classifier(sequence_output) loss = None if casing_labels is not None and punctuation_labels is not None: loss_fct = nn.CrossEntropyLoss(ignore_index=-100) punctuation_loss = loss_fct(punctuation_logits.view(-1, self.num_punctuation_labels), punctuation_labels.view(-1)) casing_loss = loss_fct(casing_logits.view(-1, self.num_casing_labels), casing_labels.view(-1)) loss = punctuation_loss + casing_loss if not return_dict: output = (punctuation_logits, casing_logits) + outputs[2:] return ((loss,) + output) if loss is not None else output result = { "loss": loss, "punctuation_logits": punctuation_logits, "casing_logits": casing_logits, } if outputs.hidden_states is not None: result["hidden_states"] = outputs.hidden_states if outputs.attentions is not None: result["attentions"] = outputs.attentions return result # --- Configuration and Mappings (must be the same as during training) --- punctuation_labels = ['O', '.', ',', '?', '!', ';', ':', '-', '"', '(', ')', '/', '\\'] punctuation_to_id = {label: i for i, label in enumerate(punctuation_labels)} id_to_punctuation = {i: label for i, label in enumerate(punctuation_labels)} casing_labels = ['O', 'CAP', 'UPPER'] casing_to_id = {label: i for i, label in enumerate(casing_labels)} id_to_casing = {i: label for i, label in enumerate(casing_labels)} model_checkpoint = 'albert-large-v2' # Define the Hugging Face repository ID hf_repo_id = "MihaiPopa-1/ReCasePunct-1-Flash" # Load tokenizer from Hugging Face Hub tokenizer = AutoTokenizer.from_pretrained(hf_repo_id) # --- CORRECTED MODEL CONFIG LOADING --- # 1. Load the base ALBERT Large v2 configuration to get correct architecture defaults (like hidden_size) config = AlbertConfig.from_pretrained(model_checkpoint) # 2. Set the custom labels on this correctly sized config config.num_punctuation_labels = len(punctuation_labels) config.num_casing_labels = len(casing_labels) # Instantiate the custom model with the corrected config model = AlbertForPunctuationAndCasing(config) # Download the model.safetensors file from the Hub safetensors_path = hf_hub_download(repo_id=hf_repo_id, filename="model.safetensors") # Load the full state dictionary into the custom model model.load_state_dict(load_file(safetensors_path, device='cpu')) model.eval() def clean_text(text): """Removes punctuation and converts text to lowercase for the model input.""" text = text.lower() text = re.sub(r'[\.,\?!\-;:"\(\)\[\]\{\}\/\\]', '', text) # Remove common punctuation text = re.sub(r'\s+', ' ', text).strip() # Replace multiple spaces with single space return text def predict_punctuation_and_casing(text, model, tokenizer, id_to_punctuation, id_to_casing): # Clean the input text similar to how training data was prepared cleaned_text_input = clean_text(text) words_in_cleaned_text = cleaned_text_input.split() # Tokenize the input tokenized_input = tokenizer( cleaned_text_input, return_offsets_mapping=True, truncation=True, max_length=tokenizer.model_max_length, return_tensors="pt" ) # Perform inference with torch.no_grad(): outputs = model( input_ids=tokenized_input['input_ids'], attention_mask=tokenized_input['attention_mask'] ) punctuation_logits = outputs['punctuation_logits'].squeeze(0).numpy() casing_logits = outputs['casing_logits'].squeeze(0).numpy() punctuation_predictions = np.argmax(punctuation_logits, axis=-1) casing_predictions = np.argmax(casing_logits, axis=-1) # Initialize output list for reconstructed sentence reconstructed_text_parts = [] current_word_idx = 0 # Iterate over tokens and apply predictions for token_idx, (token_start, token_end) in enumerate(tokenized_input['offset_mapping'].squeeze(0).numpy()): if token_start == 0 and token_end == 0: # Skip special tokens like [CLS], [SEP] continue # Get the word from the original cleaned text (not subword) # This requires careful alignment if a single word maps to multiple tokens # and apply label to the last token of a word. # Find the actual word from the input_text_single corresponding to this token token_text = cleaned_text_input[token_start:token_end] # Check if this token is the beginning of a word we care about if current_word_idx < len(words_in_cleaned_text) and words_in_cleaned_text[current_word_idx].startswith(token_text): word = words_in_cleaned_text[current_word_idx] # Apply casing casing_pred_label = id_to_casing[casing_predictions[token_idx]] if casing_pred_label == 'CAP': word = word.capitalize() elif casing_pred_label == 'UPPER': word = word.upper() # Apply punctuation (only to the last subword token of a word) # This is a heuristic and might need refinement for complex tokenizations next_token_word_idx = -1 if token_idx + 1 < len(tokenized_input['offset_mapping'].squeeze(0).numpy()): next_token_start, _ = tokenized_input['offset_mapping'].squeeze(0).numpy()[token_idx+1] # Check if the next token starts after the current word ends in the cleaned_text_input # or if the next token is a special token if next_token_start >= token_end or (tokenized_input['input_ids'].squeeze(0)[token_idx+1].item() in [tokenizer.cls_token_id, tokenizer.sep_token_id]): # This is likely the last token of the current word punctuation_pred_label = id_to_punctuation[punctuation_predictions[token_idx]] if punctuation_pred_label != 'O': word += punctuation_pred_label else: # Last token in the sequence punctuation_pred_label = id_to_punctuation[punctuation_predictions[token_idx]] if punctuation_pred_label != 'O': word += punctuation_pred_label reconstructed_text_parts.append(word) current_word_idx += 1 return ' '.join(reconstructed_text_parts).replace(' .', '.').replace(' ,', ',').replace(' ?', '?').replace(' !', '!').replace(' ;', ';').replace(' :', ':').replace(' -', '-').replace(' "', '"').replace('( ', '(').replace(' )', ')').replace(' /', '/').replace(' \\', '\\') # --- Test Case for a single sentence --- single_sample_sentence = "replace me by whatever sentence you like" print(f"Original: {single_sample_sentence}") print(f"Predicted: {predict_punctuation_and_casing(single_sample_sentence, model, tokenizer, id_to_punctuation, id_to_casing)}\n") ``` Should give: `Replace me by whatever sentence you like.` ## Examples | Original Sentence | Predicted Sentence | | :----- | :--------: | | this is a test of punctuation prediction for english how are you doing today | This is a test of punctuation prediction for English. How are you doing today? | | i love running this on t4 gpu and so for this goal we might make a better and more accurate model in the future | I love running this on T4 GPU and so, for this goal, we might make a better and more accurate model in the future. | | so imagine this we live in a world with complex models yet this model does punctuation and casing prediction for english and it's very small at just only 18 million parameters | So, imagine this, we live in a world with complex models. Yet this model does punctuation and casing prediction for English, and it's very small at just only 18 million parameters. | ## Evaluation Results | Epoch | Training Loss | Validation Loss | Punctuation Accuracy | Casing Accuracy | Overall Accuracy | | :----- | :--------: | :---------------------------: | :---------------------------: | :---------------------------: | :---------------------------: | | 1 | 0.072175 | 0.070485 | 0.642053 (64.21%) | 0.638791 (63.88%) | 0.640422 (64.04%) | | 2 | 0.052846 | 0.063811 | 0.642343 (64.23%) | 0.640475 (64.05%) | 0.641409 (64.14%) | | 3 | 0.031407 | 0.062892 | 0.640457 (64.05%) | 0.640098 (64.01%) | 0.640278 (64.03%) |