| |
| |
|
|
| use anyhow::Result; |
| use ort::{ |
| execution_providers::ExecutionProvider, |
| session::{Session, builder::GraphOptimizationLevel}, |
| inputs, value::TensorRef, |
| }; |
| use ndarray::{Array1, Array2, Array4}; |
|
|
| |
| pub struct GraniteDoclingONNX { |
| session: Session, |
| } |
|
|
| impl GraniteDoclingONNX { |
| |
| pub fn new(model_path: &str) -> Result<Self> { |
| println!("Loading granite-docling ONNX model from: {}", model_path); |
|
|
| let session = Session::builder()? |
| .with_optimization_level(GraphOptimizationLevel::Level3)? |
| .with_execution_providers([ |
| ExecutionProvider::DirectML, |
| ExecutionProvider::CUDA, |
| ExecutionProvider::CPU, |
| ])? |
| .commit_from_file(model_path)?; |
|
|
| |
| println!("Model loaded successfully:"); |
| for (i, input) in session.inputs()?.iter().enumerate() { |
| println!(" Input {}: {} {:?}", i, input.name(), input.input_type()); |
| } |
| for (i, output) in session.outputs()?.iter().enumerate() { |
| println!(" Output {}: {} {:?}", i, output.name(), output.output_type()); |
| } |
|
|
| Ok(Self { session }) |
| } |
|
|
| |
| pub async fn process_document( |
| &self, |
| document_image: Array4<f32>, |
| prompt: &str, |
| ) -> Result<String> { |
|
|
| println!("Processing document with granite-docling..."); |
|
|
| |
| let input_ids = self.tokenize_prompt(prompt)?; |
| let attention_mask = Array2::ones((1, input_ids.len())); |
|
|
| |
| let input_ids_2d = Array2::from_shape_vec( |
| (1, input_ids.len()), |
| input_ids.iter().map(|&x| x as i64).collect(), |
| )?; |
|
|
| |
| let outputs = self.session.run(inputs![ |
| "pixel_values" => TensorRef::from_array_view(&document_image.view())?, |
| "input_ids" => TensorRef::from_array_view(&input_ids_2d.view())?, |
| "attention_mask" => TensorRef::from_array_view(&attention_mask.view())?, |
| ])?; |
|
|
| |
| let logits = outputs["logits"].try_extract_tensor::<f32>()?; |
| let tokens = self.decode_logits_to_tokens(&logits)?; |
| let doctags = self.detokenize_to_doctags(&tokens)?; |
|
|
| println!("✅ Document processing complete"); |
| Ok(doctags) |
| } |
|
|
| |
| fn tokenize_prompt(&self, prompt: &str) -> Result<Vec<u32>> { |
| |
| |
| let tokens: Vec<u32> = prompt |
| .split_whitespace() |
| .enumerate() |
| .map(|(i, _)| (i + 1) as u32) |
| .collect(); |
|
|
| Ok(tokens) |
| } |
|
|
| |
| fn decode_logits_to_tokens(&self, logits: &ndarray::ArrayViewD<f32>) -> Result<Vec<u32>> { |
| |
| let tokens: Vec<u32> = logits |
| .axis_iter(ndarray::Axis(2)) |
| .map(|logit_slice| { |
| logit_slice |
| .iter() |
| .enumerate() |
| .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap()) |
| .map(|(idx, _)| idx as u32) |
| .unwrap_or(0) |
| }) |
| .collect(); |
|
|
| Ok(tokens) |
| } |
|
|
| |
| fn detokenize_to_doctags(&self, tokens: &[u32]) -> Result<String> { |
| |
| |
|
|
| |
| let mock_doctags = format!( |
| "<doctag>\n <text>Document processed with {} tokens</text>\n</doctag>", |
| tokens.len() |
| ); |
|
|
| Ok(mock_doctags) |
| } |
| } |
|
|
| |
| pub fn preprocess_document_image(image_path: &str) -> Result<Array4<f32>> { |
| |
| |
| |
|
|
| |
| let document_image = Array4::zeros((1, 3, 512, 512)); |
|
|
| Ok(document_image) |
| } |
|
|
| #[tokio::main] |
| async fn main() -> Result<()> { |
| println!("granite-docling ONNX Rust Example"); |
|
|
| |
| let model_path = "granite-docling-258M-onnx/model.onnx"; |
| let granite_docling = GraniteDoclingONNX::new(model_path)?; |
|
|
| |
| let document_image = preprocess_document_image("example_document.png")?; |
|
|
| |
| let prompt = "Convert this document to DocTags:"; |
| let doctags = granite_docling.process_document(document_image, prompt).await?; |
|
|
| println!("Generated DocTags:"); |
| println!("{}", doctags); |
|
|
| Ok(()) |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |