3v324v23 commited on
Commit
ce5b66d
·
1 Parent(s): 7b9c753

Add Table of Contents (TOC) RAG system with new modules, database enhancements, and UI integration.

Browse files
main.py CHANGED
@@ -16,6 +16,11 @@ from src.vector_rag.retriever import retrieve as vector_retrieve
16
  from src.vector_rag.llm_engine import chat_with_rag as vector_chat_with_rag, chat_no_context as vector_chat_no_context
17
  from src.vector_rag.indexer import index_all_documents as vector_index_all
18
 
 
 
 
 
 
19
  from src.cost_tracker import add_chat_tokens, get_token_summary, clear_stats
20
 
21
  app = FastAPI(title="Local RAG Comparator")
@@ -118,10 +123,31 @@ def get_document_chunks(filename: str):
118
  except Exception as e:
119
  print(f"Error fetching vector chunks: {e}")
120
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  return {
122
  "filename": filename,
123
  "page_chunks": page_chunks,
124
- "vector_chunks": vector_chunks
 
125
  }
126
 
127
  @app.get("/api/index/status")
@@ -193,6 +219,36 @@ def chat_vector(req: ChatRequest):
193
 
194
  return StreamingResponse(generate(), media_type="text/event-stream")
195
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  @app.post("/api/upload")
197
  async def upload_files(background_tasks: BackgroundTasks, files: list[UploadFile] = File(...)):
198
  global is_indexing
@@ -217,6 +273,8 @@ async def upload_files(background_tasks: BackgroundTasks, files: list[UploadFile
217
  page_index_all(progress_callback=progress_cb)
218
  progress_cb("Auto-indexing: Starting Vector RAG...")
219
  vector_index_all(progress_callback=progress_cb)
 
 
220
  progress_cb("Successfully indexed all documents.")
221
  except Exception as e:
222
  progress_cb(f"Error during indexing: {str(e)}")
@@ -247,6 +305,8 @@ def index_documents(background_tasks: BackgroundTasks):
247
  page_index_all(progress_callback=progress_cb)
248
  progress_cb("Starting Vector Indexing...")
249
  vector_index_all(progress_callback=progress_cb)
 
 
250
  progress_cb("Indexing completed.")
251
  except Exception as e:
252
  progress_cb(f"Error: {str(e)}")
@@ -259,16 +319,9 @@ def index_documents(background_tasks: BackgroundTasks):
259
 
260
  @app.post("/api/index/clear")
261
  def clear_index():
262
- # remove both collections and underlying storage directory
263
- from src.shared_db import get_chroma_client, clear_chroma_storage
264
-
265
- # Clear ChromaDB collections via client API
266
- client = get_chroma_client()
267
- collections = client.list_collections()
268
- for coll in collections:
269
- client.delete_collection(name=coll.name)
270
 
271
- # Purge ondisk database files so we start fresh next time
272
  clear_chroma_storage()
273
 
274
  # Clear any app-specific stats
 
16
  from src.vector_rag.llm_engine import chat_with_rag as vector_chat_with_rag, chat_no_context as vector_chat_no_context
17
  from src.vector_rag.indexer import index_all_documents as vector_index_all
18
 
19
+ # TOC RAG imports
20
+ from src.toc_rag.retriever import retrieve as toc_retrieve
21
+ from src.toc_rag.llm_engine import chat_with_rag as toc_chat_with_rag, chat_no_context as toc_chat_no_context
22
+ from src.toc_rag.indexer import index_all_documents as toc_index_all
23
+
24
  from src.cost_tracker import add_chat_tokens, get_token_summary, clear_stats
25
 
26
  app = FastAPI(title="Local RAG Comparator")
 
123
  except Exception as e:
124
  print(f"Error fetching vector chunks: {e}")
125
 
126
+ # ── TOC Index chunks ──
127
+ toc_chunks = []
128
+ try:
129
+ toc_collection = get_collection("toc_index")
130
+ toc_results = toc_collection.get(where={"source": filename})
131
+
132
+ if toc_results and toc_results["ids"]:
133
+ for i in range(len(toc_results["ids"])):
134
+ toc_chunks.append({
135
+ "id": toc_results["ids"][i],
136
+ "text": toc_results["documents"][i] if toc_results["documents"] else "",
137
+ "node_id": toc_results["metadatas"][i].get("node_id", "") if toc_results["metadatas"] else "",
138
+ "title": toc_results["metadatas"][i].get("title", "") if toc_results["metadatas"] else "",
139
+ "tokens": toc_results["metadatas"][i].get("tokens", 0) if toc_results["metadatas"] else 0
140
+ })
141
+ toc_chunks.sort(key=lambda x: x["node_id"])
142
+
143
+ except Exception as e:
144
+ print(f"Error fetching toc chunks: {e}")
145
+
146
  return {
147
  "filename": filename,
148
  "page_chunks": page_chunks,
149
+ "vector_chunks": vector_chunks,
150
+ "toc_chunks": toc_chunks
151
  }
152
 
153
  @app.get("/api/index/status")
 
219
 
220
  return StreamingResponse(generate(), media_type="text/event-stream")
221
 
222
+ @app.post("/api/chat/toc")
223
+ def chat_toc(req: ChatRequest):
224
+ def generate():
225
+ try:
226
+ nodes = toc_retrieve(req.query, top_k=req.top_k)
227
+ source_data = [
228
+ {
229
+ "source": n.source,
230
+ "node_id": n.node_id,
231
+ "title": n.title,
232
+ "score": n.score,
233
+ "text": n.text,
234
+ "tokens": n.tokens,
235
+ }
236
+ for n in nodes
237
+ ]
238
+ yield f"data: {json.dumps({'type': 'sources', 'sources': source_data})}\n\n"
239
+
240
+ generator = toc_chat_with_rag(req.query, nodes) if nodes else toc_chat_no_context(req.query)
241
+
242
+ for chunk in generator:
243
+ if chunk.get("type") == "stats":
244
+ add_chat_tokens(chunk.get("prompt_eval_count", 0), chunk.get("eval_count", 0))
245
+ yield f"data: {json.dumps(chunk)}\n\n"
246
+
247
+ except Exception as e:
248
+ yield f"data: {json.dumps({'type': 'error', 'content': str(e)})}\n\n"
249
+
250
+ return StreamingResponse(generate(), media_type="text/event-stream")
251
+
252
  @app.post("/api/upload")
253
  async def upload_files(background_tasks: BackgroundTasks, files: list[UploadFile] = File(...)):
254
  global is_indexing
 
273
  page_index_all(progress_callback=progress_cb)
274
  progress_cb("Auto-indexing: Starting Vector RAG...")
275
  vector_index_all(progress_callback=progress_cb)
276
+ progress_cb("Auto-indexing: Starting TOC RAG...")
277
+ toc_index_all(progress_callback=progress_cb)
278
  progress_cb("Successfully indexed all documents.")
279
  except Exception as e:
280
  progress_cb(f"Error during indexing: {str(e)}")
 
305
  page_index_all(progress_callback=progress_cb)
306
  progress_cb("Starting Vector Indexing...")
307
  vector_index_all(progress_callback=progress_cb)
308
+ progress_cb("Starting TOC Indexing...")
309
+ toc_index_all(progress_callback=progress_cb)
310
  progress_cb("Indexing completed.")
311
  except Exception as e:
312
  progress_cb(f"Error: {str(e)}")
 
319
 
320
  @app.post("/api/index/clear")
321
  def clear_index():
322
+ from src.shared_db import clear_chroma_storage
 
 
 
 
 
 
 
323
 
324
+ # Purge on-disk database files and re-initialize a fresh client
325
  clear_chroma_storage()
326
 
327
  # Clear any app-specific stats
src/cost_tracker.py CHANGED
@@ -17,6 +17,7 @@ def _default_stats():
17
  "embedding_tokens": 0,
18
  "prompt_tokens": 0,
19
  "completion_tokens": 0,
 
20
  }
21
 
22
 
@@ -55,6 +56,15 @@ def add_chat_tokens(prompt_count: int, completion_count: int):
55
  save_stats(stats)
56
 
57
 
 
 
 
 
 
 
 
 
 
58
  def get_token_summary():
59
  """Return raw token counts only – no monetary costs."""
60
  with _lock:
@@ -63,6 +73,7 @@ def get_token_summary():
63
  stats.get("embedding_tokens", 0)
64
  + stats.get("prompt_tokens", 0)
65
  + stats.get("completion_tokens", 0)
 
66
  )
67
  return {
68
  "tokens": stats,
 
17
  "embedding_tokens": 0,
18
  "prompt_tokens": 0,
19
  "completion_tokens": 0,
20
+ "toc_analysis_tokens": 0,
21
  }
22
 
23
 
 
56
  save_stats(stats)
57
 
58
 
59
+ def add_toc_analysis_tokens(count: int):
60
+ if count <= 0:
61
+ return
62
+ with _lock:
63
+ stats = load_stats()
64
+ stats["toc_analysis_tokens"] = stats.get("toc_analysis_tokens", 0) + count
65
+ save_stats(stats)
66
+
67
+
68
  def get_token_summary():
69
  """Return raw token counts only – no monetary costs."""
70
  with _lock:
 
73
  stats.get("embedding_tokens", 0)
74
  + stats.get("prompt_tokens", 0)
75
  + stats.get("completion_tokens", 0)
76
+ + stats.get("toc_analysis_tokens", 0)
77
  )
