helizac commited on
Commit
642ff82
·
verified ·
1 Parent(s): afbffaa

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +25 -32
README.md CHANGED
@@ -47,53 +47,46 @@ You can then use the 4-bit model with the following Python script. Note the incl
47
  import torch
48
  from transformers import AutoModelForCausalLM, AutoProcessor
49
  from PIL import Image
50
- import os
51
- import traceback
52
 
53
- # This assumes the utility script is available in your environment
54
  from qwen_vl_utils import process_vision_info
55
 
56
  MODEL_ID = "helizac/dots.ocr-4bit"
57
 
58
- print("Loading 4-bit quantized model from the Hub...")
59
- model = AutoModelForCausalLM.from_pretrained(
60
- MODEL_ID,
61
- device_map="auto",
62
- trust_remote_code=True,
63
- torch_dtype=torch.bfloat16,
64
- )
65
- processor = AutoProcessor.from_pretrained(
66
- MODEL_ID,
67
- trust_remote_code=True
68
- )
69
- print("✅ Model and processor loaded successfully!")
70
-
71
- # --- Inference ---
72
- image_path = "demo/demo_image1.jpg" # Make sure you have this image
73
  image = Image.open(image_path)
74
- prompt_text = "Parse all layout info, both detection and recognition"
75
 
76
- messages = [
77
- {"role": "user", "content": [{"type": "image", "image": image_path}, {"type": "text", "text": prompt_text}]}
78
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
- # Prepare inputs using the official workflow
81
  text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
82
  image_inputs, _ = process_vision_info(messages)
83
- inputs = processor(
84
- text=[text], images=image_inputs, padding=True, return_tensors="pt"
85
- ).to(model.device)
86
 
87
- # Generate with parameters to prevent looping with the 4-bit model
88
- generated_ids = model.generate(
89
- **inputs, max_new_tokens=4096, do_sample=True, temperature=0.6, top_p=0.9, repetition_penalty=1.15
90
- )
91
 
92
- # Trim and decode output
93
  generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
94
  output_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
95
 
96
- print("\n--- Inference Result ---")
97
  print(output_text)
98
  ```
99
 
 
47
  import torch
48
  from transformers import AutoModelForCausalLM, AutoProcessor
49
  from PIL import Image
50
+ from huggingface_hub import snapshot_download
 
51
 
 
52
  from qwen_vl_utils import process_vision_info
53
 
54
  MODEL_ID = "helizac/dots.ocr-4bit"
55
 
56
+ local_model_path = snapshot_download(repo_id=MODEL_ID)
57
+
58
+ model = AutoModelForCausalLM.from_pretrained(local_model_path, device_map="auto", trust_remote_code=True, torch_dtype=torch.bfloat16)
59
+ processor = AutoProcessor.from_pretrained(local_model_path, trust_remote_code=True)
60
+
61
+ image_path = "test.jpg"
 
 
 
 
 
 
 
 
 
62
  image = Image.open(image_path)
 
63
 
64
+ prompt_text = """\
65
+ Please output the layout information from the image, including each layout element's bbox, its category, and the corresponding text content within the bbox.
66
+ 1. Bbox format: [x1, y1, x2, y2]
67
+ 2. Layout Categories: The possible categories are ['Caption', 'Footnote', 'Formula', 'List-item', 'Page-footer', 'Page-header', 'Picture', 'Section-header', 'Table', 'Text', 'Title'].
68
+ 3. Text Extraction & Formatting Rules:
69
+ - Picture: For the 'Picture' category, the text field should be omitted.
70
+ - Formula: Format its text as LaTeX.
71
+ - Table: Format its text as HTML.
72
+ - All Others (Text, Title, etc.): Format their text as Markdown.
73
+ 4. Constraints:
74
+ - The output text must be the original text from the image, with no translation.
75
+ - All layout elements must be sorted according to human reading order.
76
+ 5. Final Output: The entire output must be a single JSON object.\
77
+ """
78
+
79
+ messages = [{"role": "user", "content": [{"type": "image", "image": image_path}, {"type": "text", "text": prompt_text}]}]
80
 
 
81
  text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
82
  image_inputs, _ = process_vision_info(messages)
83
+ inputs = processor(text=[text], images=image_inputs, padding=True, return_tensors="pt").to(model.device)
 
 
84
 
85
+ generated_ids = model.generate(**inputs, max_new_tokens=1048, do_sample=True, temperature=0.6, top_p=0.9, repetition_penalty=1.15)
 
 
 
86
 
 
87
  generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
88
  output_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
89
 
 
90
  print(output_text)
91
  ```
92