--- base_model: - fdemelo/g2p-multilingual-byt5-tiny-8l-ipa-childes language: - multilingual - ca - cy - da - de - en - es - et - eu - fa - fr - ga - hr - hu - id - is - it - ja - ko - nl - 'no' - pl - pt - qu - ro - sr - sv - tr - zh - yue datasets: - fdemelo/ipa-childes-split license: apache-2.0 pipeline_tag: text-generation --- onnx version of [fdemelo/g2p-multilingual-byt5-tiny-8l-ipa-childes](https://huggingface.co/fdemelo/g2p-multilingual-byt5-tiny-8l-ipa-childes) inference example ```python from transformers import AutoTokenizer import onnxruntime import numpy as np def infer_onnx(text: str, lang: str, onnx_model_path: str = "byt5_g2p_model.onnx"): """ Exports the ByT5 model to ONNX format and then performs inference using ONNX Runtime. Args: text (str): The input text to convert to phonemes. lang (str): The language tag (e.g., "en"). onnx_model_path (str): The path to save/load the ONNX model. """ model_name = 'fdemelo/g2p-multilingual-byt5-tiny-8l-ipa-childes' tokenizer = AutoTokenizer.from_pretrained(model_name) # --- Step 2: Perform Inference with ONNX Runtime --- print("\n--- Performing inference with ONNX Runtime ---") # Create an ONNX Runtime session try: session = onnxruntime.InferenceSession(onnx_model_path, providers=['CPUExecutionProvider']) except Exception as e: print(f"Error loading ONNX model: {e}") return # Get input and output names from the ONNX model onnx_input_names = [inp.name for inp in session.get_inputs()] onnx_output_names = [out.name for out in session.get_outputs()] # Prepare actual input for ONNX inference input_text_for_onnx = f"<{lang}>: {text}" inputs_for_onnx = tokenizer([input_text_for_onnx], return_tensors="pt", add_special_tokens=False) input_ids_np = inputs_for_onnx["input_ids"].cpu().numpy() attention_mask_np = inputs_for_onnx["attention_mask"].cpu().numpy() # Manual greedy decoding loop for ONNX Runtime # This simulates the 'generate' method's greedy decoding. generated_ids = [] # T5 models typically use pad_token_id as the initial token for generation # or a specific decoder_start_token_id. # For T5, the decoder_start_token_id is usually the pad_token_id. current_decoder_input_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0 # Ensure it's a batch of 1 decoder_input_ids_np = np.array([[current_decoder_input_id]]) max_length = 512 # Same as in the original predict_byt5 # Store encoder outputs if needed for cross-attention in decoder (T5 does this) # When exporting the full T5 model's forward pass, the encoder_hidden_states # are implicitly handled within the graph. We just need to feed the decoder_input_ids. for _ in range(max_length): # Prepare inputs for the current step onnx_inputs = { "input_ids": input_ids_np, "attention_mask": attention_mask_np, "decoder_input_ids": decoder_input_ids_np } # Run inference outputs = session.run(onnx_output_names, onnx_inputs) logits = outputs[0] # Get the logits # Get the logits for the last token in the sequence next_token_logits = logits[0, -1, :] # Batch 0, last token, all vocab logits # Greedy decoding: pick the token with the highest logit next_token_id = np.argmax(next_token_logits) generated_ids.append(next_token_id) # Check for end-of-sequence token if next_token_id == tokenizer.eos_token_id: break # Update decoder input for the next step # Append the new token to the decoder input sequence decoder_input_ids_np = np.concatenate((decoder_input_ids_np, np.array([[next_token_id]])), axis=1) # Decode the generated ONNX phoneme IDs onnx_phones = tokenizer.batch_decode([generated_ids], skip_special_tokens=True) print(f"ONNX Runtime Inference: {onnx_phones}") return onnx_phones ``