78
  return {
79
  "tokens": stats,
src/shared_db.py CHANGED
@@ -42,12 +42,21 @@ def get_collection(name: str):
42
  def clear_chroma_storage():
43
  """Remove all persistent chroma data on disk and reset the client."""
44
  global _client
45
- # If client exists, try closing it cleanly (PersistentClient doesn't provide close) by deleting reference
46
- _client = None
47
- try:
48
- shutil.rmtree(CHROMA_PATH)
49
- Path(CHROMA_PATH).mkdir(exist_ok=True, parents=True)
50
- print(f"Cleared chroma storage directory: {CHROMA_PATH}")
51
- except Exception as e:
52
- print(f"Error clearing chroma storage: {e}")
53
- raise
 
 
 
 
 
 
 
 
 
 
42
  def clear_chroma_storage():
43
  """Remove all persistent chroma data on disk and reset the client."""
44
  global _client
45
+ with _lock:
46
+ # Drop the old client reference so SQLite file handles are released
47
+ _client = None
48
+
49
+ try:
50
+ shutil.rmtree(CHROMA_PATH)
51
+ Path(CHROMA_PATH).mkdir(exist_ok=True, parents=True)
52
+ print(f"Cleared chroma storage directory: {CHROMA_PATH}")
53
+ except Exception as e:
54
+ print(f"Error clearing chroma storage: {e}")
55
+ raise
56
+
57
+ # Eagerly create a fresh client so subsequent calls don't hit a stale db
58
+ try:
59
+ _client = chromadb.PersistentClient(path=CHROMA_PATH)
60
+ print("Fresh ChromaDB Client re-initialized after clear.")
61
+ except Exception as e:
62
+ print(f"Warning: could not re-init ChromaDB Client: {e}")
src/toc_rag/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # TOC-based Index RAG module
src/toc_rag/indexer.py ADDED
@@ -0,0 +1,361 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ indexer.py
3
+ ----------
4
+ Main indexer for the TOC-based Index RAG.
5
+ Parses documents into hierarchical tree structures,
6
+ embeds each node's text, and stores in ChromaDB.
7
+
8
+ Supports PDF (via tree_builder), MD/TXT (via md_parser).
9
+ """
10
+
11
+ import os
12
+ import asyncio
13
+ from pathlib import Path
14
+ from typing import Generator
15
+
16
+ import fitz # PyMuPDF
17
+ import ollama
18
+ from docx import Document
19
+ from rich.console import Console
20
+ from rich.progress import track
21
+ from src.shared_db import get_collection as get_shared_collection
22
+ from src.cost_tracker import add_embedding_tokens
23
+
24
+ import tiktoken
25
+ _tokenizer = tiktoken.get_encoding("cl100k_base")
26
+
27
+ from .toc_utils import (
28
+ get_page_tokens, JsonLogger, ConfigLoader,
29
+ write_node_id, add_node_text, remove_structure_text,
30
+ structure_to_list, format_structure,
31
+ generate_summaries_for_structure,
32
+ create_clean_structure_for_description, generate_doc_description,
33
+ get_pdf_name, count_tokens, LLM_MODEL,
34
+ )
35
+ from .tree_builder import tree_parser
36
+ from .md_parser import md_to_tree
37
+
38
+ console = Console()
39
+
40
+ # ─── Config ────────────────────────────────────────────────────────────────────
41
+ EMBED_MODEL = "nomic-embed-text"
42
+ COLLECTION = "toc_index"
43
+ DOCS_FOLDER = "./documents"
44
+ # ───────────────────────────────────────────────────────────────────────────────
45
+
46
+
47
+
48
+ def get_chroma_collection():
49
+ return get_shared_collection(COLLECTION)
50
+
51
+
52
+ CHUNK_CHAR_LIMIT = 6000 # chars per sub-chunk (~1500 tokens, well within 8192 model limit)
53
+ CHUNK_OVERLAP_CHARS = 1000 # overlap between consecutive sub-chunks
54
+
55
+
56
+ def _chunk_text_for_embedding(text: str, header: str = "") -> list[str]:
57
+ """Split long text into overlapping sub-chunks for embedding.
58
+
59
+ Each chunk is prefixed with *header* (e.g. section title + summary) so
60
+ the embedding captures the node context even for chunks deep in the text.
61
+
62
+ Short texts (≤ CHUNK_CHAR_LIMIT) are returned as a single-item list.
63
+ """
64
+ # Header is always prepended; account for its length
65
+ available = CHUNK_CHAR_LIMIT - len(header)
66
+ if available < 500:
67
+ available = 500 # safety floor
68
+
69
+ # If text fits in one chunk, return as-is
70
+ if len(text) <= available:
71
+ return [f"{header}{text}".strip()] if header else [text]
72
+
73
+ chunks: list[str] = []
74
+ start = 0
75
+ while start < len(text):
76
+ end = start + available
77
+ chunk_body = text[start:end]
78
+ full_chunk = f"{header}{chunk_body}".strip() if header else chunk_body
79
+ chunks.append(full_chunk)
80
+ # Advance with overlap
81
+ start += available - CHUNK_OVERLAP_CHARS
82
+ if start >= len(text):
83
+ break
84
+
85
+ return chunks
86
+
87
+
88
+ def embed_text(text: str) -> tuple[list[float], int]:
89
+ """Generate embedding locally using Ollama.
90
+ Text MUST already be chunked to fit the model context window."""
91
+ try:
92
+ response = ollama.embeddings(model=EMBED_MODEL, prompt=text)
93
+ except Exception as e:
94
+ print(f"[TOC-RAG] Embed error (len={len(text)} chars): {e}")
95
+ raise
96
+ exact_tokens = len(_tokenizer.encode(text))
97
+ add_embedding_tokens(exact_tokens)
98
+ return response["embedding"], exact_tokens
99
+
100
+
101
+ def flatten_tree(structure, parent_title=""):
102
+ """Flatten tree structure into a list of indexable nodes."""
103
+ nodes = []
104
+ for node in structure:
105
+ title = node.get('title', '')
106
+ full_title = f"{parent_title} > {title}" if parent_title else title
107
+ text = node.get('text', '')
108
+ summary = node.get('summary', node.get('prefix_summary', ''))
109
+
110
+ # Build the indexable text: title + summary + text content
111
+ indexable_text = f"Section: {full_title}\n"
112
+ if summary:
113
+ indexable_text += f"Summary: {summary}\n"
114
+ if text:
115
+ indexable_text += f"Content:\n{text}"
116
+
117
+ nodes.append({
118
+ 'node_id': node.get('node_id', ''),
119
+ 'title': title,
120
+ 'full_title': full_title,
121
+ 'text': indexable_text.strip(),
122
+ 'summary': summary,
123
+ 'start_index': node.get('start_index'),
124
+ 'end_index': node.get('end_index'),
125
+ 'line_num': node.get('line_num'),
126
+ })
127
+
128
+ if 'nodes' in node and node['nodes']:
129
+ nodes.extend(flatten_tree(node['nodes'], full_title))
130
+
131
+ return nodes
132
+
133
+
134
+ def _build_tree_for_pdf(filepath: str, progress_callback=None):
135
+ """Build tree structure for a PDF using the full TOC pipeline."""
136
+ logger = JsonLogger(filepath)
137
+ opt = ConfigLoader().load({
138
+ 'model': LLM_MODEL,
139
+ 'if_add_node_id': 'yes',
140
+ 'if_add_node_summary': 'yes',
141
+ 'if_add_doc_description': 'no',
142
+ 'if_add_node_text': 'yes',
143
+ })
144
+
145
+ if progress_callback:
146
+ progress_callback("TOC RAG: Parsing PDF pages...")
147
+
148
+ page_list = get_page_tokens(filepath)
149
+ logger.info({'total_pages': len(page_list), 'total_tokens': sum(p[1] for p in page_list)})
150
+
151
+ async def build():
152
+ if progress_callback:
153
+ progress_callback("TOC RAG: Building document tree structure...")
154
+ structure = await tree_parser(page_list, opt, doc=filepath, logger=logger)
155
+
156
+ if progress_callback:
157
+ progress_callback("TOC RAG: Assigning node IDs...")
158
+ write_node_id(structure)
159
+
160
+ if progress_callback:
161
+ progress_callback("TOC RAG: Extracting node text...")
162
+ add_node_text(structure, page_list)
163
+
164
+ if progress_callback:
165
+ progress_callback("TOC RAG: Generating summaries...")
166
+ await generate_summaries_for_structure(structure, model=opt.model)
167
+
168
+ return structure
169
+
170
+ # Run async pipeline
171
+ try:
172
+ loop = asyncio.get_event_loop()
173
+ if loop.is_running():
174
+ # If we're already in an async context, create a new loop in a thread
175
+ import concurrent.futures
176
+ with concurrent.futures.ThreadPoolExecutor() as pool:
177
+ future = pool.submit(asyncio.run, build())
178
+ structure = future.result()
179
+ else:
180
+ structure = asyncio.run(build())
181
+ except RuntimeError:
182
+ structure = asyncio.run(build())
183
+
184
+ return structure
185
+
186
+
187
+ def _build_tree_for_text(filepath: str, progress_callback=None):
188
+ """Build tree structure for text/markdown files."""
189
+ if progress_callback:
190
+ progress_callback(f"TOC RAG: Parsing {Path(filepath).suffix} file...")
191
+
192
+ async def build():
193
+ return await md_to_tree(
194
+ filepath,
195
+ if_thinning=True,
196
+ min_token_threshold=5000,
197
+ if_add_node_summary='yes',
198
+ summary_token_threshold=200,
199
+ model=LLM_MODEL,
200
+ if_add_node_text='yes',
201
+ if_add_node_id='yes',
202
+ )
203
+
204
+ try:
205
+ loop = asyncio.get_event_loop()
206
+ if loop.is_running():
207
+ import concurrent.futures
208
+ with concurrent.futures.ThreadPoolExecutor() as pool:
209
+ future = pool.submit(asyncio.run, build())
210
+ result = future.result()
211
+ else:
212
+ result = asyncio.run(build())
213
+ except RuntimeError:
214
+ result = asyncio.run(build())
215
+
216
+ return result.get('structure', [])
217
+
218
+
219
+ def index_document(filepath: str, collection, progress_callback=None) -> int:
220
+ """Index all nodes of a single document. Returns node count indexed."""
221
+ ext = Path(filepath).suffix.lower()
222
+
223
+ if ext == '.pdf':
224
+ structure = _build_tree_for_pdf(filepath, progress_callback)
225
+ elif ext in ('.md', '.txt'):
226
+ structure = _build_tree_for_text(filepath, progress_callback)
227
+ elif ext == '.docx':
228
+ # Convert DOCX to text, then treat as text
229
+ doc = Document(filepath)
230
+ temp_txt = Path(filepath).with_suffix('.tmp.txt')
231
+ with open(temp_txt, 'w', encoding='utf-8') as f:
232
+ f.write('\n'.join(p.text for p in doc.paragraphs))
233
+ structure = _build_tree_for_text(str(temp_txt), progress_callback)
234
+ temp_txt.unlink(missing_ok=True)
235
+ else:
236
+ console.print(f"[yellow]Skipping unsupported file for TOC indexing: {filepath}[/yellow]")
237
+ return 0
238
+
239
+ if not structure:
240
+ console.print(f"[red]No structure extracted from {filepath}[/red]")
241
+ return 0
242
+
243
+ # Flatten tree to indexable nodes
244
+ flat_nodes = flatten_tree(structure)
245
+ if not flat_nodes:
246
+ console.print(f"[red]No nodes to index from {filepath}[/red]")
247
+ return 0
248
+
249
+ filename = Path(filepath).name
250
+ ids, embeddings, documents, metadatas = [], [], [], []
251
+ total = len(flat_nodes)
252
+ chunks_total = 0
253
+
254
+ for i, node in enumerate(track(flat_nodes, description=f" Embedding TOC nodes for {filename}"), 1):
255
+ if progress_callback:
256
+ progress_callback(f"TOC RAG: Embedding {filename} node {i}/{total}")
257
+
258
+ text = node['text']
259
+ if not text.strip():
260
+ continue
261
+
262
+ # Build a short header that gets prepended to every sub-chunk
263
+ header = f"Section: {node['full_title']}\n"
264
+ if node.get('summary'):
265
+ header += f"Summary: {node['summary']}\n"
266
+ header += "Content:\n"
267
+
268
+ # Split into sub-chunks (single chunk if short enough)
269
+ raw_content = node.get('text', '')
270
+ # Remove the header we already built from raw_content if flatten_tree already added it
271
+ # (flatten_tree builds the full indexable text with header)
272
+ content_only = raw_content
273
+ for prefix in [f"Section: {node['full_title']}\n", f"Summary: {node.get('summary','')}\n", "Content:\n"]:
274
+ if content_only.startswith(prefix):
275
+ content_only = content_only[len(prefix):]
276
+
277
+ sub_chunks = _chunk_text_for_embedding(content_only, header=header)
278
+
279
+ for chunk_idx, chunk_text in enumerate(sub_chunks):
280
+ if len(sub_chunks) == 1:
281
+ doc_id = f"{Path(filepath).stem}__toc_node_{node['node_id']}"
282
+ else:
283
+ doc_id = f"{Path(filepath).stem}__toc_node_{node['node_id']}_chunk_{chunk_idx}"
284
+
285
+ # Skip if already indexed
286
+ existing = collection.get(ids=[doc_id])
287
+ if existing["ids"]:
288
+ continue
289
+
290
+ try:
291
+ emb, tokens = embed_text(chunk_text)
292
+ except Exception as e:
293
+ console.print(f" [red]Skipping node {node['node_id']} chunk {chunk_idx}: {e}[/red]")
294
+ continue
295
+
296
+ ids.append(doc_id)
297
+ embeddings.append(emb)
298
+ documents.append(chunk_text)
299
+ metadatas.append({
300
+ "source": filename,
301
+ "node_id": node['node_id'],
302
+ "title": node.get('title', ''),
303
+ "full_title": node.get('full_title', ''),
304
+ "doc_type": Path(filepath).suffix.lower().replace('.', ''),
305
+ "tokens": tokens,
306
+ "summary": node.get('summary', ''),
307
+ "chunk_index": chunk_idx,
308
+ "total_chunks": len(sub_chunks),
309
+ })
310
+ chunks_total += 1
311
+
312
+ if ids:
313
+ collection.add(
314
+ ids=ids,
315
+ embeddings=embeddings,
316
+ documents=documents,
317
+ metadatas=metadatas,
318
+ )
319
+
320
+ console.print(f" [dim]{chunks_total} embedding(s) from {total} node(s)[/dim]")
321
+ return chunks_total
322
+
323
+
324
+ def index_all_documents(docs_folder: str = DOCS_FOLDER, progress_callback=None) -> None:
325
+ """Scan the documents folder and index everything with TOC structure."""
326
+ collection = get_chroma_collection()
327
+ folder = Path(docs_folder)
328
+ supported = [".pdf", ".txt", ".docx", ".md"]
329
+ files = [f for f in folder.iterdir() if f.suffix.lower() in supported]
330
+
331
+ if not files:
332
+ if progress_callback:
333
+ progress_callback("No documents to index for TOC RAG.")
334
+ console.print(f"[yellow]No supported files found in {docs_folder}[/yellow]")
335
+ return
336
+
337
+ console.print(f"\n[bold cyan]🌳 TOC RAG: Found {len(files)} document(s) to index[/bold cyan]\n")
338
+
339
+ total_nodes = 0
340
+ for filepath in files:
341
+ console.print(f"[green]→ TOC Processing:[/green] {filepath.name}")
342
+ try:
343
+ count = index_document(str(filepath), collection, progress_callback)
344
+ total_nodes += count
345
+ console.print(f" [dim]Indexed {count} node(s)[/dim]\n")
346
+ except Exception as e:
347
+ console.print(f" [red]Error indexing {filepath.name}: {e}[/red]\n")
348
+
349
+ if progress_callback:
350
+ progress_callback("TOC Indexing Complete.")
351
+
352
+ console.print(f"[bold green]✅ TOC RAG Done! Total nodes indexed: {total_nodes}[/bold green]")
353
+ console.print(f"[dim]ChromaDB collection size: {collection.count()} nodes[/dim]")
354
+
355
+
356
+ if __name__ == "__main__":
357
+ console.print("[bold]🌳 TOC Indexer — Local RAG[/bold]")
358
+ console.print(f"Embed model : {EMBED_MODEL}")
359
+ console.print(f"LLM model : {LLM_MODEL}")
360
+ console.print(f"Docs folder : {DOCS_FOLDER}\n")
361
+ index_all_documents()
src/toc_rag/llm_engine.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ llm_engine.py
3
+ -------------
4
+ Wraps Ollama to call your local LLM model.
5
+ Builds a RAG prompt using retrieved TOC nodes as context.
6
+ Streams response with token tracking.
7
+ """
8
+
9
+ from typing import Generator
10
+ import ollama
11
+ from .retriever import RetrievedNode, build_context
12
+
13
+ # ─── Config ────────────────────────────────────────────────────────────────────
14
+ LLM_MODEL = "llama3.2:1b"
15
+ # ───────────────────────────────────────────────────────────────────────────────
16
+
17
+ SYSTEM_PROMPT = """You are a helpful document assistant using tree-structured document indexing.
18
+ Answer the user's question using ONLY the provided context sections.
19
+ Always cite your sources by mentioning the filename and section title.
20
+ If the answer is not found in the context, say so honestly — do not make up information.
21
+ """
22
+
23
+
24
+ def build_rag_prompt(query: str, nodes: list[RetrievedNode]) -> str:
25
+ context = build_context(nodes)
26
+ return f"""Use the following document sections to answer the question.
27
+
28
+ === CONTEXT (Tree-Indexed Sections) ===
29
+ {context}
30
+
31
+ === QUESTION ===
32
+ {query}
33
+
34
+ === ANSWER ==="""
35
+
36
+
37
+ def chat_with_rag(
38
+ query: str,
39
+ nodes: list[RetrievedNode],
40
+ chat_history: list[dict] | None = None,
41
+ ) -> Generator[dict, None, None]:
42
+ """
43
+ Stream a response from the local LLM using retrieved tree nodes as context.
44
+ Yields text chunks as they arrive (streaming).
45
+ """
46
+ messages = [{"role": "system", "content": SYSTEM_PROMPT}]
47
+
48
+ if chat_history:
49
+ messages.extend(chat_history)
50
+
51
+ rag_prompt = build_rag_prompt(query, nodes)
52
+ messages.append({"role": "user", "content": rag_prompt})
53
+
54
+ stream = ollama.chat(
55
+ model=LLM_MODEL,
56
+ messages=messages,
57
+ stream=True,
58
+ )
59
+
60
+ for chunk in stream:
61
+ token = chunk.get("message", {}).get("content", "")
62
+ if token:
63
+ yield {"type": "token", "content": token}
64
+
65
+ if chunk.get("done"):
66
+ yield {
67
+ "type": "stats",
68
+ "prompt_eval_count": chunk.get("prompt_eval_count", 0),
69
+ "eval_count": chunk.get("eval_count", 0)
70
+ }
71
+
72
+
73
+ def chat_no_context(query: str, chat_history: list[dict] | None = None) -> Generator[dict, None, None]:
74
+ """Fallback: LLM response when no relevant nodes are found."""
75
+ messages = [{"role": "system", "content": SYSTEM_PROMPT}]
76
+ if chat_history:
77
+ messages.extend(chat_history)
78
+ messages.append({
79
+ "role": "user",
80
+ "content": f"{query}\n\n(No relevant document sections were found for this query.)"
81
+ })
82
+ stream = ollama.chat(model=LLM_MODEL, messages=messages, stream=True)
83
+ for chunk in stream:
84
+ token = chunk.get("message", {}).get("content", "")
85
+ if token:
86
+ yield {"type": "token", "content": token}
87
+
88
+ if chunk.get("done"):
89
+ yield {
90
+ "type": "stats",
91
+ "prompt_eval_count": chunk.get("prompt_eval_count", 0),
92
+ "eval_count": chunk.get("eval_count", 0)
93
+ }
src/toc_rag/md_parser.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ md_parser.py
3
+ ------------
4
+ Markdown-to-tree parsing for .md and .txt files.
5
+ Extracts headers, builds hierarchical tree, supports thinning
6
+ and summary generation.
7
+
8
+ Replica of the reference page-indexing repository.
9
+ """
10
+
11
+ import os
12
+ import re
13
+ import json
14
+ import asyncio
15
+
16
+ from .toc_utils import (
17
+ count_tokens, write_node_id, format_structure,
18
+ structure_to_list, create_clean_structure_for_description,
19
+ generate_doc_description, generate_node_summary,
20
+ ChatGPT_API_async,
21
+ )
22
+
23
+
24
+ # ─── Summary Helpers ──────────────────────────────────────────────────────────
25
+
26
+ async def get_node_summary(node, summary_token_threshold=200, model=None):
27
+ node_text = node.get('text')
28
+ num_tokens = count_tokens(node_text, model=model)
29
+ if num_tokens < summary_token_threshold:
30
+ return node_text
31
+ else:
32
+ return await generate_node_summary(node, model=model)
33
+
34
+
35
+ async def generate_summaries_for_structure_md(structure, summary_token_threshold, model=None):
36
+ nodes = structure_to_list(structure)
37
+ tasks = [get_node_summary(node, summary_token_threshold=summary_token_threshold, model=model) for node in nodes]
38
+ summaries = await asyncio.gather(*tasks)
39
+
40
+ for node, summary in zip(nodes, summaries):
41
+ if not node.get('nodes'):
42
+ node['summary'] = summary
43
+ else:
44
+ node['prefix_summary'] = summary
45
+ return structure
46
+
47
+
48
+ # ─── Markdown Extraction ──────────────────────────────────────────────────────
49
+
50
+ def extract_nodes_from_markdown(markdown_content):
51
+ """Find headers in markdown, skipping code blocks."""
52
+ header_pattern = r'^(#{1,6})\s+(.+)$'
53
+ code_block_pattern = r'^```'
54
+ node_list = []
55
+
56
+ lines = markdown_content.split('\n')
57
+ in_code_block = False
58
+
59
+ for line_num, line in enumerate(lines, 1):
60
+ stripped_line = line.strip()
61
+
62
+ if re.match(code_block_pattern, stripped_line):
63
+ in_code_block = not in_code_block
64
+ continue
65
+
66
+ if not stripped_line:
67
+ continue
68
+
69
+ if not in_code_block:
70
+ match = re.match(header_pattern, stripped_line)
71
+ if match:
72
+ title = match.group(2).strip()
73
+ node_list.append({'node_title': title, 'line_num': line_num})
74
+
75
+ return node_list, lines
76
+
77
+
78
+ def extract_node_text_content(node_list, markdown_lines):
79
+ """Extract text content between headers."""
80
+ all_nodes = []
81
+ for node in node_list:
82
+ line_content = markdown_lines[node['line_num'] - 1]
83
+ header_match = re.match(r'^(#{1,6})', line_content)
84
+
85
+ if header_match is None:
86
+ print(f"Warning: Line {node['line_num']} does not contain a valid header: '{line_content}'")
87
+ continue
88
+
89
+ processed_node = {
90
+ 'title': node['node_title'],
91
+ 'line_num': node['line_num'],
92
+ 'level': len(header_match.group(1))
93
+ }
94
+ all_nodes.append(processed_node)
95
+
96
+ for i, node in enumerate(all_nodes):
97
+ start_line = node['line_num'] - 1
98
+ if i + 1 < len(all_nodes):
99
+ end_line = all_nodes[i + 1]['line_num'] - 1
100
+ else:
101
+ end_line = len(markdown_lines)
102
+
103
+ node['text'] = '\n'.join(markdown_lines[start_line:end_line]).strip()
104
+ return all_nodes
105
+
106
+
107
+ # ─── Token Counting / Thinning ─────────────────────────────────────────────────
108
+
109
+ def update_node_list_with_text_token_count(node_list, model=None):
110
+ """Calculate cumulative token counts for nodes including children."""
111
+ def find_all_children(parent_index, parent_level, node_list):
112
+ children_indices = []
113
+ for i in range(parent_index + 1, len(node_list)):
114
+ current_level = node_list[i]['level']
115
+ if current_level <= parent_level:
116
+ break
117
+ children_indices.append(i)
118
+ return children_indices
119
+
120
+ result_list = node_list.copy()
121
+
122
+ for i in range(len(result_list) - 1, -1, -1):
123
+ current_node = result_list[i]
124
+ current_level = current_node['level']
125
+ children_indices = find_all_children(i, current_level, result_list)
126
+
127
+ node_text = current_node.get('text', '')
128
+ total_text = node_text
129
+
130
+ for child_index in children_indices:
131
+ child_text = result_list[child_index].get('text', '')
132
+ if child_text:
133
+ total_text += '\n' + child_text
134
+
135
+ result_list[i]['text_token_count'] = count_tokens(total_text, model=model)
136
+
137
+ return result_list
138
+
139
+
140
+ def tree_thinning_for_index(node_list, min_node_token=None, model=None):
141
+ """Merge small nodes into their parents."""
142
+ def find_all_children(parent_index, parent_level, node_list):
143
+ children_indices = []
144
+ for i in range(parent_index + 1, len(node_list)):
145
+ current_level = node_list[i]['level']
146
+ if current_level <= parent_level:
147
+ break
148
+ children_indices.append(i)
149
+ return children_indices
150
+
151
+ result_list = node_list.copy()
152
+ nodes_to_remove = set()
153
+
154
+ for i in range(len(result_list) - 1, -1, -1):
155
+ if i in nodes_to_remove:
156
+ continue
157
+
158
+ current_node = result_list[i]
159
+ current_level = current_node['level']
160
+ total_tokens = current_node.get('text_token_count', 0)
161
+
162
+ if total_tokens < min_node_token:
163
+ children_indices = find_all_children(i, current_level, result_list)
164
+
165
+ children_texts = []
166
+ for child_index in sorted(children_indices):
167
+ if child_index not in nodes_to_remove:
168
+ child_text = result_list[child_index].get('text', '')
169
+ if child_text.strip():
170
+ children_texts.append(child_text)
171
+ nodes_to_remove.add(child_index)
172
+
173
+ if children_texts:
174
+ parent_text = current_node.get('text', '')
175
+ merged_text = parent_text
176
+ for child_text in children_texts:
177
+ if merged_text and not merged_text.endswith('\n'):
178
+ merged_text += '\n\n'
179
+ merged_text += child_text
180
+
181
+ result_list[i]['text'] = merged_text
182
+ result_list[i]['text_token_count'] = count_tokens(merged_text, model=model)
183
+
184
+ for index in sorted(nodes_to_remove, reverse=True):
185
+ result_list.pop(index)
186
+
187
+ return result_list
188
+
189
+
190
+ # ─── Tree Building ─────────────────────────────────────────────────────────────
191
+
192
+ def build_tree_from_nodes(node_list):
193
+ """Convert flat node list to nested tree structure."""
194
+ if not node_list:
195
+ return []
196
+
197
+ stack = []
198
+ root_nodes = []
199
+ node_counter = 1
200
+
201
+ for node in node_list:
202
+ current_level = node['level']
203
+
204
+ tree_node = {
205
+ 'title': node['title'],
206
+ 'node_id': str(node_counter).zfill(4),
207
+ 'text': node['text'],
208
+ 'line_num': node['line_num'],
209
+ 'nodes': []
210
+ }
211
+ node_counter += 1
212
+
213
+ while stack and stack[-1][1] >= current_level:
214
+ stack.pop()
215
+
216
+ if not stack:
217
+ root_nodes.append(tree_node)
218
+ else:
219
+ parent_node, _ = stack[-1]
220
+ parent_node['nodes'].append(tree_node)
221
+
222
+ stack.append((tree_node, current_level))
223
+
224
+ return root_nodes
225
+
226
+
227
+ # ─── Main Entry ────────────────────────────────────────────────────────────────
228
+
229
+ async def md_to_tree(md_path, if_thinning=False, min_token_threshold=None,
230
+ if_add_node_summary='no', summary_token_threshold=None,
231
+ model=None, if_add_doc_description='no',
232
+ if_add_node_text='no', if_add_node_id='yes'):
233
+ """Top-level async entry for markdown-to-tree conversion."""
234
+ with open(md_path, 'r', encoding='utf-8') as f:
235
+ markdown_content = f.read()
236
+
237
+ print("Extracting nodes from markdown...")
238
+ node_list, markdown_lines = extract_nodes_from_markdown(markdown_content)
239
+
240
+ print("Extracting text content from nodes...")
241
+ nodes_with_content = extract_node_text_content(node_list, markdown_lines)
242
+
243
+ if if_thinning:
244
+ nodes_with_content = update_node_list_with_text_token_count(nodes_with_content, model=model)
245
+ print("Thinning nodes...")
246
+ nodes_with_content = tree_thinning_for_index(nodes_with_content, min_token_threshold, model=model)
247
+
248
+ print("Building tree from nodes...")
249
+ tree_structure = build_tree_from_nodes(nodes_with_content)
250
+
251
+ if if_add_node_id == 'yes':
252
+ write_node_id(tree_structure)
253
+
254
+ print("Formatting tree structure...")
255
+
256
+ if if_add_node_summary == 'yes':
257
+ tree_structure = format_structure(
258
+ tree_structure,
259
+ order=['title', 'node_id', 'summary', 'prefix_summary', 'text', 'line_num', 'nodes']
260
+ )
261
+
262
+ print("Generating summaries for each node...")
263
+ tree_structure = await generate_summaries_for_structure_md(
264
+ tree_structure, summary_token_threshold=summary_token_threshold, model=model
265
+ )
266
+
267
+ if if_add_node_text == 'no':
268
+ tree_structure = format_structure(
269
+ tree_structure,
270
+ order=['title', 'node_id', 'summary', 'prefix_summary', 'line_num', 'nodes']
271
+ )
272
+
273
+ if if_add_doc_description == 'yes':
274
+ print("Generating document description...")
275
+ clean_structure = create_clean_structure_for_description(tree_structure)
276
+ doc_description = generate_doc_description(clean_structure, model=model)
277
+ return {
278
+ 'doc_name': os.path.splitext(os.path.basename(md_path))[0],
279
+ 'doc_description': doc_description,
280
+ 'structure': tree_structure,
281
+ }
282
+ else:
283
+ if if_add_node_text == 'yes':
284
+ tree_structure = format_structure(
285
+ tree_structure,
286
+ order=['title', 'node_id', 'summary', 'prefix_summary', 'text', 'line_num', 'nodes']
287
+ )
288
+ else:
289
+ tree_structure = format_structure(
290
+ tree_structure,
291
+ order=['title', 'node_id', 'summary', 'prefix_summary', 'line_num', 'nodes']
292
+ )
293
+
294
+ return {
295
+ 'doc_name': os.path.splitext(os.path.basename(md_path))[0],
296
+ 'structure': tree_structure,
297
+ }
src/toc_rag/retriever.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ retriever.py
3
+ ------------
4
+ Node-level retrieval for the TOC-based Index RAG.
5
+ Embeds query with Ollama, searches ChromaDB toc_index collection,
6
+ returns top-K nodes with metadata.
7
+ """
8
+
9
+ import ollama
10
+ from dataclasses import dataclass
11
+ from src.shared_db import get_collection as get_shared_collection
12
+ from src.cost_tracker import add_embedding_tokens
13
+
14
+ import tiktoken
15
+ _tokenizer = tiktoken.get_encoding("cl100k_base")
16
+
17
+ # ─── Config ────────────────────────────────────────────────────────────────────
18
+ EMBED_MODEL = "nomic-embed-text"
19
+ COLLECTION = "toc_index"
20
+ TOP_K = 5
21
+ # ───────────────────────────────────────────────────────────────────────────────
22
+
23
+
24
+ @dataclass
25
+ class RetrievedNode:
26
+ text: str
27
+ source: str
28
+ node_id: str
29
+ title: str
30
+ score: float # cosine distance (lower = more similar)
31
+ tokens: int = 0
32
+ summary: str = ""
33
+ full_title: str = ""
34
+
35
+
36
+ def get_collection():
37
+ return get_shared_collection(COLLECTION)
38
+
39
+
40
+ def embed_query(query: str) -> list[float]:
41
+ response = ollama.embeddings(model=EMBED_MODEL, prompt=query)
42
+ exact_tokens = len(_tokenizer.encode(query))
43
+ add_embedding_tokens(exact_tokens)
44
+ return response["embedding"]
45
+
46
+
47
+ def retrieve(query: str, top_k: int = TOP_K) -> list[RetrievedNode]:
48
+ """
49
+ Embed the user query and return the top-k most relevant tree nodes
50
+ from the ChromaDB toc_index collection.
51
+ """
52
+ collection = get_collection()
53
+
54
+ if collection.count() == 0:
55
+ return []
56
+
57
+ query_emb = embed_query(query)
58
+
59
+ results = collection.query(
60
+ query_embeddings=[query_emb],
61
+ n_results=min(top_k, collection.count()),
62
+ include=["documents", "metadatas", "distances"],
63
+ )
64
+
65
+ nodes = []
66
+ for doc, meta, dist in zip(
67
+ results["documents"][0],
68
+ results["metadatas"][0],
69
+ results["distances"][0],
70
+ ):
71
+ nodes.append(RetrievedNode(
72
+ text = doc,
73
+ source = meta.get("source", "unknown"),
74
+ node_id = meta.get("node_id", ""),
75
+ title = meta.get("title", ""),
76
+ score = round(float(dist), 4),
77
+ tokens = int(meta.get("tokens", 0)),
78
+ summary = meta.get("summary", ""),
79
+ full_title = meta.get("full_title", ""),
80
+ ))
81
+
82
+ return nodes
83
+
84
+
85
+ def build_context(nodes: list[RetrievedNode]) -> str:
86
+ """Format retrieved nodes into a context block for the LLM prompt."""
87
+ parts = []
88
+ for i, node in enumerate(nodes, start=1):
89
+ label = node.full_title or node.title
90
+ parts.append(
91
+ f"[Source {i}: {node.source} — Section: {label}]\n{node.text}"
92
+ )
93
+ return "\n\n---\n\n".join(parts)
src/toc_rag/toc_detector.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ toc_detector.py
3
+ ---------------
4
+ TOC detection and extraction functions.
5
+ Scans document pages to find Table of Contents,
6
+ extracts TOC content, and checks for page numbers.
7
+
8
+ Replica of the reference page-indexing repository,
9
+ adapted for local Ollama LLM.
10
+ """
11
+
12
+ import re
13
+ from .toc_utils import (
14
+ ChatGPT_API, ChatGPT_API_with_finish_reason,
15
+ extract_json, get_json_content,
16
+ )
17
+
18
+
19
+ def toc_detector_single_page(content, model=None):
20
+ """Detect if a single page contains a table of contents."""
21
+ prompt = f"""
22
+ Your job is to detect if there is a table of content provided in the given text.
23
+
24
+ Given text: {content}
25
+
26
+ return the following JSON format:
27
+ {{
28
+ "thinking": <why do you think there is a table of content in the given text>
29
+ "toc_detected": "<yes or no>",
30
+ }}
31
+
32
+ Directly return the final JSON structure. Do not output anything else.
33
+ Please note: abstract,summary, notation list, figure list, table list, etc. are not table of contents."""
34
+
35
+ response = ChatGPT_API(model=model, prompt=prompt)
36
+ json_content = extract_json(response)
37
+ return json_content.get('toc_detected', 'no')
38
+
39
+
40
+ def detect_page_index(toc_content, model=None):
41
+ """Detect if TOC has page numbers/indices."""
42
+ print('start detect_page_index')
43
+ prompt = f"""
44
+ You will be given a table of contents.
45
+
46
+ Your job is to detect if there are page numbers/indices given within the table of contents.
47
+
48
+ Given text: {toc_content}
49
+
50
+ Reply format:
51
+ {{
52
+ "thinking": <why do you think there are page numbers/indices given within the table of contents>
53
+ "page_index_given_in_toc": "<yes or no>"
54
+ }}
55
+ Directly return the final JSON structure. Do not output anything else."""
56
+
57
+ response = ChatGPT_API(model=model, prompt=prompt)
58
+ json_content = extract_json(response)
59
+ return json_content.get('page_index_given_in_toc', 'no')
60
+
61
+
62
+ def extract_toc_content(content, model=None):
63
+ """Extract and clean raw TOC text from page content."""
64
+ prompt = f"""
65
+ Your job is to extract the full table of contents from the given text, replace ... with :
66
+
67
+ Given text: {content}
68
+
69
+ Directly return the full table of contents content. Do not output anything else."""
70
+
71
+ response, finish_reason = ChatGPT_API_with_finish_reason(model=model, prompt=prompt)
72
+
73
+ if_complete = check_if_toc_transformation_is_complete(content, response, model)
74
+ if if_complete == "yes" and finish_reason == "finished":
75
+ return response
76
+
77
+ chat_history = [
78
+ {"role": "user", "content": prompt},
79
+ {"role": "assistant", "content": response},
80
+ ]
81
+ prompt = "please continue the generation of table of contents, directly output the remaining part of the structure"
82
+ new_response, finish_reason = ChatGPT_API_with_finish_reason(model=model, prompt=prompt, chat_history=chat_history)
83
+ response = response + new_response
84
+ if_complete = check_if_toc_transformation_is_complete(content, response, model)
85
+
86
+ retry_count = 0
87
+ while not (if_complete == "yes" and finish_reason == "finished"):
88
+ chat_history = [
89
+ {"role": "user", "content": prompt},
90
+ {"role": "assistant", "content": response},
91
+ ]
92
+ prompt = "please continue the generation of table of contents, directly output the remaining part of the structure"
93
+ new_response, finish_reason = ChatGPT_API_with_finish_reason(model=model, prompt=prompt, chat_history=chat_history)
94
+ response = response + new_response
95
+ if_complete = check_if_toc_transformation_is_complete(content, response, model)
96
+
97
+ retry_count += 1
98
+ if retry_count > 5:
99
+ raise Exception('Failed to complete table of contents after maximum retries')
100
+
101
+ return response
102
+
103
+
104
+ def toc_extractor(page_list, toc_page_list, model):
105
+ """Concatenate TOC pages, clean up, and detect if page indices exist."""
106
+ def transform_dots_to_colon(text):
107
+ text = re.sub(r'\.{5,}', ': ', text)
108
+ text = re.sub(r'(?:\. ){5,}\.?', ': ', text)
109
+ return text
110
+
111
+ toc_content = ""
112
+ for page_index in toc_page_list:
113
+ toc_content += page_list[page_index][0]
114
+ toc_content = transform_dots_to_colon(toc_content)
115
+ has_page_index = detect_page_index(toc_content, model=model)
116
+
117
+ return {
118
+ "toc_content": toc_content,
119
+ "page_index_given_in_toc": has_page_index
120
+ }
121
+
122
+
123
+ def find_toc_pages(start_page_index, page_list, opt, logger=None):
124
+ """Scan pages to find TOC pages."""
125
+ print('start find_toc_pages')
126
+ last_page_is_yes = False
127
+ toc_page_list = []
128
+ i = start_page_index
129
+
130
+ while i < len(page_list):
131
+ if i >= opt.toc_check_page_num and not last_page_is_yes:
132
+ break
133
+ detected_result = toc_detector_single_page(page_list[i][0], model=opt.model)
134
+ if detected_result == 'yes':
135
+ if logger:
136
+ logger.info(f'Page {i} has toc')
137
+ toc_page_list.append(i)
138
+ last_page_is_yes = True
139
+ elif detected_result == 'no' and last_page_is_yes:
140
+ if logger:
141
+ logger.info(f'Found the last page with toc: {i-1}')
142
+ break
143
+ i += 1
144
+
145
+ if not toc_page_list and logger:
146
+ logger.info('No toc found')
147
+
148
+ return toc_page_list
149
+
150
+
151
+ def check_toc(page_list, opt=None):
152
+ """Main entry: find TOC pages, extract content, check for page numbers."""
153
+ toc_page_list = find_toc_pages(start_page_index=0, page_list=page_list, opt=opt)
154
+ if len(toc_page_list) == 0:
155
+ print('no toc found')
156
+ return {'toc_content': None, 'toc_page_list': [], 'page_index_given_in_toc': 'no'}
157
+ else:
158
+ print('toc found')
159
+ toc_json = toc_extractor(page_list, toc_page_list, opt.model)
160
+
161
+ if toc_json['page_index_given_in_toc'] == 'yes':
162
+ print('index found')
163
+ return {
164
+ 'toc_content': toc_json['toc_content'],
165
+ 'toc_page_list': toc_page_list,
166
+ 'page_index_given_in_toc': 'yes'
167
+ }
168
+ else:
169
+ current_start_index = toc_page_list[-1] + 1
170
+
171
+ while (toc_json['page_index_given_in_toc'] == 'no' and
172
+ current_start_index < len(page_list) and
173
+ current_start_index < opt.toc_check_page_num):
174
+
175
+ additional_toc_pages = find_toc_pages(
176
+ start_page_index=current_start_index,
177
+ page_list=page_list,
178
+ opt=opt
179
+ )
180
+
181
+ if len(additional_toc_pages) == 0:
182
+ break
183
+
184
+ additional_toc_json = toc_extractor(page_list, additional_toc_pages, opt.model)
185
+ if additional_toc_json['page_index_given_in_toc'] == 'yes':
186
+ print('index found')
187
+ return {
188
+ 'toc_content': additional_toc_json['toc_content'],
189
+ 'toc_page_list': additional_toc_pages,
190
+ 'page_index_given_in_toc': 'yes'
191
+ }
192
+ else:
193
+ current_start_index = additional_toc_pages[-1] + 1
194
+
195
+ print('index not found')
196
+ return {
197
+ 'toc_content': toc_json['toc_content'],
198
+ 'toc_page_list': toc_page_list,
199
+ 'page_index_given_in_toc': 'no'
200
+ }
201
+
202
+
203
+ def check_if_toc_transformation_is_complete(content, toc, model=None):
204
+ """Check if the TOC transformation is complete."""
205
+ prompt = f"""
206
+ You are given a raw table of contents and a table of contents.
207
+ Your job is to check if the table of contents is complete.
208
+
209
+ Reply format:
210
+ {{
211
+ "thinking": <why do you think the cleaned table of contents is complete or not>
212
+ "completed": "yes" or "no"
213
+ }}
214
+ Directly return the final JSON structure. Do not output anything else."""
215
+
216
+ prompt = prompt + '\n Raw Table of contents:\n' + content + '\n Cleaned Table of contents:\n' + toc
217
+ response = ChatGPT_API(model=model, prompt=prompt)
218
+ json_content = extract_json(response)
219
+ return json_content.get('completed', 'no')
src/toc_rag/toc_transformer.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ toc_transformer.py
3
+ ------------------
4
+ TOC-to-JSON transformation and page offset calculation.
5
+ Converts raw TOC text to structured JSON, extracts physical indices,
6
+ and calculates page offsets.
7
+
8
+ Replica of the reference page-indexing repository.
9
+ """
10
+
11
+ import json
12
+ import copy
13
+ import re
14
+ from .toc_utils import (
15
+ ChatGPT_API, ChatGPT_API_with_finish_reason,
16
+ extract_json, get_json_content,
17
+ convert_page_to_int, convert_physical_index_to_int,
18
+ count_tokens,
19
+ )
20
+
21
+
22
+ def check_if_toc_transformation_is_complete(content, toc, model=None):
23
+ """Check if the TOC transformation is complete."""
24
+ prompt = f"""
25
+ You are given a raw table of contents and a table of contents.
26
+ Your job is to check if the table of contents is complete.
27
+
28
+ Reply format:
29
+ {{
30
+ "thinking": <why do you think the cleaned table of contents is complete or not>
31
+ "completed": "yes" or "no"
32
+ }}
33
+ Directly return the final JSON structure. Do not output anything else."""
34
+
35
+ prompt = prompt + '\n Raw Table of contents:\n' + str(content) + '\n Cleaned Table of contents:\n' + str(toc)
36
+ response = ChatGPT_API(model=model, prompt=prompt)
37
+ json_content = extract_json(response)
38
+ return json_content.get('completed', 'no')
39
+
40
+
41
+ def toc_transformer(toc_content, model=None):
42
+ """Convert raw TOC text to structured JSON."""
43
+ print('start toc_transformer')
44
+ init_prompt = """
45
+ You are given a table of contents, You job is to transform the whole table of content into a JSON format included table_of_contents.
46
+
47
+ structure is the numeric system which represents the index of the hierarchy section in the table of contents. For example, the first section has structure index 1, the first subsection has structure index 1.1, the second subsection has structure index 1.2, etc.
48
+
49
+ The response should be in the following JSON format:
50
+ {
51
+ table_of_contents: [
52
+ {
53
+ "structure": <structure index, "x.x.x" or None> (string),
54
+ "title": <title of the section>,
55
+ "page": <page number or None>,
56
+ },
57
+ ...
58
+ ],
59
+ }
60
+ You should transform the full table of contents in one go.
61
+ Directly return the final JSON structure, do not output anything else. """
62
+
63
+ prompt = init_prompt + '\n Given table of contents\n:' + toc_content
64
+ last_complete, finish_reason = ChatGPT_API_with_finish_reason(model=model, prompt=prompt)
65
+ if_complete = check_if_toc_transformation_is_complete(toc_content, last_complete, model)
66
+ if if_complete == "yes" and finish_reason == "finished":
67
+ last_complete = extract_json(last_complete)
68
+ if isinstance(last_complete, dict) and 'table_of_contents' in last_complete:
69
+ cleaned_response = convert_page_to_int(last_complete['table_of_contents'])
70
+ elif isinstance(last_complete, list):
71
+ cleaned_response = convert_page_to_int(last_complete)
72
+ else:
73
+ cleaned_response = []
74
+ return cleaned_response
75
+
76
+ last_complete = get_json_content(last_complete)
77
+ retry_count = 0
78
+ while not (if_complete == "yes" and finish_reason == "finished"):
79
+ position = last_complete.rfind('}')
80
+ if position != -1:
81
+ last_complete = last_complete[:position+2]
82
+ prompt = f"""
83
+ Your task is to continue the table of contents json structure, directly output the remaining part of the json structure.
84
+ The response should be in the following JSON format:
85
+
86
+ The raw table of contents json structure is:
87
+ {toc_content}
88
+
89
+ The incomplete transformed table of contents json structure is:
90
+ {last_complete}
91
+
92
+ Please continue the json structure, directly output the remaining part of the json structure."""
93
+
94
+ new_complete, finish_reason = ChatGPT_API_with_finish_reason(model=model, prompt=prompt)
95
+
96
+ if new_complete.startswith('```json'):
97
+ new_complete = get_json_content(new_complete)
98
+ last_complete = last_complete + new_complete
99
+
100
+ if_complete = check_if_toc_transformation_is_complete(toc_content, last_complete, model)
101
+
102
+ retry_count += 1
103
+ if retry_count > 5:
104
+ break
105
+
106
+ try:
107
+ last_complete = json.loads(last_complete)
108
+ except json.JSONDecodeError:
109
+ last_complete = extract_json(last_complete)
110
+
111
+ if isinstance(last_complete, dict) and 'table_of_contents' in last_complete:
112
+ cleaned_response = convert_page_to_int(last_complete['table_of_contents'])
113
+ elif isinstance(last_complete, list):
114
+ cleaned_response = convert_page_to_int(last_complete)
115
+ else:
116
+ cleaned_response = []
117
+ return cleaned_response
118
+
119
+
120
+ def toc_index_extractor(toc, content, model=None):
121
+ """Add physical_index to TOC items from document pages."""
122
+ print('start toc_index_extractor')
123
+ toc_extractor_prompt = """
124
+ You are given a table of contents in a json format and several pages of a document, your job is to add the physical_index to the table of contents in the json format.
125
+
126
+ The provided pages contains tags like <physical_index_X> and <physical_index_X> to indicate the physical location of the page X.
127
+
128
+ The structure variable is the numeric system which represents the index of the hierarchy section in the table of contents. For example, the first section has structure index 1, the first subsection has structure index 1.1, the second subsection has structure index 1.2, etc.
129
+
130
+ The response should be in the following JSON format:
131
+ [
132
+ {
133
+ "structure": <structure index, "x.x.x" or None> (string),
134
+ "title": <title of the section>,
135
+ "physical_index": "<physical_index_X>" (keep the format)
136
+ },
137
+ ...
138
+ ]
139
+
140
+ Only add the physical_index to the sections that are in the provided pages.
141
+ If the section is not in the provided pages, do not add the physical_index to it.
142
+ Directly return the final JSON structure. Do not output anything else."""
143
+
144
+ prompt = toc_extractor_prompt + '\nTable of contents:\n' + str(toc) + '\nDocument pages:\n' + content
145
+ response = ChatGPT_API(model=model, prompt=prompt)
146
+ json_content = extract_json(response)
147
+ return json_content
148
+
149
+
150
+ def add_page_number_to_toc(part, structure, model=None):
151
+ """Fill page numbers from document pages into TOC structure."""
152
+ fill_prompt_seq = """
153
+ You are given an JSON structure of a document and a partial part of the document. Your task is to check if the title that is described in the structure is started in the partial given document.
154
+
155
+ The provided text contains tags like <physical_index_X> and <physical_index_X> to indicate the physical location of the page X.
156
+
157
+ If the full target section starts in the partial given document, insert the given JSON structure with the "start": "yes", and "start_index": "<physical_index_X>".
158
+
159
+ If the full target section does not start in the partial given document, insert "start": "no", "start_index": None.
160
+
161
+ The response should be in the following format.
162
+ [
163
+ {
164
+ "structure": <structure index, "x.x.x" or None> (string),
165
+ "title": <title of the section>,
166
+ "start": "<yes or no>",
167
+ "physical_index": "<physical_index_X> (keep the format)" or None
168
+ },
169
+ ...
170
+ ]
171
+ The given structure contains the result of the previous part, you need to fill the result of the current part, do not change the previous result.
172
+ Directly return the final JSON structure. Do not output anything else."""
173
+
174
+ if isinstance(part, list):
175
+ part = ''.join(part)
176
+
177
+ prompt = fill_prompt_seq + f"\n\nCurrent Partial Document:\n{part}\n\nGiven Structure\n{json.dumps(structure, indent=2)}\n"
178
+ current_json_raw = ChatGPT_API(model=model, prompt=prompt)
179
+ json_result = extract_json(current_json_raw)
180
+
181
+ if isinstance(json_result, list):
182
+ for item in json_result:
183
+ if 'start' in item:
184
+ del item['start']
185
+ return json_result
186
+
187
+
188
+ def remove_page_number(data):
189
+ """Remove page_number field from data recursively."""
190
+ if isinstance(data, dict):
191
+ data.pop('page_number', None)
192
+ data.pop('page', None)
193
+ for key in list(data.keys()):
194
+ if 'nodes' in key:
195
+ remove_page_number(data[key])
196
+ elif isinstance(data, list):
197
+ for item in data:
198
+ remove_page_number(item)
199
+ return data
200
+
201
+
202
+ def extract_matching_page_pairs(toc_page, toc_physical_index, start_page_index):
203
+ """Match TOC items by title and extract page/physical_index pairs."""
204
+ pairs = []
205
+ for phy_item in toc_physical_index:
206
+ if not isinstance(phy_item, dict):
207
+ continue
208
+ for page_item in toc_page:
209
+ if not isinstance(page_item, dict):
210
+ continue
211
+ if phy_item.get('title') == page_item.get('title'):
212
+ physical_index = phy_item.get('physical_index')
213
+ if physical_index is not None and int(physical_index) >= start_page_index:
214
+ pairs.append({
215
+ 'title': phy_item.get('title'),
216
+ 'page': page_item.get('page'),
217
+ 'physical_index': physical_index
218
+ })
219
+ return pairs
220
+
221
+
222
+ def calculate_page_offset(pairs):
223
+ """Calculate the most common page offset from matching pairs."""
224
+ differences = []
225
+ for pair in pairs:
226
+ try:
227
+ physical_index = pair['physical_index']
228
+ page_number = pair['page']
229
+ difference = physical_index - page_number
230
+ differences.append(difference)
231
+ except (KeyError, TypeError):
232
+ continue
233
+
234
+ if not differences:
235
+ return 0
236
+
237
+ difference_counts = {}
238
+ for diff in differences:
239
+ difference_counts[diff] = difference_counts.get(diff, 0) + 1
240
+
241
+ most_common = max(difference_counts.items(), key=lambda x: x[1])[0]
242
+ return most_common
243
+
244
+
245
+ def add_page_offset_to_toc_json(data, offset):
246
+ """Apply page offset to convert logical page numbers to physical indices."""
247
+ for i in range(len(data)):
248
+ if data[i].get('page') is not None and isinstance(data[i]['page'], int):
249
+ data[i]['physical_index'] = data[i]['page'] + offset
250
+ del data[i]['page']
251
+ return data
252
+
253
+
254
+ def process_none_page_numbers(toc_items, page_list, start_index=1, model=None):
255
+ """Fix items that are missing physical_index by searching nearby pages."""
256
+ for i, item in enumerate(toc_items):
257
+ if "physical_index" not in item:
258
+ # Find previous physical_index
259
+ prev_physical_index = 0
260
+ for j in range(i - 1, -1, -1):
261
+ if toc_items[j].get('physical_index') is not None:
262
+ prev_physical_index = toc_items[j]['physical_index']
263
+ break
264
+
265
+ # Find next physical_index
266
+ next_physical_index = -1
267
+ for j in range(i + 1, len(toc_items)):
268
+ if toc_items[j].get('physical_index') is not None:
269
+ next_physical_index = toc_items[j]['physical_index']
270
+ break
271
+
272
+ if next_physical_index == -1:
273
+ next_physical_index = len(page_list) + start_index - 1
274
+
275
+ page_contents = []
276
+ for page_index in range(prev_physical_index, next_physical_index + 1):
277
+ list_index = page_index - start_index
278
+ if 0 <= list_index < len(page_list):
279
+ page_text = f"<physical_index_{page_index}>\n{page_list[list_index][0]}\n<physical_index_{page_index}>\n\n"
280
+ page_contents.append(page_text)
281
+
282
+ item_copy = copy.deepcopy(item)
283
+ item_copy.pop('page', None)
284
+ result = add_page_number_to_toc(page_contents, item_copy, model)
285
+
286
+ if isinstance(result, list) and len(result) > 0:
287
+ pi = result[0].get('physical_index')
288
+ if isinstance(pi, str) and pi.startswith('<physical_index'):
289
+ item['physical_index'] = int(pi.split('_')[-1].rstrip('>').strip())
290
+ item.pop('page', None)
291
+
292
+ return toc_items
293
+
294
+
295
+ def single_toc_item_index_fixer(section_title, content, model=None):
296
+ """Fix a single TOC item's physical index using LLM."""
297
+ toc_extractor_prompt = """
298
+ You are given a section title and several pages of a document, your job is to find the physical index of the start page of the section in the partial document.
299
+
300
+ The provided pages contains tags like <physical_index_X> and <physical_index_X> to indicate the physical location of the page X.
301
+
302
+ Reply in a JSON format:
303
+ {
304
+ "thinking": <explain which page, started and closed by <physical_index_X>, contains the start of this section>,
305
+ "physical_index": "<physical_index_X>" (keep the format)
306
+ }
307
+ Directly return the final JSON structure. Do not output anything else."""
308
+
309
+ prompt = toc_extractor_prompt + '\nSection Title:\n' + str(section_title) + '\nDocument pages:\n' + content
310
+ response = ChatGPT_API(model=model, prompt=prompt)
311
+ json_content = extract_json(response)
312
+ return convert_physical_index_to_int(json_content.get('physical_index', ''))
src/toc_rag/toc_utils.py ADDED
@@ -0,0 +1,459 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ toc_utils.py
3
+ ------------
4
+ Utility functions for the TOC-based Index RAG.
5
+ Handles token counting, JSON extraction, LLM wrappers (Ollama),
6
+ structure helpers, logging, and config management.
7
+
8
+ Adapted from the reference page-indexing repository to use
9
+ local Ollama instead of OpenAI.
10
+ """
11
+
12
+ import os
13
+ import re
14
+ import json
15
+ import copy
16
+ import asyncio
17
+ import tiktoken
18
+ from pathlib import Path
19
+ from io import BytesIO
20
+ from collections import OrderedDict
21
+
22
+ import fitz # PyMuPDF
23
+ import ollama
24
+
25
+ from src.cost_tracker import add_toc_analysis_tokens
26
+
27
+ # ─── Token Counting ───────────────────────────────────────────────────────────
28
+
29
+ _tokenizer = tiktoken.get_encoding("cl100k_base")
30
+
31
+ LLM_MODEL = "gpt-oss:20b-cloud"
32
+
33
+
34
+ def count_tokens(text, model=None):
35
+ """Count tokens in text using tiktoken."""
36
+ if not text:
37
+ return 0
38
+ return len(_tokenizer.encode(str(text)))
39
+
40
+
41
+ # ─── JSON Extraction ──────────────────────────────────────────────────────────
42
+
43
+ def extract_json(text):
44
+ """Extract JSON from LLM response text, handling markdown code blocks."""
45
+ if not text:
46
+ return {}
47
+ text = text.strip()
48
+ # Remove markdown code block wrappers
49
+ if text.startswith("```json"):
50
+ text = text[7:]
51
+ elif text.startswith("```"):
52
+ text = text[3:]
53
+ if text.endswith("```"):
54
+ text = text[:-3]
55
+ text = text.strip()
56
+
57
+ try:
58
+ return json.loads(text)
59
+ except json.JSONDecodeError:
60
+ # Try to find JSON in the text
61
+ # Look for array
62
+ match = re.search(r'\[.*\]', text, re.DOTALL)
63
+ if match:
64
+ try:
65
+ return json.loads(match.group())
66
+ except json.JSONDecodeError:
67
+ pass
68
+ # Look for object
69
+ match = re.search(r'\{.*\}', text, re.DOTALL)
70
+ if match:
71
+ try:
72
+ return json.loads(match.group())
73
+ except json.JSONDecodeError:
74
+ pass
75
+ return {}
76
+
77
+
78
+ def get_json_content(text):
79
+ """Extract raw JSON content string from markdown-wrapped response."""
80
+ if not text:
81
+ return text
82
+ text = text.strip()
83
+ if text.startswith("```json"):
84
+ text = text[7:]
85
+ elif text.startswith("```"):
86
+ text = text[3:]
87
+ if text.endswith("```"):
88
+ text = text[:-3]
89
+ return text.strip()
90
+
91
+
92
+ # ─── LLM Wrappers (Ollama) ────────────────────────────────────────────────────
93
+
94
+ def ChatGPT_API(model=None, prompt=None, chat_history=None):
95
+ """Synchronous LLM call via Ollama. Returns response text."""
96
+ model = model or LLM_MODEL
97
+ messages = []
98
+ if chat_history:
99
+ messages.extend(chat_history)
100
+ messages.append({"role": "user", "content": prompt})
101
+
102
+ response = ollama.chat(model=model, messages=messages)
103
+ content = response.get("message", {}).get("content", "")
104
+
105
+ # Track tokens
106
+ prompt_tokens = response.get("prompt_eval_count", 0)
107
+ completion_tokens = response.get("eval_count", 0)
108
+ add_toc_analysis_tokens(prompt_tokens + completion_tokens)
109
+
110
+ return content
111
+
112
+
113
+ def ChatGPT_API_with_finish_reason(model=None, prompt=None, chat_history=None):
114
+ """Synchronous LLM call that also returns finish reason."""
115
+ model = model or LLM_MODEL
116
+ messages = []
117
+ if chat_history:
118
+ messages.extend(chat_history)
119
+ messages.append({"role": "user", "content": prompt})
120
+
121
+ response = ollama.chat(model=model, messages=messages)
122
+ content = response.get("message", {}).get("content", "")
123
+
124
+ # Track tokens
125
+ prompt_tokens = response.get("prompt_eval_count", 0)
126
+ completion_tokens = response.get("eval_count", 0)
127
+ add_toc_analysis_tokens(prompt_tokens + completion_tokens)
128
+
129
+ # Ollama always finishes (no length limit truncation like OpenAI)
130
+ finish_reason = "finished"
131
+ if response.get("done_reason") == "length":
132
+ finish_reason = "length"
133
+
134
+ return content, finish_reason
135
+
136
+
137
+ async def ChatGPT_API_async(model=None, prompt=None):
138
+ """Async LLM call via Ollama (runs sync call in executor)."""
139
+ loop = asyncio.get_event_loop()
140
+ return await loop.run_in_executor(None, ChatGPT_API, model, prompt, None)
141
+
142
+
143
+ # ─── Type Converters ──────────────────────────────────────────────────────────
144
+
145
+ def convert_page_to_int(data):
146
+ """Convert 'page' fields to int in TOC data."""
147
+ if isinstance(data, list):
148
+ for item in data:
149
+ if isinstance(item, dict) and 'page' in item:
150
+ try:
151
+ if item['page'] is not None:
152
+ item['page'] = int(item['page'])
153
+ except (ValueError, TypeError):
154
+ item['page'] = None
155
+ return data
156
+
157
+
158
+ def convert_physical_index_to_int(data):
159
+ """Convert physical_index from '<physical_index_X>' format to int."""
160
+ if isinstance(data, str):
161
+ # Single value
162
+ match = re.search(r'physical_index_(\d+)', data)
163
+ if match:
164
+ return int(match.group(1))
165
+ try:
166
+ return int(data)
167
+ except (ValueError, TypeError):
168
+ return None
169
+
170
+ if isinstance(data, list):
171
+ for item in data:
172
+ if isinstance(item, dict) and 'physical_index' in item:
173
+ pi = item['physical_index']
174
+ if isinstance(pi, str):
175
+ match = re.search(r'physical_index_(\d+)', pi)
176
+ if match:
177
+ item['physical_index'] = int(match.group(1))
178
+ else:
179
+ try:
180
+ item['physical_index'] = int(pi)
181
+ except (ValueError, TypeError):
182
+ item['physical_index'] = None
183
+ elif isinstance(pi, (int, float)):
184
+ item['physical_index'] = int(pi)
185
+ return data
186
+
187
+
188
+ # ─── Structure Helpers ─────────────────────────────────────────────────────────
189
+
190
+ def post_processing(toc_items, end_index):
191
+ """Convert flat TOC list into tree with start_index/end_index."""
192
+ if not toc_items:
193
+ return []
194
+
195
+ result = []
196
+ for i, item in enumerate(toc_items):
197
+ node = {
198
+ 'title': item.get('title', ''),
199
+ 'start_index': item.get('physical_index'),
200
+ }
201
+ if item.get('structure'):
202
+ node['structure'] = item['structure']
203
+ if item.get('appear_start'):
204
+ node['appear_start'] = item['appear_start']
205
+
206
+ # End index is the start of the next section, or the document end
207
+ if i + 1 < len(toc_items):
208
+ node['end_index'] = toc_items[i + 1].get('physical_index', end_index)
209
+ else:
210
+ node['end_index'] = end_index
211
+
212
+ result.append(node)
213
+
214
+ return result
215
+
216
+
217
+ def add_preface_if_needed(toc_items):
218
+ """Insert a preface node if first section doesn't start at page 1."""
219
+ if not toc_items:
220
+ return toc_items
221
+
222
+ first_pi = toc_items[0].get('physical_index')
223
+ if first_pi is not None and first_pi > 1:
224
+ preface = {
225
+ 'structure': None,
226
+ 'title': 'Preface',
227
+ 'physical_index': 1,
228
+ 'appear_start': 'yes',
229
+ }
230
+ toc_items.insert(0, preface)
231
+
232
+ return toc_items
233
+
234
+
235
+ def write_node_id(structure, counter=None):
236
+ """Assign sequential node IDs to tree nodes."""
237
+ if counter is None:
238
+ counter = [1]
239
+
240
+ for node in structure:
241
+ node['node_id'] = str(counter[0]).zfill(4)
242
+ counter[0] += 1
243
+ if 'nodes' in node and node['nodes']:
244
+ write_node_id(node['nodes'], counter)
245
+
246
+
247
+ def add_node_text(structure, page_list):
248
+ """Populate 'text' field from page_list for each node."""
249
+ for node in structure:
250
+ start = node.get('start_index', 1)
251
+ end = node.get('end_index', start)
252
+ texts = []
253
+ for pi in range(start, end + 1):
254
+ idx = pi - 1
255
+ if 0 <= idx < len(page_list):
256
+ texts.append(page_list[idx][0])
257
+ node['text'] = '\n'.join(texts)
258
+
259
+ if 'nodes' in node and node['nodes']:
260
+ add_node_text(node['nodes'], page_list)
261
+
262
+
263
+ def remove_structure_text(structure):
264
+ """Remove 'text' field from structure."""
265
+ for node in structure:
266
+ node.pop('text', None)
267
+ if 'nodes' in node and node['nodes']:
268
+ remove_structure_text(node['nodes'])
269
+
270
+
271
+ def format_structure(structure, order=None):
272
+ """Reorder dict keys in structure according to given order."""
273
+ if order is None:
274
+ return structure
275
+
276
+ result = []
277
+ for node in structure:
278
+ ordered = OrderedDict()
279
+ for key in order:
280
+ if key == 'nodes':
281
+ if 'nodes' in node and node['nodes']:
282
+ ordered['nodes'] = format_structure(node['nodes'], order)
283
+ elif key in node:
284
+ ordered[key] = node[key]
285
+ result.append(dict(ordered))
286
+ return result
287
+
288
+
289
+ def structure_to_list(structure):
290
+ """Flatten tree structure to a flat list of nodes."""
291
+ result = []
292
+ for node in structure:
293
+ result.append(node)
294
+ if 'nodes' in node and node['nodes']:
295
+ result.extend(structure_to_list(node['nodes']))
296
+ return result
297
+
298
+
299
+ def create_clean_structure_for_description(structure):
300
+ """Create a minimal structure for document description generation."""
301
+ clean = []
302
+ for node in structure:
303
+ item = {'title': node.get('title', '')}
304
+ if 'summary' in node:
305
+ item['summary'] = node['summary']
306
+ if 'prefix_summary' in node:
307
+ item['prefix_summary'] = node['prefix_summary']
308
+ if 'nodes' in node and node['nodes']:
309
+ item['nodes'] = create_clean_structure_for_description(node['nodes'])
310
+ clean.append(item)
311
+ return clean
312
+
313
+
314
+ # ─── Summary Generation ───────────────────────────────────────────────────────
315
+
316
+ async def generate_node_summary(node, model=None):
317
+ """Generate a summary for a single node using LLM."""
318
+ title = node.get('title', '')
319
+ text = node.get('text', '')
320
+
321
+ prompt = f"""Summarize the following section concisely (2-3 sentences max).
322
+
323
+ Section title: {title}
324
+ Section content:
325
+ {text[:8000]}
326
+
327
+ Provide only the summary, no extra formatting."""
328
+
329
+ return await ChatGPT_API_async(model=model, prompt=prompt)
330
+
331
+
332
+ async def generate_summaries_for_structure(structure, model=None):
333
+ """Generate summaries for all nodes in the structure."""
334
+ nodes = structure_to_list(structure)
335
+ nodes_with_text = [n for n in nodes if n.get('text')]
336
+
337
+ if not nodes_with_text:
338
+ return structure
339
+
340
+ summary_tasks = [generate_node_summary(n, model=model) for n in nodes_with_text]
341
+ summaries = await asyncio.gather(*summary_tasks, return_exceptions=True)
342
+
343
+ for node, summary in zip(nodes_with_text, summaries):
344
+ if isinstance(summary, Exception):
345
+ print(f"[TOC-RAG] Summary error for '{node.get('title', '')}': {summary}")
346
+ node['summary'] = ''
347
+ else:
348
+ if not node.get('nodes'):
349
+ node['summary'] = summary
350
+ else:
351
+ node['prefix_summary'] = summary
352
+
353
+ return structure
354
+
355
+
356
+ def generate_doc_description(structure, model=None):
357
+ """Generate a document-level description from the structure."""
358
+ prompt = f"""Based on the following document structure, provide a brief description of
359
+ what this document is about (2-3 sentences).
360
+
361
+ Document structure:
362
+ {json.dumps(structure, indent=2)[:6000]}
363
+
364
+ Provide only the description."""
365
+
366
+ return ChatGPT_API(model=model, prompt=prompt)
367
+
368
+
369
+ # ─── PDF Helpers ───────────────────────────────────────────────────────────────
370
+
371
+ def get_page_tokens(doc):
372
+ """Extract pages from PDF. Returns list of (text, token_count) tuples."""
373
+ if isinstance(doc, str):
374
+ pdf_doc = fitz.open(doc)
375
+ elif isinstance(doc, BytesIO):
376
+ pdf_doc = fitz.open(stream=doc, filetype="pdf")
377
+ else:
378
+ raise ValueError("Expected file path or BytesIO")
379
+
380
+ page_list = []
381
+ for page in pdf_doc:
382
+ text = page.get_text("text").strip()
383
+ tokens = count_tokens(text)
384
+ page_list.append((text, tokens))
385
+
386
+ return page_list
387
+
388
+
389
+ def get_pdf_name(doc):
390
+ """Get the filename from a PDF path or BytesIO."""
391
+ if isinstance(doc, str):
392
+ return Path(doc).stem
393
+ return "document"
394
+
395
+
396
+ # ─── Logger ────────────────────────────────────────────────────────────────────
397
+
398
+ class JsonLogger:
399
+ """Simple logger that prints structured JSON logs."""
400
+
401
+ def __init__(self, doc=None):
402
+ self.doc_name = get_pdf_name(doc) if doc else "unknown"
403
+
404
+ def info(self, msg):
405
+ if isinstance(msg, (dict, list)):
406
+ print(f"[TOC-RAG][{self.doc_name}] {json.dumps(msg, default=str)[:500]}")
407
+ else:
408
+ print(f"[TOC-RAG][{self.doc_name}] {str(msg)[:500]}")
409
+
410
+ def error(self, msg):
411
+ print(f"[TOC-RAG][ERROR][{self.doc_name}] {str(msg)[:500]}")
412
+
413
+
414
+ # ─── Config ───────────────────────────────────────────────────────────────────
415
+
416
+ class _Config:
417
+ """Simple config object with attribute access."""
418
+ def __init__(self, **kwargs):
419
+ for k, v in kwargs.items():
420
+ setattr(self, k, v)
421
+
422
+
423
+ class ConfigLoader:
424
+ """Load config from defaults + user overrides."""
425
+
426
+ DEFAULTS = {
427
+ "model": LLM_MODEL,
428
+ "toc_check_page_num": 20,
429
+ "max_page_num_each_node": 10,
430
+ "max_token_num_each_node": 20000,
431
+ "if_add_node_id": "yes",
432
+ "if_add_node_summary": "yes",
433
+ "if_add_doc_description": "no",
434
+ "if_add_node_text": "no",
435
+ }
436
+
437
+ def load(self, user_opt=None):
438
+ merged = dict(self.DEFAULTS)
439
+ if user_opt:
440
+ merged.update({k: v for k, v in user_opt.items() if v is not None})
441
+ return _Config(**merged)
442
+
443
+
444
+ # ─── Debug Printing ───────────────────────────────────────────────────────────
445
+
446
+ def print_json(data):
447
+ """Pretty print JSON data."""
448
+ print(json.dumps(data, indent=2, ensure_ascii=False, default=str))
449
+
450
+
451
+ def print_toc(structure, indent=0):
452
+ """Print tree table of contents."""
453
+ for node in structure:
454
+ prefix = " " * indent
455
+ title = node.get('title', '')
456
+ node_id = node.get('node_id', '')
457
+ print(f"{prefix}{node_id} {title}")
458
+ if 'nodes' in node and node['nodes']:
459
+ print_toc(node['nodes'], indent + 1)
src/toc_rag/tree_builder.py ADDED
@@ -0,0 +1,716 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ tree_builder.py
3
+ ---------------
4
+ Tree construction, async verification, fix loops, and recursive splitting.
5
+ Main async pipeline for building document structure trees from PDFs.
6
+
7
+ Replica of the reference page-indexing repository,
8
+ adapted for local Ollama LLM.
9
+ """
10
+
11
+ import json
12
+ import copy
13
+ import math
14
+ import random
15
+ import asyncio
16
+
17
+ from .toc_utils import (
18
+ ChatGPT_API, ChatGPT_API_async, ChatGPT_API_with_finish_reason,
19
+ extract_json, count_tokens, convert_physical_index_to_int,
20
+ post_processing, add_preface_if_needed,
21
+ write_node_id, add_node_text, remove_structure_text,
22
+ format_structure, structure_to_list,
23
+ create_clean_structure_for_description,
24
+ generate_summaries_for_structure, generate_doc_description,
25
+ get_page_tokens, get_pdf_name,
26
+ JsonLogger, ConfigLoader,
27
+ )
28
+ from .toc_detector import check_toc
29
+ from .toc_transformer import (
30
+ toc_transformer, toc_index_extractor, add_page_number_to_toc,
31
+ remove_page_number, extract_matching_page_pairs,
32
+ calculate_page_offset, add_page_offset_to_toc_json,
33
+ process_none_page_numbers, single_toc_item_index_fixer,
34
+ )
35
+
36
+
37
+ # ─── Title Verification (async) ───────────────────────────────────────────────
38
+
39
+ async def check_title_appearance(item, page_list, start_index=1, model=None):
40
+ """Check if a section title appears on its indicated page."""
41
+ title = item['title']
42
+ if 'physical_index' not in item or item['physical_index'] is None:
43
+ return {'list_index': item.get('list_index'), 'answer': 'no', 'title': title, 'page_number': None}
44
+
45
+ page_number = item['physical_index']
46
+ idx = page_number - start_index
47
+ if idx < 0 or idx >= len(page_list):
48
+ return {'list_index': item.get('list_index'), 'answer': 'no', 'title': title, 'page_number': page_number}
49
+
50
+ page_text = page_list[idx][0]
51
+
52
+ prompt = f"""
53
+ Your job is to check if the given section appears or starts in the given page_text.
54
+
55
+ Note: do fuzzy matching, ignore any space inconsistency in the page_text.
56
+
57
+ The given section title is {title}.
58
+ The given page_text is {page_text}.
59
+
60
+ Reply format:
61
+ {{
62
+ "thinking": <why do you think the section appears or starts in the page_text>
63
+ "answer": "yes or no" (yes if the section appears or starts in the page_text, no otherwise)
64
+ }}
65
+ Directly return the final JSON structure. Do not output anything else."""
66
+
67
+ response = await ChatGPT_API_async(model=model, prompt=prompt)
68
+ response = extract_json(response)
69
+ answer = response.get('answer', 'no')
70
+ return {'list_index': item.get('list_index'), 'answer': answer, 'title': title, 'page_number': page_number}
71
+
72
+
73
+ async def check_title_appearance_in_start(title, page_text, model=None, logger=None):
74
+ """Check if a section starts at the beginning of a page."""
75
+ prompt = f"""
76
+ You will be given the current section title and the current page_text.
77
+ Your job is to check if the current section starts in the beginning of the given page_text.
78
+ If there are other contents before the current section title, then the current section does not start in the beginning of the given page_text.
79
+ If the current section title is the first content in the given page_text, then the current section starts in the beginning of the given page_text.
80
+
81
+ Note: do fuzzy matching, ignore any space inconsistency in the page_text.
82
+
83
+ The given section title is {title}.
84
+ The given page_text is {page_text}.
85
+
86
+ reply format:
87
+ {{
88
+ "thinking": <why do you think the section appears or starts in the page_text>
89
+ "start_begin": "yes or no" (yes if the section starts in the beginning of the page_text, no otherwise)
90
+ }}
91
+ Directly return the final JSON structure. Do not output anything else."""
92
+
93
+ response = await ChatGPT_API_async(model=model, prompt=prompt)
94
+ response = extract_json(response)
95
+ if logger:
96
+ logger.info(f"Response: {response}")
97
+ return response.get("start_begin", "no")
98
+
99
+
100
+ async def check_title_appearance_in_start_concurrent(structure, page_list, model=None, logger=None):
101
+ """Check all titles for start-of-page appearance concurrently."""
102
+ if logger:
103
+ logger.info("Checking title appearance in start concurrently")
104
+
105
+ for item in structure:
106
+ if item.get('physical_index') is None:
107
+ item['appear_start'] = 'no'
108
+
109
+ tasks = []
110
+ valid_items = []
111
+ for item in structure:
112
+ if item.get('physical_index') is not None:
113
+ idx = item['physical_index'] - 1
114
+ if 0 <= idx < len(page_list):
115
+ page_text = page_list[idx][0]
116
+ tasks.append(check_title_appearance_in_start(item['title'], page_text, model=model, logger=logger))
117
+ valid_items.append(item)
118
+ else:
119
+ item['appear_start'] = 'no'
120
+
121
+ if tasks:
122
+ results = await asyncio.gather(*tasks, return_exceptions=True)
123
+ for item, result in zip(valid_items, results):
124
+ if isinstance(result, Exception):
125
+ if logger:
126
+ logger.error(f"Error checking start for {item['title']}: {result}")
127
+ item['appear_start'] = 'no'
128
+ else:
129
+ item['appear_start'] = result
130
+
131
+ return structure
132
+
133
+
134
+ # ─── Page Grouping ─────────────────────────────────────────────────────────────
135
+
136
+ def page_list_to_group_text(page_contents, token_lengths, max_tokens=20000, overlap_page=1):
137
+ """Group pages into text chunks respecting token limits."""
138
+ num_tokens = sum(token_lengths)
139
+
140
+ if num_tokens <= max_tokens:
141
+ page_text = "".join(page_contents)
142
+ return [page_text]
143
+
144
+ subsets = []
145
+ current_subset = []
146
+ current_token_count = 0
147
+
148
+ expected_parts_num = math.ceil(num_tokens / max_tokens)
149
+ average_tokens_per_part = math.ceil(((num_tokens / expected_parts_num) + max_tokens) / 2)
150
+
151
+ for i, (page_content, page_tokens) in enumerate(zip(page_contents, token_lengths)):
152
+ if current_token_count + page_tokens > average_tokens_per_part:
153
+ subsets.append(''.join(current_subset))
154
+ overlap_start = max(i - overlap_page, 0)
155
+ current_subset = list(page_contents[overlap_start:i])
156
+ current_token_count = sum(token_lengths[overlap_start:i])
157
+
158
+ current_subset.append(page_content)
159
+ current_token_count += page_tokens
160
+
161
+ if current_subset:
162
+ subsets.append(''.join(current_subset))
163
+
164
+ print('divide page_list to groups', len(subsets))
165
+ return subsets
166
+
167
+
168
+ # ─── TOC Generation (no-TOC documents) ────────────────────────────────────────
169
+
170
+ def generate_toc_init(part, model=None):
171
+ """Generate initial TOC structure from document text using LLM."""
172
+ print('start generate_toc_init')
173
+ prompt = """
174
+ You are an expert in extracting hierarchical tree structure, your task is to generate the tree structure of the document.
175
+
176
+ The structure variable is the numeric system which represents the index of the hierarchy section in the table of contents. For example, the first section has structure index 1, the first subsection has structure index 1.1, the second subsection has structure index 1.2, etc.
177
+
178
+ For the title, you need to extract the original title from the text, only fix the space inconsistency.
179
+
180
+ The provided text contains tags like <physical_index_X> and <physical_index_X> to indicate the start and end of page X.
181
+
182
+ For the physical_index, you need to extract the physical index of the start of the section from the text. Keep the <physical_index_X> format.
183
+
184
+ The response should be in the following format.
185
+ [
186
+ {
187
+ "structure": <structure index, "x.x.x"> (string),
188
+ "title": <title of the section, keep the original title>,
189
+ "physical_index": "<physical_index_X> (keep the format)"
190
+ },
191
+ ]
192
+
193
+ Directly return the final JSON structure. Do not output anything else."""
194
+
195
+ prompt = prompt + '\nGiven text\n:' + part
196
+ response, finish_reason = ChatGPT_API_with_finish_reason(model=model, prompt=prompt)
197
+
198
+ if finish_reason == 'finished':
199
+ return extract_json(response)
200
+ else:
201
+ raise Exception(f'finish reason: {finish_reason}')
202
+
203
+
204
+ def generate_toc_continue(toc_content, part, model=None):
205
+ """Continue TOC generation for additional document parts."""
206
+ print('start generate_toc_continue')
207
+ prompt = """
208
+ You are an expert in extracting hierarchical tree structure.
209
+ You are given a tree structure of the previous part and the text of the current part.
210
+ Your task is to continue the tree structure from the previous part to include the current part.
211
+
212
+ The structure variable is the numeric system which represents the index of the hierarchy section in the table of contents. For example, the first section has structure index 1, the first subsection has structure index 1.1, the second subsection has structure index 1.2, etc.
213
+
214
+ For the title, you need to extract the original title from the text, only fix the space inconsistency.
215
+
216
+ The provided text contains tags like <physical_index_X> and <physical_index_X> to indicate the start and end of page X.
217
+
218
+ For the physical_index, you need to extract the physical index of the start of the section from the text. Keep the <physical_index_X> format.
219
+
220
+ The response should be in the following format.
221
+ [
222
+ {
223
+ "structure": <structure index, "x.x.x"> (string),
224
+ "title": <title of the section, keep the original title>,
225
+ "physical_index": "<physical_index_X> (keep the format)"
226
+ },
227
+ ...
228
+ ]
229
+
230
+ Directly return the additional part of the final JSON structure. Do not output anything else."""
231
+
232
+ prompt = prompt + '\nGiven text\n:' + part + '\nPrevious tree structure\n:' + json.dumps(toc_content, indent=2)
233
+ response, finish_reason = ChatGPT_API_with_finish_reason(model=model, prompt=prompt)
234
+ if finish_reason == 'finished':
235
+ return extract_json(response)
236
+ else:
237
+ raise Exception(f'finish reason: {finish_reason}')
238
+
239
+
240
+ # ─── Processing Pipelines ─────────────────────────────────────────────────────
241
+
242
+ def process_no_toc(page_list, start_index=1, model=None, logger=None):
243
+ """Full pipeline for documents without a TOC."""
244
+ page_contents = []
245
+ token_lengths = []
246
+ for page_index in range(start_index, start_index + len(page_list)):
247
+ page_text = f"<physical_index_{page_index}>\n{page_list[page_index-start_index][0]}\n<physical_index_{page_index}>\n\n"
248
+ page_contents.append(page_text)
249
+ token_lengths.append(count_tokens(page_text, model))
250
+ group_texts = page_list_to_group_text(page_contents, token_lengths)
251
+ if logger:
252
+ logger.info(f'len(group_texts): {len(group_texts)}')
253
+
254
+ toc_with_page_number = generate_toc_init(group_texts[0], model)
255
+ for group_text in group_texts[1:]:
256
+ toc_with_page_number_additional = generate_toc_continue(toc_with_page_number, group_text, model)
257
+ toc_with_page_number.extend(toc_with_page_number_additional)
258
+ if logger:
259
+ logger.info(f'generate_toc: {toc_with_page_number}')
260
+
261
+ toc_with_page_number = convert_physical_index_to_int(toc_with_page_number)
262
+ if logger:
263
+ logger.info(f'convert_physical_index_to_int: {toc_with_page_number}')
264
+
265
+ return toc_with_page_number
266
+
267
+
268
+ def process_toc_no_page_numbers(toc_content, toc_page_list, page_list, start_index=1, model=None, logger=None):
269
+ """Pipeline for documents with TOC but no page numbers."""
270
+ page_contents = []
271
+ token_lengths = []
272
+ toc_content = toc_transformer(toc_content, model)
273
+ if logger:
274
+ logger.info(f'toc_transformer: {toc_content}')
275
+ for page_index in range(start_index, start_index + len(page_list)):
276
+ page_text = f"<physical_index_{page_index}>\n{page_list[page_index-start_index][0]}\n<physical_index_{page_index}>\n\n"
277
+ page_contents.append(page_text)
278
+ token_lengths.append(count_tokens(page_text, model))
279
+
280
+ group_texts = page_list_to_group_text(page_contents, token_lengths)
281
+ if logger:
282
+ logger.info(f'len(group_texts): {len(group_texts)}')
283
+
284
+ toc_with_page_number = copy.deepcopy(toc_content)
285
+ for group_text in group_texts:
286
+ toc_with_page_number = add_page_number_to_toc(group_text, toc_with_page_number, model)
287
+ if logger:
288
+ logger.info(f'add_page_number_to_toc: {toc_with_page_number}')
289
+
290
+ toc_with_page_number = convert_physical_index_to_int(toc_with_page_number)
291
+ if logger:
292
+ logger.info(f'convert_physical_index_to_int: {toc_with_page_number}')
293
+
294
+ return toc_with_page_number
295
+
296
+
297
+ def process_toc_with_page_numbers(toc_content, toc_page_list, page_list, toc_check_page_num=None, model=None, logger=None):
298
+ """Pipeline for documents with TOC and page numbers."""
299
+ toc_with_page_number = toc_transformer(toc_content, model)
300
+ if logger:
301
+ logger.info(f'toc_with_page_number: {toc_with_page_number}')
302
+
303
+ toc_no_page_number = remove_page_number(copy.deepcopy(toc_with_page_number))
304
+
305
+ start_page_index = toc_page_list[-1] + 1
306
+ main_content = ""
307
+ for page_index in range(start_page_index, min(start_page_index + (toc_check_page_num or 20), len(page_list))):
308
+ main_content += f"<physical_index_{page_index+1}>\n{page_list[page_index][0]}\n<physical_index_{page_index+1}>\n\n"
309
+
310
+ toc_with_physical_index = toc_index_extractor(toc_no_page_number, main_content, model)
311
+ if logger:
312
+ logger.info(f'toc_with_physical_index: {toc_with_physical_index}')
313
+
314
+ toc_with_physical_index = convert_physical_index_to_int(toc_with_physical_index)
315
+ if logger:
316
+ logger.info(f'toc_with_physical_index: {toc_with_physical_index}')
317
+
318
+ matching_pairs = extract_matching_page_pairs(toc_with_page_number, toc_with_physical_index, start_page_index)
319
+ if logger:
320
+ logger.info(f'matching_pairs: {matching_pairs}')
321
+
322
+ offset = calculate_page_offset(matching_pairs)
323
+ if logger:
324
+ logger.info(f'offset: {offset}')
325
+
326
+ toc_with_page_number = add_page_offset_to_toc_json(toc_with_page_number, offset)
327
+ if logger:
328
+ logger.info(f'toc_with_page_number: {toc_with_page_number}')
329
+
330
+ toc_with_page_number = process_none_page_numbers(toc_with_page_number, page_list, model=model)
331
+ if logger:
332
+ logger.info(f'toc_with_page_number: {toc_with_page_number}')
333
+
334
+ return toc_with_page_number
335
+
336
+
337
+ # ─── Validation ─────────────────────────────────────────────────────────��──────
338
+
339
+ def validate_and_truncate_physical_indices(toc_with_page_number, page_list_length, start_index=1, logger=None):
340
+ """Validate and truncate physical indices exceeding document length."""
341
+ if not toc_with_page_number:
342
+ return toc_with_page_number
343
+
344
+ max_allowed_page = page_list_length + start_index - 1
345
+ truncated_items = []
346
+
347
+ for i, item in enumerate(toc_with_page_number):
348
+ if item.get('physical_index') is not None:
349
+ original_index = item['physical_index']
350
+ if original_index > max_allowed_page:
351
+ item['physical_index'] = None
352
+ truncated_items.append({
353
+ 'title': item.get('title', 'Unknown'),
354
+ 'original_index': original_index
355
+ })
356
+ if logger:
357
+ logger.info(f"Removed physical_index for '{item.get('title', 'Unknown')}' (was {original_index})")
358
+
359
+ if truncated_items and logger:
360
+ logger.info(f"Total removed items: {len(truncated_items)}")
361
+
362
+ print(f"Document validation: {page_list_length} pages, max allowed index: {max_allowed_page}")
363
+ if truncated_items:
364
+ print(f"Truncated {len(truncated_items)} TOC items that exceeded document length")
365
+
366
+ return toc_with_page_number
367
+
368
+
369
+ # ─── Verification ─────────────────────────────────────────────────────────────
370
+
371
+ async def verify_toc(page_list, list_result, start_index=1, N=None, model=None):
372
+ """Verify TOC accuracy by checking title appearances on pages."""
373
+ print('start verify_toc')
374
+ last_physical_index = None
375
+ for item in reversed(list_result):
376
+ if item.get('physical_index') is not None:
377
+ last_physical_index = item['physical_index']
378
+ break
379
+
380
+ if last_physical_index is None or last_physical_index < len(page_list) / 2:
381
+ return 0, []
382
+
383
+ if N is None:
384
+ sample_indices = range(0, len(list_result))
385
+ else:
386
+ N = min(N, len(list_result))
387
+ sample_indices = random.sample(range(0, len(list_result)), N)
388
+
389
+ indexed_sample_list = []
390
+ for idx in sample_indices:
391
+ item = list_result[idx]
392
+ if item.get('physical_index') is not None:
393
+ item_with_index = item.copy()
394
+ item_with_index['list_index'] = idx
395
+ indexed_sample_list.append(item_with_index)
396
+
397
+ tasks = [
398
+ check_title_appearance(item, page_list, start_index, model)
399
+ for item in indexed_sample_list
400
+ ]
401
+ results = await asyncio.gather(*tasks)
402
+
403
+ correct_count = 0
404
+ incorrect_results = []
405
+ for result in results:
406
+ if result['answer'] == 'yes':
407
+ correct_count += 1
408
+ else:
409
+ incorrect_results.append(result)
410
+
411
+ checked_count = len(results)
412
+ accuracy = correct_count / checked_count if checked_count > 0 else 0
413
+ print(f"accuracy: {accuracy*100:.2f}%")
414
+ return accuracy, incorrect_results
415
+
416
+
417
+ # ─── Fix Loops ─────────────────────────────────────────────────────────────────
418
+
419
+ async def fix_incorrect_toc(toc_with_page_number, page_list, incorrect_results, start_index=1, model=None, logger=None):
420
+ """Fix incorrect TOC entries by searching within nearby pages."""
421
+ print(f'start fix_incorrect_toc with {len(incorrect_results)} incorrect results')
422
+ incorrect_indices = {result['list_index'] for result in incorrect_results}
423
+ end_index = len(page_list) + start_index - 1
424
+
425
+ async def process_and_check_item(incorrect_item):
426
+ list_index = incorrect_item['list_index']
427
+
428
+ if list_index < 0 or list_index >= len(toc_with_page_number):
429
+ return {
430
+ 'list_index': list_index,
431
+ 'title': incorrect_item['title'],
432
+ 'physical_index': incorrect_item.get('physical_index'),
433
+ 'is_valid': False
434
+ }
435
+
436
+ prev_correct = None
437
+ for i in range(list_index - 1, -1, -1):
438
+ if i not in incorrect_indices and 0 <= i < len(toc_with_page_number):
439
+ pi = toc_with_page_number[i].get('physical_index')
440
+ if pi is not None:
441
+ prev_correct = pi
442
+ break
443
+ if prev_correct is None:
444
+ prev_correct = start_index - 1
445
+
446
+ next_correct = None
447
+ for i in range(list_index + 1, len(toc_with_page_number)):
448
+ if i not in incorrect_indices and 0 <= i < len(toc_with_page_number):
449
+ pi = toc_with_page_number[i].get('physical_index')
450
+ if pi is not None:
451
+ next_correct = pi
452
+ break
453
+ if next_correct is None:
454
+ next_correct = end_index
455
+
456
+ page_contents = []
457
+ for page_index in range(prev_correct, next_correct + 1):
458
+ li = page_index - start_index
459
+ if 0 <= li < len(page_list):
460
+ page_text = f"<physical_index_{page_index}>\n{page_list[li][0]}\n<physical_index_{page_index}>\n\n"
461
+ page_contents.append(page_text)
462
+ content_range = ''.join(page_contents)
463
+
464
+ physical_index_int = single_toc_item_index_fixer(incorrect_item['title'], content_range, model)
465
+
466
+ check_item = incorrect_item.copy()
467
+ check_item['physical_index'] = physical_index_int
468
+ check_result = await check_title_appearance(check_item, page_list, start_index, model)
469
+
470
+ return {
471
+ 'list_index': list_index,
472
+ 'title': incorrect_item['title'],
473
+ 'physical_index': physical_index_int,
474
+ 'is_valid': check_result['answer'] == 'yes'
475
+ }
476
+
477
+ tasks = [process_and_check_item(item) for item in incorrect_results]
478
+ results = await asyncio.gather(*tasks, return_exceptions=True)
479
+ results = [r for r in results if not isinstance(r, Exception)]
480
+
481
+ invalid_results = []
482
+ for result in results:
483
+ if result['is_valid']:
484
+ list_idx = result['list_index']
485
+ if 0 <= list_idx < len(toc_with_page_number):
486
+ toc_with_page_number[list_idx]['physical_index'] = result['physical_index']
487
+ else:
488
+ invalid_results.append(result)
489
+ else:
490
+ invalid_results.append(result)
491
+
492
+ if logger:
493
+ logger.info(f'invalid_results: {invalid_results}')
494
+
495
+ return toc_with_page_number, invalid_results
496
+
497
+
498
+ async def fix_incorrect_toc_with_retries(toc_with_page_number, page_list, incorrect_results, start_index=1, max_attempts=3, model=None, logger=None):
499
+ """Fix incorrect TOC with retry loop."""
500
+ print('start fix_incorrect_toc')
501
+ fix_attempt = 0
502
+ current_toc = toc_with_page_number
503
+ current_incorrect = incorrect_results
504
+
505
+ while current_incorrect:
506
+ print(f"Fixing {len(current_incorrect)} incorrect results")
507
+ current_toc, current_incorrect = await fix_incorrect_toc(
508
+ current_toc, page_list, current_incorrect, start_index, model, logger
509
+ )
510
+ fix_attempt += 1
511
+ if fix_attempt >= max_attempts:
512
+ if logger:
513
+ logger.info("Maximum fix attempts reached")
514
+ break
515
+
516
+ return current_toc, current_incorrect
517
+
518
+
519
+ # ─── Main Async Dispatcher ────────────────────────────────────────────────────
520
+
521
+ async def meta_processor(page_list, mode=None, toc_content=None, toc_page_list=None, start_index=1, opt=None, logger=None):
522
+ """Main async dispatcher with fallback chain."""
523
+ print(mode)
524
+ print(f'start_index: {start_index}')
525
+
526
+ if mode == 'process_toc_with_page_numbers':
527
+ toc_with_page_number = process_toc_with_page_numbers(
528
+ toc_content, toc_page_list, page_list,
529
+ toc_check_page_num=opt.toc_check_page_num, model=opt.model, logger=logger
530
+ )
531
+ elif mode == 'process_toc_no_page_numbers':
532
+ toc_with_page_number = process_toc_no_page_numbers(
533
+ toc_content, toc_page_list, page_list, model=opt.model, logger=logger
534
+ )
535
+ else:
536
+ toc_with_page_number = process_no_toc(page_list, start_index=start_index, model=opt.model, logger=logger)
537
+
538
+ toc_with_page_number = [item for item in toc_with_page_number if item.get('physical_index') is not None]
539
+
540
+ toc_with_page_number = validate_and_truncate_physical_indices(
541
+ toc_with_page_number, len(page_list), start_index=start_index, logger=logger
542
+ )
543
+
544
+ accuracy, incorrect_results = await verify_toc(
545
+ page_list, toc_with_page_number, start_index=start_index, model=opt.model
546
+ )
547
+
548
+ if logger:
549
+ logger.info({
550
+ 'mode': mode,
551
+ 'accuracy': accuracy,
552
+ 'incorrect_count': len(incorrect_results)
553
+ })
554
+
555
+ if accuracy == 1.0 and len(incorrect_results) == 0:
556
+ return toc_with_page_number
557
+ if accuracy > 0.6 and len(incorrect_results) > 0:
558
+ toc_with_page_number, _ = await fix_incorrect_toc_with_retries(
559
+ toc_with_page_number, page_list, incorrect_results,
560
+ start_index=start_index, max_attempts=3, model=opt.model, logger=logger
561
+ )
562
+ return toc_with_page_number
563
+ else:
564
+ if mode == 'process_toc_with_page_numbers':
565
+ return await meta_processor(
566
+ page_list, mode='process_toc_no_page_numbers',
567
+ toc_content=toc_content, toc_page_list=toc_page_list,
568
+ start_index=start_index, opt=opt, logger=logger
569
+ )
570
+ elif mode == 'process_toc_no_page_numbers':
571
+ return await meta_processor(
572
+ page_list, mode='process_no_toc',
573
+ start_index=start_index, opt=opt, logger=logger
574
+ )
575
+ else:
576
+ raise Exception('Processing failed')
577
+
578
+
579
+ # ─── Recursive Large Node Splitting ────────────────���──────────────────────────
580
+
581
+ async def process_large_node_recursively(node, page_list, opt=None, logger=None):
582
+ """Recursively split large nodes into sub-nodes."""
583
+ start_idx = node.get('start_index', 1)
584
+ end_idx = node.get('end_index', start_idx)
585
+ node_page_list = page_list[start_idx - 1:end_idx]
586
+ token_num = sum([page[1] for page in node_page_list])
587
+
588
+ if (end_idx - start_idx > opt.max_page_num_each_node and
589
+ token_num >= opt.max_token_num_each_node):
590
+ print(f'large node: {node["title"]} start: {start_idx} end: {end_idx} tokens: {token_num}')
591
+
592
+ node_toc_tree = await meta_processor(
593
+ node_page_list, mode='process_no_toc',
594
+ start_index=start_idx, opt=opt, logger=logger
595
+ )
596
+ node_toc_tree = await check_title_appearance_in_start_concurrent(
597
+ node_toc_tree, page_list, model=opt.model, logger=logger
598
+ )
599
+
600
+ valid_items = [item for item in node_toc_tree if item.get('physical_index') is not None]
601
+
602
+ if valid_items and node['title'].strip() == valid_items[0]['title'].strip():
603
+ node['nodes'] = post_processing(valid_items[1:], end_idx)
604
+ node['end_index'] = valid_items[1]['start_index'] if len(valid_items) > 1 else end_idx
605
+ else:
606
+ node['nodes'] = post_processing(valid_items, end_idx)
607
+ node['end_index'] = valid_items[0]['start_index'] if valid_items else end_idx
608
+
609
+ if 'nodes' in node and node['nodes']:
610
+ tasks = [
611
+ process_large_node_recursively(child, page_list, opt, logger=logger)
612
+ for child in node['nodes']
613
+ ]
614
+ await asyncio.gather(*tasks)
615
+
616
+ return node
617
+
618
+
619
+ # ─── Top-Level Entry Points ───────────────────────────────────────────────────
620
+
621
+ async def tree_parser(page_list, opt, doc=None, logger=None):
622
+ """Top-level async entry point for building the document tree."""
623
+ check_toc_result = check_toc(page_list, opt)
624
+ if logger:
625
+ logger.info(check_toc_result)
626
+
627
+ if (check_toc_result.get("toc_content") and
628
+ check_toc_result["toc_content"].strip() and
629
+ check_toc_result["page_index_given_in_toc"] == "yes"):
630
+ toc_with_page_number = await meta_processor(
631
+ page_list, mode='process_toc_with_page_numbers',
632
+ start_index=1,
633
+ toc_content=check_toc_result['toc_content'],
634
+ toc_page_list=check_toc_result['toc_page_list'],
635
+ opt=opt, logger=logger
636
+ )
637
+ else:
638
+ toc_with_page_number = await meta_processor(
639
+ page_list, mode='process_no_toc',
640
+ start_index=1, opt=opt, logger=logger
641
+ )
642
+
643
+ toc_with_page_number = add_preface_if_needed(toc_with_page_number)
644
+ toc_with_page_number = await check_title_appearance_in_start_concurrent(
645
+ toc_with_page_number, page_list, model=opt.model, logger=logger
646
+ )
647
+
648
+ valid_toc_items = [item for item in toc_with_page_number if item.get('physical_index') is not None]
649
+
650
+ toc_tree = post_processing(valid_toc_items, len(page_list))
651
+ tasks = [
652
+ process_large_node_recursively(node, page_list, opt, logger=logger)
653
+ for node in toc_tree
654
+ ]
655
+ await asyncio.gather(*tasks)
656
+
657
+ return toc_tree
658
+
659
+
660
+ def page_index_main(doc, opt=None):
661
+ """Main synchronous entry point for PDF page indexing."""
662
+ from io import BytesIO
663
+ logger = JsonLogger(doc)
664
+
665
+ is_valid_pdf = (
666
+ (isinstance(doc, str) and os.path.isfile(doc) and doc.lower().endswith(".pdf")) or
667
+ isinstance(doc, BytesIO)
668
+ )
669
+ if not is_valid_pdf:
670
+ raise ValueError("Unsupported input type. Expected a PDF file path or BytesIO object.")
671
+
672
+ print('Parsing PDF...')
673
+ page_list = get_page_tokens(doc)
674
+
675
+ logger.info({'total_page_number': len(page_list)})
676
+ logger.info({'total_token': sum([page[1] for page in page_list])})
677
+
678
+ async def page_index_builder():
679
+ structure = await tree_parser(page_list, opt, doc=doc, logger=logger)
680
+ if opt.if_add_node_id == 'yes':
681
+ write_node_id(structure)
682
+ if opt.if_add_node_text == 'yes':
683
+ add_node_text(structure, page_list)
684
+ if opt.if_add_node_summary == 'yes':
685
+ if opt.if_add_node_text == 'no':
686
+ add_node_text(structure, page_list)
687
+ await generate_summaries_for_structure(structure, model=opt.model)
688
+ if opt.if_add_node_text == 'no':
689
+ remove_structure_text(structure)
690
+ if opt.if_add_doc_description == 'yes':
691
+ clean_structure = create_clean_structure_for_description(structure)
692
+ doc_description = generate_doc_description(clean_structure, model=opt.model)
693
+ return {
694
+ 'doc_name': get_pdf_name(doc),
695
+ 'doc_description': doc_description,
696
+ 'structure': structure,
697
+ }
698
+ return {
699
+ 'doc_name': get_pdf_name(doc),
700
+ 'structure': structure,
701
+ }
702
+
703
+ return asyncio.run(page_index_builder())
704
+
705
+
706
+ def page_index(doc, model=None, toc_check_page_num=None, max_page_num_each_node=None,
707
+ max_token_num_each_node=None, if_add_node_id=None, if_add_node_summary=None,
708
+ if_add_doc_description=None, if_add_node_text=None):
709
+ """Public API: index a document and return its tree structure."""
710
+ import os
711
+ user_opt = {
712
+ arg: value for arg, value in locals().items()
713
+ if arg != "doc" and value is not None
714
+ }
715
+ opt = ConfigLoader().load(user_opt)
716
+ return page_index_main(doc, opt)
static/index.html CHANGED
@@ -36,6 +36,8 @@
36
  </div>
37
  <div><span class="cost-label">Response:</span> <span id="stat-response" class="cost-value">0</span>
38
  </div>
 
 
39
  <div class="cost-total-row">
40
  <span class="cost-label" style="font-weight: 600;">Total Tokens:</span>
41
  <span id="stat-total" class="cost-value cost-total">0</span>
@@ -62,7 +64,7 @@
62
  <main class="main-content">
63
  <header class="top-nav">
64
  <h1>RAG Comparator</h1>
65
- <p>Vector RAG vs Page Index RAG</p>
66
  </header>
67
 
68
  <!-- Loading Overlay -->
@@ -113,12 +115,16 @@
113
  <div class="column-header">Vector RAG (Sliding Window)</div>
114
  <div id="vector-chunks-list" class="chunks-list"></div>
115
  </div>
 
 
 
 
116
  </div>
117
  </div>
118
  </div>
119
  </div>
120
 
121
- <script src="/static/script.js?v=8"></script>
122
  </body>
123
 
124
  </html>
 
36
  </div>
37
  <div><span class="cost-label">Response:</span> <span id="stat-response" class="cost-value">0</span>
38
  </div>
39
+ <div><span class="cost-label">TOC Analysis:</span> <span id="stat-toc" class="cost-value">0</span>
40
+ </div>
41
  <div class="cost-total-row">
42
  <span class="cost-label" style="font-weight: 600;">Total Tokens:</span>
43
  <span id="stat-total" class="cost-value cost-total">0</span>
 
64
  <main class="main-content">
65
  <header class="top-nav">
66
  <h1>RAG Comparator</h1>
67
+ <p>Vector RAG vs Page Index RAG vs TOC Index RAG</p>
68
  </header>
69
 
70
  <!-- Loading Overlay -->
 
115
  <div class="column-header">Vector RAG (Sliding Window)</div>
116
  <div id="vector-chunks-list" class="chunks-list"></div>
117
  </div>
118
+ <div class="chunk-column">
119
+ <div class="column-header">TOC Index RAG (Tree Nodes)</div>
120
+ <div id="toc-chunks-list" class="chunks-list"></div>
121
+ </div>
122
  </div>
123
  </div>
124
  </div>
125
  </div>
126
 
127
+ <script src="/static/script.js?v=9"></script>
128
  </body>
129
 
130
  </html>
static/script.js CHANGED
@@ -209,6 +209,7 @@ document.addEventListener('DOMContentLoaded', () => {
209
  document.getElementById('stat-embed').textContent = data.tokens.embedding_tokens.toLocaleString();
210
  document.getElementById('stat-prompt').textContent = data.tokens.prompt_tokens.toLocaleString();
211
  document.getElementById('stat-response').textContent = data.tokens.completion_tokens.toLocaleString();
 
212
  document.getElementById('stat-total').textContent = data.total_tokens.toLocaleString();
213
  } catch (error) {
214
  console.error('Error fetching stats:', error);
@@ -269,9 +270,11 @@ document.addEventListener('DOMContentLoaded', () => {
269
  // Create Cards
270
  const vectorCard = createResponseCard('Vector RAG');
271
  const pageCard = createResponseCard('Page Index RAG');
 
272
 
273
  responsesGrid.appendChild(vectorCard);
274
  responsesGrid.appendChild(pageCard);
 
275
  sysMsg.appendChild(responsesGrid);
276
  chatContainer.appendChild(sysMsg);
277
 
@@ -279,9 +282,10 @@ document.addEventListener('DOMContentLoaded', () => {
279
 
280
  const topK = parseInt(topKInput.value) || 5;
281
 
282
- // Trigger both SSE calls
283
  fetchSSE('/api/chat/vector', { query, top_k: topK }, vectorCard);
284
  fetchSSE('/api/chat/page', { query, top_k: topK }, pageCard);
 
285
  }
286
 
287
  async function fetchSSE(url, payload, cardElement) {
@@ -334,7 +338,7 @@ document.addEventListener('DOMContentLoaded', () => {
334
  // Render sources
335
  let sourcesHtml = '';
336
  data.sources.forEach((s, idx) => {
337
- let label = s.chunk_index !== undefined ? `Chunk #${s.chunk_index}` : `Page ${s.page_num}`;
338
  sourcesHtml += `
339
  <div class="source-item">
340
  <span class="source-meta">[${idx + 1}] ${s.source} — ${label} (score: ${s.score})</span>
@@ -382,6 +386,7 @@ document.addEventListener('DOMContentLoaded', () => {
382
  const modalDocTitle = document.getElementById('modal-doc-title');
383
  const pageChunksList = document.getElementById('page-chunks-list');
384
  const vectorChunksList = document.getElementById('vector-chunks-list');
 
385
 
386
  closeModalBtn.addEventListener('click', () => {
387
  chunksModal.style.display = 'none';
@@ -398,6 +403,7 @@ document.addEventListener('DOMContentLoaded', () => {
398
  modalDocTitle.textContent = filename;
399
  pageChunksList.innerHTML = '<div class="spinner large-spinner"></div>';
400
  vectorChunksList.innerHTML = '<div class="spinner large-spinner"></div>';
 
401
  chunksModal.style.display = 'flex';
402
 
403
  try {
@@ -470,10 +476,44 @@ document.addEventListener('DOMContentLoaded', () => {
470
  vectorChunksList.innerHTML = '<div style="padding:16px;">No vector chunks found.</div>';
471
  }
472
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
473
  } catch (error) {
474
  console.error(error);
475
  pageChunksList.innerHTML = '<div style="color:red; padding:16px;">Error loading data.</div>';
476
  vectorChunksList.innerHTML = '<div style="color:red; padding:16px;">Error loading data.</div>';
 
477
  }
478
  }
479
  });
 
209
  document.getElementById('stat-embed').textContent = data.tokens.embedding_tokens.toLocaleString();
210
  document.getElementById('stat-prompt').textContent = data.tokens.prompt_tokens.toLocaleString();
211
  document.getElementById('stat-response').textContent = data.tokens.completion_tokens.toLocaleString();
212
+ document.getElementById('stat-toc').textContent = (data.tokens.toc_analysis_tokens || 0).toLocaleString();
213
  document.getElementById('stat-total').textContent = data.total_tokens.toLocaleString();
214
  } catch (error) {
215
  console.error('Error fetching stats:', error);
 
270
  // Create Cards
271
  const vectorCard = createResponseCard('Vector RAG');
272
  const pageCard = createResponseCard('Page Index RAG');
273
+ const tocCard = createResponseCard('TOC Index RAG');
274
 
275
  responsesGrid.appendChild(vectorCard);
276
  responsesGrid.appendChild(pageCard);
277
+ responsesGrid.appendChild(tocCard);
278
  sysMsg.appendChild(responsesGrid);
279
  chatContainer.appendChild(sysMsg);
280
 
 
282
 
283
  const topK = parseInt(topKInput.value) || 5;
284
 
285
+ // Trigger all three SSE calls
286
  fetchSSE('/api/chat/vector', { query, top_k: topK }, vectorCard);
287
  fetchSSE('/api/chat/page', { query, top_k: topK }, pageCard);
288
+ fetchSSE('/api/chat/toc', { query, top_k: topK }, tocCard);
289
  }
290
 
291
  async function fetchSSE(url, payload, cardElement) {
 
338
  // Render sources
339
  let sourcesHtml = '';
340
  data.sources.forEach((s, idx) => {
341
+ let label = s.chunk_index !== undefined ? `Chunk #${s.chunk_index}` : s.node_id !== undefined ? `Node ${s.node_id}: ${s.title || ''}` : `Page ${s.page_num}`;
342
  sourcesHtml += `
343
  <div class="source-item">
344
  <span class="source-meta">[${idx + 1}] ${s.source} — ${label} (score: ${s.score})</span>
 
386
  const modalDocTitle = document.getElementById('modal-doc-title');
387
  const pageChunksList = document.getElementById('page-chunks-list');
388
  const vectorChunksList = document.getElementById('vector-chunks-list');
389
+ const tocChunksList = document.getElementById('toc-chunks-list');
390
 
391
  closeModalBtn.addEventListener('click', () => {
392
  chunksModal.style.display = 'none';
 
403
  modalDocTitle.textContent = filename;
404
  pageChunksList.innerHTML = '<div class="spinner large-spinner"></div>';
405
  vectorChunksList.innerHTML = '<div class="spinner large-spinner"></div>';
406
+ tocChunksList.innerHTML = '<div class="spinner large-spinner"></div>';
407
  chunksModal.style.display = 'flex';
408
 
409
  try {
 
476
  vectorChunksList.innerHTML = '<div style="padding:16px;">No vector chunks found.</div>';
477
  }
478
 
479
+ // Render TOC Chunks
480
+ tocChunksList.innerHTML = '';
481
+ let totalTocTokens = 0;
482
+ if (data.toc_chunks && data.toc_chunks.length > 0) {
483
+ data.toc_chunks.forEach(chunk => {
484
+ const el = document.createElement('div');
485
+ el.className = 'chunk-item';
486
+ el.innerHTML = `
487
+ <span class="chunk-meta" style="display:flex; justify-content:space-between;">
488
+ <span>Node ${chunk.node_id}: ${chunk.title || ''}</span>
489
+ <span style="color:var(--text-secondary); font-size:0.85em;">Tokens: ${chunk.tokens || 0}</span>
490
+ </span>
491
+ <div class="chunk-text">${chunk.text}</div>
492
+ `;
493
+ tocChunksList.appendChild(el);
494
+ totalTocTokens += (chunk.tokens || 0);
495
+ });
496
+
497
+ const totalEl = document.createElement('div');
498
+ totalEl.style.marginTop = '16px';
499
+ totalEl.style.padding = '12px';
500
+ totalEl.style.backgroundColor = 'var(--bg-secondary)';
501
+ totalEl.style.borderRadius = '6px';
502
+ totalEl.style.fontWeight = 'bold';
503
+ totalEl.style.display = 'flex';
504
+ totalEl.style.justifyContent = 'space-between';
505
+ totalEl.style.border = '1px solid var(--border-color)';
506
+ totalEl.innerHTML = `<span>Total TOC RAG Tokens:</span> <span>${totalTocTokens.toLocaleString()}</span>`;
507
+ tocChunksList.appendChild(totalEl);
508
+ } else {
509
+ tocChunksList.innerHTML = '<div style="padding:16px;">No TOC nodes found.</div>';
510
+ }
511
+
512
  } catch (error) {
513
  console.error(error);
514
  pageChunksList.innerHTML = '<div style="color:red; padding:16px;">Error loading data.</div>';
515
  vectorChunksList.innerHTML = '<div style="color:red; padding:16px;">Error loading data.</div>';
516
+ tocChunksList.innerHTML = '<div style="color:red; padding:16px;">Error loading data.</div>';
517
  }
518
  }
519
  });
static/style.css CHANGED
@@ -188,7 +188,7 @@ body {
188
 
189
  .system-responses {
190
  display: grid;
191
- grid-template-columns: 1fr 1fr;
192
  gap: 24px;
193
  width: 100%;
194
  }
 
188
 
189
  .system-responses {
190
  display: grid;
191
+ grid-template-columns: 1fr 1fr 1fr;
192
  gap: 24px;
193
  width: 100%;
194
  }