mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 16:43:16 +00:00
119 lines
4.7 KiB
Python
119 lines
4.7 KiB
Python
import re
|
|
from typing import List, Tuple, Union
|
|
import logging
|
|
from application.parser.schema.base import Document
|
|
from application.utils import get_encoding
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class Chunker:
|
|
def __init__(
|
|
self,
|
|
chunking_strategy: str = "classic_chunk",
|
|
max_tokens: int = 2000,
|
|
min_tokens: int = 150,
|
|
duplicate_headers: bool = False,
|
|
):
|
|
if chunking_strategy not in ["classic_chunk"]:
|
|
raise ValueError(f"Unsupported chunking strategy: {chunking_strategy}")
|
|
self.chunking_strategy = chunking_strategy
|
|
self.max_tokens = max_tokens
|
|
self.min_tokens = min_tokens
|
|
self.duplicate_headers = duplicate_headers
|
|
self.encoding = get_encoding()
|
|
|
|
def separate_header_and_body(self, text: str) -> Tuple[str, str]:
|
|
header_pattern = r"^(.*?\n){3}"
|
|
match = re.match(header_pattern, text)
|
|
if match:
|
|
header = match.group(0)
|
|
body = text[len(header):]
|
|
else:
|
|
header, body = "", text # No header, treat entire text as body
|
|
return header, body
|
|
|
|
def combine_documents(self, doc: Document, next_doc: Document) -> Document:
|
|
combined_text = doc.text + " " + next_doc.text
|
|
combined_token_count = len(self.encoding.encode(combined_text))
|
|
new_doc = Document(
|
|
text=combined_text,
|
|
doc_id=doc.doc_id,
|
|
embedding=doc.embedding,
|
|
extra_info={**(doc.extra_info or {}), "token_count": combined_token_count}
|
|
)
|
|
return new_doc
|
|
|
|
def split_document(self, doc: Document) -> List[Document]:
|
|
split_docs = []
|
|
header, body = self.separate_header_and_body(doc.text)
|
|
header_tokens = self.encoding.encode(header) if header else []
|
|
body_tokens = self.encoding.encode(body)
|
|
|
|
current_position = 0
|
|
part_index = 0
|
|
while current_position < len(body_tokens):
|
|
end_position = current_position + self.max_tokens - len(header_tokens)
|
|
chunk_tokens = (header_tokens + body_tokens[current_position:end_position]
|
|
if self.duplicate_headers or part_index == 0 else body_tokens[current_position:end_position])
|
|
chunk_text = self.encoding.decode(chunk_tokens)
|
|
new_doc = Document(
|
|
text=chunk_text,
|
|
doc_id=f"{doc.doc_id}-{part_index}",
|
|
embedding=doc.embedding,
|
|
extra_info={**(doc.extra_info or {}), "token_count": len(chunk_tokens)}
|
|
)
|
|
split_docs.append(new_doc)
|
|
current_position = end_position
|
|
part_index += 1
|
|
header_tokens = []
|
|
return split_docs
|
|
|
|
def classic_chunk(self, documents: List[Document]) -> List[Document]:
|
|
processed_docs = []
|
|
i = 0
|
|
while i < len(documents):
|
|
doc = documents[i]
|
|
tokens = self.encoding.encode(doc.text)
|
|
token_count = len(tokens)
|
|
|
|
if self.min_tokens <= token_count <= self.max_tokens:
|
|
doc.extra_info = doc.extra_info or {}
|
|
doc.extra_info["token_count"] = token_count
|
|
processed_docs.append(doc)
|
|
i += 1
|
|
elif token_count < self.min_tokens:
|
|
if i + 1 < len(documents):
|
|
next_doc = documents[i + 1]
|
|
next_tokens = self.encoding.encode(next_doc.text)
|
|
if token_count + len(next_tokens) <= self.max_tokens:
|
|
# Combine small documents
|
|
combined_doc = self.combine_documents(doc, next_doc)
|
|
processed_docs.append(combined_doc)
|
|
i += 2
|
|
else:
|
|
# Keep the small document as is if adding next_doc would exceed max_tokens
|
|
doc.extra_info = doc.extra_info or {}
|
|
doc.extra_info["token_count"] = token_count
|
|
processed_docs.append(doc)
|
|
i += 1
|
|
else:
|
|
# No next document to combine with; add the small document as is
|
|
doc.extra_info = doc.extra_info or {}
|
|
doc.extra_info["token_count"] = token_count
|
|
processed_docs.append(doc)
|
|
i += 1
|
|
else:
|
|
# Split large documents
|
|
processed_docs.extend(self.split_document(doc))
|
|
i += 1
|
|
return processed_docs
|
|
|
|
def chunk(
|
|
self,
|
|
documents: List[Document]
|
|
) -> List[Document]:
|
|
if self.chunking_strategy == "classic_chunk":
|
|
return self.classic_chunk(documents)
|
|
else:
|
|
raise ValueError("Unsupported chunking strategy")
|