From 1e26943c3e2cf1ccb4526a9686d79968e63ba6bc Mon Sep 17 00:00:00 2001 From: Alex Date: Tue, 9 Apr 2024 15:45:24 +0100 Subject: [PATCH] Update application files, fix LLM models, and create new retriever class --- application/api/answer/routes.py | 12 +-- application/requirements.txt | 1 + application/retriever/base.py | 4 + application/retriever/classic_rag.py | 11 ++- application/retriever/duckduck_search.py | 97 ++++++++++++++++++++++ application/retriever/retriever_creator.py | 2 + 6 files changed, 112 insertions(+), 15 deletions(-) create mode 100644 application/retriever/duckduck_search.py diff --git a/application/api/answer/routes.py b/application/api/answer/routes.py index 1b3c9b97..11e6c4ad 100644 --- a/application/api/answer/routes.py +++ b/application/api/answer/routes.py @@ -103,7 +103,7 @@ def is_azure_configured(): return settings.OPENAI_API_BASE and settings.OPENAI_API_VERSION and settings.AZURE_DEPLOYMENT_NAME def save_conversation(conversation_id, question, response, source_log_docs, llm): - if conversation_id is not None: + if conversation_id is not None and conversation_id != "None": conversations_collection.update_one( {"_id": ObjectId(conversation_id)}, {"$push": {"queries": {"prompt": question, "response": response, "sources": source_log_docs}}}, @@ -129,6 +129,7 @@ def save_conversation(conversation_id, question, response, source_log_docs, llm) "name": completion, "queries": [{"prompt": question, "response": response, "sources": source_log_docs}]} ).inserted_id + return conversation_id def get_prompt(prompt_id): if prompt_id == 'default': @@ -293,12 +294,5 @@ def api_search(): source=source, chat_history=[], prompt="default", chunks=chunks, gpt_model=gpt_model ) docs = retriever.search() - - source_log_docs = [] - for doc in docs: - if doc.metadata: - source_log_docs.append({"title": doc.metadata['title'].split('/')[-1], "text": doc.page_content}) - else: - source_log_docs.append({"title": doc.page_content, "text": doc.page_content}) - return source_log_docs + return docs diff --git a/application/requirements.txt b/application/requirements.txt index 0874a7c9..c984534a 100644 --- a/application/requirements.txt +++ b/application/requirements.txt @@ -3,6 +3,7 @@ boto3==1.34.6 celery==5.3.6 dataclasses_json==0.6.3 docx2txt==0.8 +duckduckgo-search=5.3.0 EbookLib==0.18 elasticsearch==8.12.0 escodegen==1.0.11 diff --git a/application/retriever/base.py b/application/retriever/base.py index 3bfaa5e1..4a37e810 100644 --- a/application/retriever/base.py +++ b/application/retriever/base.py @@ -8,3 +8,7 @@ class BaseRetriever(ABC): @abstractmethod def gen(self, *args, **kwargs): pass + + @abstractmethod + def search(self, *args, **kwargs): + pass diff --git a/application/retriever/classic_rag.py b/application/retriever/classic_rag.py index dc757946..a5bf8e3c 100644 --- a/application/retriever/classic_rag.py +++ b/application/retriever/classic_rag.py @@ -40,23 +40,22 @@ class ClassicRAG(BaseRetriever): docs = [] else: docsearch = VectorCreator.create_vectorstore(settings.VECTOR_STORE, self.vectorstore, settings.EMBEDDINGS_KEY) - docs = docsearch.search(self.question, k=self.chunks) + docs_temp = docsearch.search(self.question, k=self.chunks) + docs = [{"title": i.metadata['title'].split('/')[-1] if i.metadata else i.page_content, "text": i.page_content} for i in docs_temp] if settings.LLM_NAME == "llama.cpp": docs = [docs[0]] + return docs def gen(self): docs = self._get_data() # join all page_content together with a newline - docs_together = "\n".join([doc.page_content for doc in docs]) + docs_together = "\n".join([doc["text"] for doc in docs]) p_chat_combine = self.prompt.replace("{summaries}", docs_together) messages_combine = [{"role": "system", "content": p_chat_combine}] for doc in docs: - if doc.metadata: - yield {"source": {"title": doc.metadata['title'].split('/')[-1], "text": doc.page_content}} - else: - yield {"source": {"title": doc.page_content, "text": doc.page_content}} + yield {"source": doc} if len(self.chat_history) > 1: tokens_current_history = 0 diff --git a/application/retriever/duckduck_search.py b/application/retriever/duckduck_search.py new file mode 100644 index 00000000..778313c9 --- /dev/null +++ b/application/retriever/duckduck_search.py @@ -0,0 +1,97 @@ +import json +import ast +from application.retriever.base import BaseRetriever +from application.core.settings import settings +from application.llm.llm_creator import LLMCreator +from application.utils import count_tokens +from langchain_community.tools import DuckDuckGoSearchResults +from langchain_community.utilities import DuckDuckGoSearchAPIWrapper + + + +class DuckDuckSearch(BaseRetriever): + + def __init__(self, question, source, chat_history, prompt, chunks=2, gpt_model='docsgpt'): + self.question = question + self.source = source + self.chat_history = chat_history + self.prompt = prompt + self.chunks = chunks + self.gpt_model = gpt_model + + def _parse_lang_string(self, input_string): + result = [] + current_item = "" + inside_brackets = False + for char in input_string: + if char == "[": + inside_brackets = True + elif char == "]": + inside_brackets = False + result.append(current_item) + current_item = "" + elif inside_brackets: + current_item += char + + # Check if there is an unmatched opening bracket at the end + if inside_brackets: + result.append(current_item) + + return result + + def _get_data(self): + if self.chunks == 0: + docs = [] + else: + wrapper = DuckDuckGoSearchAPIWrapper(max_results=self.chunks) + search = DuckDuckGoSearchResults(api_wrapper=wrapper) + results = search.run(self.question) + results = self._parse_lang_string(results) + + docs = [] + for i in results: + try: + text = i.split("title:")[0] + title = i.split("title:")[1].split("link:")[0] + link = i.split("link:")[1] + docs.append({"text": text, "title": title, "link": link}) + except IndexError: + pass + if settings.LLM_NAME == "llama.cpp": + docs = [docs[0]] + + return docs + + def gen(self): + docs = self._get_data() + + # join all page_content together with a newline + docs_together = "\n".join([doc["text"] for doc in docs]) + p_chat_combine = self.prompt.replace("{summaries}", docs_together) + messages_combine = [{"role": "system", "content": p_chat_combine}] + for doc in docs: + yield {"source": doc} + + if len(self.chat_history) > 1: + tokens_current_history = 0 + # count tokens in history + self.chat_history.reverse() + for i in self.chat_history: + if "prompt" in i and "response" in i: + tokens_batch = count_tokens(i["prompt"]) + count_tokens(i["response"]) + if tokens_current_history + tokens_batch < settings.TOKENS_MAX_HISTORY: + tokens_current_history += tokens_batch + messages_combine.append({"role": "user", "content": i["prompt"]}) + messages_combine.append({"role": "system", "content": i["response"]}) + messages_combine.append({"role": "user", "content": self.question}) + + llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY) + + completion = llm.gen_stream(model=self.gpt_model, + messages=messages_combine) + for line in completion: + yield {"answer": str(line)} + + def search(self): + return self._get_data() + diff --git a/application/retriever/retriever_creator.py b/application/retriever/retriever_creator.py index 9255f4e6..892d63ab 100644 --- a/application/retriever/retriever_creator.py +++ b/application/retriever/retriever_creator.py @@ -1,10 +1,12 @@ from application.retriever.classic_rag import ClassicRAG +from application.retriever.duckduck_search import DuckDuckSearch class RetrieverCreator: retievers = { 'classic': ClassicRAG, + 'duckduck': DuckDuckSearch } @classmethod