From 8e477c9d16b33d1dfaf0bcfadf12fea465754e96 Mon Sep 17 00:00:00 2001 From: Alex Date: Wed, 15 Mar 2023 00:23:51 +0000 Subject: [PATCH] update worker --- application/parser/file/rst_parser.py | 17 +------ application/parser/token_func.py | 70 +++++++++++++++++++++++++++ application/worker.py | 19 +++++--- 3 files changed, 85 insertions(+), 21 deletions(-) create mode 100644 application/parser/token_func.py diff --git a/application/parser/file/rst_parser.py b/application/parser/file/rst_parser.py index 0a4724fc..1719b84c 100644 --- a/application/parser/file/rst_parser.py +++ b/application/parser/file/rst_parser.py @@ -29,7 +29,6 @@ class RstParser(BaseParser): remove_whitespaces_excess: bool = True, #Be carefull with remove_characters_excess, might cause data loss remove_characters_excess: bool = True, - max_tokens: int = 2048, **kwargs: Any, ) -> None: """Init params.""" @@ -41,18 +40,6 @@ class RstParser(BaseParser): self._remove_directives = remove_directives self._remove_whitespaces_excess = remove_whitespaces_excess self._remove_characters_excess = remove_characters_excess - self._max_tokens = max_tokens - - def tups_chunk_append(self, tups: List[Tuple[Optional[str], str]], current_header: Optional[str], current_text: str): - """Append to tups chunk.""" - num_tokens = len(tiktoken.get_encoding("cl100k_base").encode(current_text)) - if num_tokens > self._max_tokens: - chunks = [current_text[i:i + self._max_tokens] for i in range(0, len(current_text), self._max_tokens)] - for chunk in chunks: - tups.append((current_header, chunk)) - else: - tups.append((current_header, current_text)) - return tups def rst_to_tups(self, rst_text: str) -> List[Tuple[Optional[str], str]]: @@ -76,14 +63,14 @@ class RstParser(BaseParser): # removes the next heading from current Document if current_text.endswith(lines[i - 1] + "\n"): current_text = current_text[:len(current_text) - len(lines[i - 1] + "\n")] - rst_tups = self.tups_chunk_append(rst_tups, current_header, current_text) + rst_tups.append((current_header, current_text)) current_header = lines[i - 1] current_text = "" else: current_text += line + "\n" - rst_tups = self.tups_chunk_append(rst_tups, current_header, current_text) + rst_tups.append((current_header, current_text)) #TODO: Format for rst # diff --git a/application/parser/token_func.py b/application/parser/token_func.py new file mode 100644 index 00000000..95b318b9 --- /dev/null +++ b/application/parser/token_func.py @@ -0,0 +1,70 @@ +import re +import tiktoken + +from typing import List +from parser.schema.base import Document +from math import ceil + + +def separate_header_and_body(text): + header_pattern = r"^(.*?\n){3}" + match = re.match(header_pattern, text) + header = match.group(0) + body = text[len(header):] + return header, body + +def group_documents(documents: List[Document], min_tokens: int, max_tokens: int) -> List[Document]: + docs = [] + current_group = None + + for doc in documents: + doc_len = len(tiktoken.get_encoding("cl100k_base").encode(doc.text)) + + if current_group is None: + current_group = Document(text=doc.text, doc_id=doc.doc_id, embedding=doc.embedding, + extra_info=doc.extra_info) + elif len(tiktoken.get_encoding("cl100k_base").encode(current_group.text)) + doc_len < max_tokens and doc_len >= min_tokens: + current_group.text += " " + doc.text + else: + docs.append(current_group) + current_group = Document(text=doc.text, doc_id=doc.doc_id, embedding=doc.embedding, + extra_info=doc.extra_info) + + if current_group is not None: + docs.append(current_group) + + return docs + +def split_documents(documents: List[Document], max_tokens: int) -> List[Document]: + docs = [] + for doc in documents: + token_length = len(tiktoken.get_encoding("cl100k_base").encode(doc.text)) + if token_length <= max_tokens: + docs.append(doc) + else: + header, body = separate_header_and_body(doc.text) + num_body_parts = ceil(token_length / max_tokens) + part_length = ceil(len(body) / num_body_parts) + body_parts = [body[i:i + part_length] for i in range(0, len(body), part_length)] + for i, body_part in enumerate(body_parts): + new_doc = Document(text=header + body_part.strip(), + doc_id=f"{doc.doc_id}-{i}", + embedding=doc.embedding, + extra_info=doc.extra_info) + docs.append(new_doc) + return docs + +def group_split(documents: List[Document], max_tokens: int = 2000, min_tokens: int = 150, token_check: bool = True): + if token_check == False: + return documents + print("Grouping small documents") + try: + documents = group_documents(documents=documents, min_tokens=min_tokens, max_tokens=max_tokens) + except: + print("Grouping failed, try running without token_check") + print("Separating large documents") + try: + documents = split_documents(documents=documents, max_tokens=max_tokens) + except: + print("Grouping failed, try running without token_check") + return documents diff --git a/application/worker.py b/application/worker.py index 1a538f7f..268b829a 100644 --- a/application/worker.py +++ b/application/worker.py @@ -1,11 +1,11 @@ import requests import nltk import os -from langchain.text_splitter import RecursiveCharacterTextSplitter from parser.file.bulk import SimpleDirectoryReader from parser.schema.base import Document from parser.open_ai_func import call_openai_api +from parser.token_func import group_split from celery import current_task @@ -33,6 +33,10 @@ def ingest_worker(self, directory, formats, name_job, filename, user): # name_job = 'job1' # filename = 'install.rst' # user = 'local' + sample = False + token_check = True + min_tokens = 150 + max_tokens = 2000 full_path = directory + '/' + user + '/' + name_job # check if API_URL env variable is set if not os.environ.get('API_URL'): @@ -61,14 +65,17 @@ def ingest_worker(self, directory, formats, name_job, filename, user): raw_docs = SimpleDirectoryReader(input_dir=full_path, input_files=input_files, recursive=recursive, required_exts=formats, num_files_limit=limit, exclude_hidden=exclude).load_data() - raw_docs = [Document.to_langchain_format(raw_doc) for raw_doc in raw_docs] - # Here we split the documents, as needed, into smaller chunks. - # We do this due to the context limits of the LLMs. - text_splitter = RecursiveCharacterTextSplitter() - docs = text_splitter.split_documents(raw_docs) + raw_docs = group_split(documents=raw_docs, min_tokens=min_tokens, max_tokens=max_tokens, token_check=token_check) + + docs = [Document.to_langchain_format(raw_doc) for raw_doc in raw_docs] + call_openai_api(docs, full_path, self) self.update_state(state='PROGRESS', meta={'current': 100}) + if sample == True: + for i in range(min(5, len(raw_docs))): + print(raw_docs[i].text) + # get files from outputs/inputs/index.faiss and outputs/inputs/index.pkl # and send them to the server (provide user and name in form) if not os.environ.get('API_URL'):