""" Example usage of LogitsProcessors for document parsing. This example shows how to use: - TableInsertionLogitsProcessor: Force \begin{tabular} at the start of every object - RepetitionStopProcessor: Detect hallucination/repetition and force coordinate tokens """ import torch from PIL import Image, ImageDraw from transformers import AutoModel, AutoProcessor, AutoTokenizer, GenerationConfig from postprocessing import extract_classes_bboxes, transform_bbox_to_original, postprocess_text from hf_logits_processor import TableInsertionLogitsProcessor, RepetitionStopProcessor # Load model and processor model_path = "nvidia/NVIDIA-Nemotron-Parse-v1.2" or use a local path device = "cuda:0" model = AutoModel.from_pretrained( model_path, trust_remote_code=True, torch_dtype=torch.bfloat16 ).to(device).eval() tokenizer = AutoTokenizer.from_pretrained(model_path) processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) # Load image image = Image.open('example.png').convert("RGB") task_prompt = "" # Process image inputs = processor(images=[image], text=task_prompt, return_tensors="pt", add_special_tokens=False).to(device) generation_config = GenerationConfig.from_pretrained(model_path, trust_remote_code=True) # Create the table processor - inserts \begin{tabular} after every that starts a new object table_processor = TableInsertionLogitsProcessor( tokenizer=tokenizer, table_prefix="\\begin{tabular}" ) # Create the repetition stop processor - detects hallucination and forces tokens repetition_processor = RepetitionStopProcessor( tokenizer=tokenizer, max_repetitions=10, # Force stop after any pattern repeats 10+ times ngram_sizes=[3, 4, 5, 6], # Check these n-gram sizes for repetition window_size=500 # Only check the last 500 tokens ) # Generate with both logits processors outputs = model.generate( **inputs, generation_config=generation_config, logits_processor=[table_processor, repetition_processor] ) # Reset processor states for next generation (important for batch processing) table_processor.reset() repetition_processor.reset() # Decode and process the generated text generated_text = processor.batch_decode(outputs, skip_special_tokens=True)[0] print(outputs) print('--------------------------------') print("Generated text:", generated_text) print('--------------------------------') classes, bboxes, texts = extract_classes_bboxes(generated_text) bboxes = [transform_bbox_to_original(bbox, image.width, image.height) for bbox in bboxes] # Specify output formats for postprocessing table_format = 'HTML' # latex | HTML | markdown text_format = 'markdown' # markdown | plain blank_text_in_figures = False # remove text inside 'Picture' class texts = [ postprocess_text( text, cls=cls, table_format=table_format, text_format=text_format, blank_text_in_figures=blank_text_in_figures ) for text, cls in zip(texts, classes) ] for cl, bb, txt in zip(classes, bboxes, texts): print(cl, ': ', txt) # OPTIONAL - Draw bounding boxes draw = ImageDraw.Draw(image) for bbox in bboxes: draw.rectangle((bbox[0], bbox[1], (max(bbox[0], bbox[2])), (max(bbox[1], bbox[3]))), outline="red", width=2) # Save or display the image image.save("output_with_boxes.jpg") # image.show()