diff --git a/application/parser/token_func.py b/application/parser/token_func.py index e77376f5..80a55674 100644 --- a/application/parser/token_func.py +++ b/application/parser/token_func.py @@ -46,6 +46,10 @@ def split_documents(documents: List[Document], max_tokens: int) -> List[Document docs.append(doc) else: header, body = separate_header_and_body(doc.text) + if len(tiktoken.get_encoding("cl100k_base").encode(header)) > max_tokens: + print("header too long, skipping", file=sys.stderr) + body = doc.text + header = "" num_body_parts = ceil(token_length / max_tokens) part_length = ceil(len(body) / num_body_parts) body_parts = [body[i:i + part_length] for i in range(0, len(body), part_length)] diff --git a/scripts/parser/token_func.py b/scripts/parser/token_func.py index e77376f5..bb386f2e 100644 --- a/scripts/parser/token_func.py +++ b/scripts/parser/token_func.py @@ -4,7 +4,7 @@ from typing import List import tiktoken from parser.schema.base import Document - +import sys def separate_header_and_body(text): header_pattern = r"^(.*?\n){3}" @@ -17,6 +17,7 @@ def separate_header_and_body(text): def group_documents(documents: List[Document], min_tokens: int, max_tokens: int) -> List[Document]: docs = [] current_group = None + print("Grouping", len(documents), "documents", file=sys.stderr) for doc in documents: doc_len = len(tiktoken.get_encoding("cl100k_base").encode(doc.text)) @@ -46,6 +47,9 @@ def split_documents(documents: List[Document], max_tokens: int) -> List[Document docs.append(doc) else: header, body = separate_header_and_body(doc.text) + if len(tiktoken.get_encoding("cl100k_base").encode(header)) > max_tokens: + body = doc.text + header = "" num_body_parts = ceil(token_length / max_tokens) part_length = ceil(len(body) / num_body_parts) body_parts = [body[i:i + part_length] for i in range(0, len(body), part_length)] @@ -55,6 +59,7 @@ def split_documents(documents: List[Document], max_tokens: int) -> List[Document embedding=doc.embedding, extra_info=doc.extra_info) docs.append(new_doc) + print("split into", len(docs), "documents", file=sys.stderr) return docs