Merge pull request #169 from arc53/min-max-tokens

Min max tokens
This commit is contained in:
Alex
2023-03-14 14:02:06 +00:00
committed by GitHub
3 changed files with 96 additions and 27 deletions

View File

@@ -3,14 +3,10 @@ import sys
import nltk
import dotenv
import typer
import ast
from collections import defaultdict
from pathlib import Path
from typing import List, Optional
from langchain.text_splitter import RecursiveCharacterTextSplitter
from parser.file.bulk import SimpleDirectoryReader
from parser.schema.base import Document
from parser.open_ai_func import call_openai_api, get_user_permission
@@ -18,6 +14,7 @@ from parser.py2doc import transform_to_docs
from parser.py2doc import extract_functions_and_classes as extract_py
from parser.js2doc import extract_functions_and_classes as extract_js
from parser.java2doc import extract_functions_and_classes as extract_java
from parser.token_func import group_split
dotenv.load_dotenv()
@@ -38,14 +35,17 @@ def ingest(yes: bool = typer.Option(False, "-y", "--yes", prompt=False,
file: Optional[List[str]] = typer.Option(None,
help="""File paths to use (Optional; overrides dir).
E.g. --file inputs/1.md --file inputs/2.md"""),
recursive: Optional[bool] = typer.Option(True,
help="Whether to recursively search in subdirectories."),
limit: Optional[int] = typer.Option(None,
help="Maximum number of files to read."),
recursive: Optional[bool] = typer.Option(True, help="Whether to recursively search in subdirectories."),
limit: Optional[int] = typer.Option(None, help="Maximum number of files to read."),
formats: Optional[List[str]] = typer.Option([".rst", ".md"],
help="""List of required extensions (list with .)
Currently supported: .rst, .md, .pdf, .docx, .csv, .epub, .html, .mdx"""),
exclude: Optional[bool] = typer.Option(True, help="Whether to exclude hidden files (dotfiles).")):
exclude: Optional[bool] = typer.Option(True, help="Whether to exclude hidden files (dotfiles)."),
sample: Optional[bool] = typer.Option(False, help="Whether to output sample of the first 5 split documents."),
token_check: Optional[bool] = typer.Option(True, help="Whether to group small documents and split large."),
min_tokens: Optional[int] = typer.Option(150, help="Minimum number of tokens to not group."),
max_tokens: Optional[int] = typer.Option(2000, help="Maximum number of tokens to not split."),
):
"""
Creates index from specified location or files.
@@ -56,11 +56,21 @@ def ingest(yes: bool = typer.Option(False, "-y", "--yes", prompt=False,
raw_docs = SimpleDirectoryReader(input_dir=directory, input_files=file, recursive=recursive,
required_exts=formats, num_files_limit=limit,
exclude_hidden=exclude).load_data()
raw_docs = [Document.to_langchain_format(raw_doc) for raw_doc in raw_docs]
# Here we split the documents, as needed, into smaller chunks.
# We do this due to the context limits of the LLMs.
text_splitter = RecursiveCharacterTextSplitter()
docs = text_splitter.split_documents(raw_docs)
raw_docs = group_split(documents=raw_docs, min_tokens=min_tokens, max_tokens=max_tokens, token_check=token_check)
#Old method
# text_splitter = RecursiveCharacterTextSplitter()
# docs = text_splitter.split_documents(raw_docs)
#Sample feature
if sample == True:
for i in range(min(5, len(raw_docs))):
print(raw_docs[i].text)
docs = [Document.to_langchain_format(raw_doc) for raw_doc in raw_docs]
# Here we check for command line arguments for bot calls.
# If no argument exists or the yes is not True, then the
@@ -109,3 +119,5 @@ def convert(dir: Optional[str] = typer.Option("inputs",
transform_to_docs(functions_dict, classes_dict, formats, dir)
if __name__ == "__main__":
app()

View File

@@ -29,7 +29,6 @@ class RstParser(BaseParser):
remove_whitespaces_excess: bool = True,
#Be carefull with remove_characters_excess, might cause data loss
remove_characters_excess: bool = True,
max_tokens: int = 2048,
**kwargs: Any,
) -> None:
"""Init params."""
@@ -41,18 +40,6 @@ class RstParser(BaseParser):
self._remove_directives = remove_directives
self._remove_whitespaces_excess = remove_whitespaces_excess
self._remove_characters_excess = remove_characters_excess
self._max_tokens = max_tokens
def tups_chunk_append(self, tups: List[Tuple[Optional[str], str]], current_header: Optional[str], current_text: str):
"""Append to tups chunk."""
num_tokens = len(tiktoken.get_encoding("cl100k_base").encode(current_text))
if num_tokens > self._max_tokens:
chunks = [current_text[i:i + self._max_tokens] for i in range(0, len(current_text), self._max_tokens)]
for chunk in chunks:
tups.append((current_header, chunk))
else:
tups.append((current_header, current_text))
return tups
def rst_to_tups(self, rst_text: str) -> List[Tuple[Optional[str], str]]:
@@ -76,14 +63,14 @@ class RstParser(BaseParser):
# removes the next heading from current Document
if current_text.endswith(lines[i - 1] + "\n"):
current_text = current_text[:len(current_text) - len(lines[i - 1] + "\n")]
rst_tups = self.tups_chunk_append(rst_tups, current_header, current_text)
rst_tups.append((current_header, current_text))
current_header = lines[i - 1]
current_text = ""
else:
current_text += line + "\n"
rst_tups = self.tups_chunk_append(rst_tups, current_header, current_text)
rst_tups.append((current_header, current_text))
#TODO: Format for rst
#

View File

@@ -0,0 +1,70 @@
import re
import tiktoken
from typing import List
from parser.schema.base import Document
from math import ceil
def separate_header_and_body(text):
header_pattern = r"^(.*?\n){3}"
match = re.match(header_pattern, text)
header = match.group(0)
body = text[len(header):]
return header, body
def group_documents(documents: List[Document], min_tokens: int, max_tokens: int) -> List[Document]:
docs = []
current_group = None
for doc in documents:
doc_len = len(tiktoken.get_encoding("cl100k_base").encode(doc.text))
if current_group is None:
current_group = Document(text=doc.text, doc_id=doc.doc_id, embedding=doc.embedding,
extra_info=doc.extra_info)
elif len(tiktoken.get_encoding("cl100k_base").encode(current_group.text)) + doc_len < max_tokens and doc_len >= min_tokens:
current_group.text += " " + doc.text
else:
docs.append(current_group)
current_group = Document(text=doc.text, doc_id=doc.doc_id, embedding=doc.embedding,
extra_info=doc.extra_info)
if current_group is not None:
docs.append(current_group)
return docs
def split_documents(documents: List[Document], max_tokens: int) -> List[Document]:
docs = []
for doc in documents:
token_length = len(tiktoken.get_encoding("cl100k_base").encode(doc.text))
if token_length <= max_tokens:
docs.append(doc)
else:
header, body = separate_header_and_body(doc.text)
num_body_parts = ceil(token_length / max_tokens)
part_length = ceil(len(body) / num_body_parts)
body_parts = [body[i:i + part_length] for i in range(0, len(body), part_length)]
for i, body_part in enumerate(body_parts):
new_doc = Document(text=header + body_part.strip(),
doc_id=f"{doc.doc_id}-{i}",
embedding=doc.embedding,
extra_info=doc.extra_info)
docs.append(new_doc)
return docs
def group_split(documents: List[Document], max_tokens: int = 2000, min_tokens: int = 150, token_check: bool = True):
if token_check == False:
return documents
print("Grouping small documents")
try:
documents = group_documents(documents=documents, min_tokens=min_tokens, max_tokens=max_tokens)
except:
print("Grouping failed, try running without token_check")
print("Separating large documents")
try:
documents = split_documents(documents=documents, max_tokens=max_tokens)
except:
print("Grouping failed, try running without token_check")
return documents