File size: 4,406 Bytes
31c93b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4aaae80
 
 
 
 
31c93b1
 
 
4aaae80
31c93b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4aaae80
 
 
 
 
31c93b1
 
4aaae80
31c93b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from __future__ import annotations

from dataclasses import dataclass


@dataclass(frozen=True)
class Chunk:
    text: str
    metadata: dict  # {doc_cid, doc_title, page, chunk_index, language}


def chunk_text(
    text: str,
    *,
    chunk_size: int = 512,
    overlap: int = 64,
    metadata: dict | None = None,
) -> list[Chunk]:
    """Split text using sliding window measured in approximate tokens (chars/4).

    Respects paragraph boundaries (double newline) where possible, else word
    boundaries.
    """
    meta = metadata or {}

    approx_tokens = len(text) // 4
    if approx_tokens <= chunk_size:
        return [Chunk(text=text, metadata=meta)]

    # Split on paragraph boundaries first
    paragraphs = text.split("\n\n")

    chunks: list[Chunk] = []
    current_parts: list[str] = []
    current_tokens = 0

    def flush(parts: list[str]) -> str:
        return "\n\n".join(parts).strip()

    for para in paragraphs:
        para_tokens = len(para) // 4
        if current_tokens + para_tokens > chunk_size and current_parts:
            chunk_text_val = flush(current_parts)
            if chunk_text_val:
                chunks.append(Chunk(text=chunk_text_val, metadata=meta))
            # Carry overlap: keep tail words from current
            overlap_chars = overlap * 4
            tail = (
                chunk_text_val[-overlap_chars:]
                if overlap_chars < len(chunk_text_val)
                else chunk_text_val
            )
            # Find word boundary at start of tail
            space_idx = tail.find(" ")
            if space_idx != -1:
                tail = tail[space_idx + 1 :]
            current_parts = [tail] if tail else []
            current_tokens = len(tail) // 4

        if para_tokens > chunk_size:
            # Para itself too large — split at word boundaries
            words = para.split(" ")
            word_buf: list[str] = []
            word_tokens = 0
            for word in words:
                wt = (len(word) + 1) // 4 or 1
                if word_tokens + wt > chunk_size and word_buf:
                    chunk_text_val = " ".join(word_buf).strip()
                    if chunk_text_val:
                        chunks.append(Chunk(text=chunk_text_val, metadata=meta))
                    # overlap
                    overlap_chars = overlap * 4
                    tail_words = " ".join(word_buf)
                    tail = (
                        tail_words[-overlap_chars:]
                        if overlap_chars < len(tail_words)
                        else tail_words
                    )
                    space_idx = tail.find(" ")
                    if space_idx != -1:
                        tail = tail[space_idx + 1 :]
                    word_buf = tail.split(" ") if tail else []
                    word_tokens = len(tail) // 4
                word_buf.append(word)
                word_tokens += wt
            remaining = " ".join(word_buf).strip()
            if remaining:
                current_parts.append(remaining)
                current_tokens += len(remaining) // 4
        else:
            current_parts.append(para)
            current_tokens += para_tokens

    # Flush remainder
    if current_parts:
        chunk_text_val = flush(current_parts)
        if chunk_text_val:
            chunks.append(Chunk(text=chunk_text_val, metadata=meta))

    return chunks if chunks else [Chunk(text=text, metadata=meta)]


def chunk_pdf(pdf_bytes: bytes, *, doc_metadata: dict) -> list[Chunk]:
    """Extract text per page using pypdf, then chunk_text per page.

    Falls back to treating as plain text if pypdf not installed.
    """
    try:
        import io

        import pypdf  # type: ignore[import-untyped]

        reader = pypdf.PdfReader(io.BytesIO(pdf_bytes))
        all_chunks: list[Chunk] = []
        for page_num, page in enumerate(reader.pages):
            page_text = page.extract_text() or ""
            if not page_text.strip():
                continue
            meta = {**doc_metadata, "page": page_num, "language": "unknown"}
            page_chunks = chunk_text(page_text, metadata=meta)
            all_chunks.extend(page_chunks)
        return all_chunks
    except ImportError:
        # Fallback: treat bytes as UTF-8 text
        text = pdf_bytes.decode("utf-8", errors="replace")
        return chunk_text(text, metadata=doc_metadata)