mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-12-01 01:23:14 +00:00
Update application files, fix LLM models, and create new retriever class
This commit is contained in:
@@ -103,7 +103,7 @@ def is_azure_configured():
|
||||
return settings.OPENAI_API_BASE and settings.OPENAI_API_VERSION and settings.AZURE_DEPLOYMENT_NAME
|
||||
|
||||
def save_conversation(conversation_id, question, response, source_log_docs, llm):
|
||||
if conversation_id is not None:
|
||||
if conversation_id is not None and conversation_id != "None":
|
||||
conversations_collection.update_one(
|
||||
{"_id": ObjectId(conversation_id)},
|
||||
{"$push": {"queries": {"prompt": question, "response": response, "sources": source_log_docs}}},
|
||||
@@ -129,6 +129,7 @@ def save_conversation(conversation_id, question, response, source_log_docs, llm)
|
||||
"name": completion,
|
||||
"queries": [{"prompt": question, "response": response, "sources": source_log_docs}]}
|
||||
).inserted_id
|
||||
return conversation_id
|
||||
|
||||
def get_prompt(prompt_id):
|
||||
if prompt_id == 'default':
|
||||
@@ -293,12 +294,5 @@ def api_search():
|
||||
source=source, chat_history=[], prompt="default", chunks=chunks, gpt_model=gpt_model
|
||||
)
|
||||
docs = retriever.search()
|
||||
|
||||
source_log_docs = []
|
||||
for doc in docs:
|
||||
if doc.metadata:
|
||||
source_log_docs.append({"title": doc.metadata['title'].split('/')[-1], "text": doc.page_content})
|
||||
else:
|
||||
source_log_docs.append({"title": doc.page_content, "text": doc.page_content})
|
||||
return source_log_docs
|
||||
return docs
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ boto3==1.34.6
|
||||
celery==5.3.6
|
||||
dataclasses_json==0.6.3
|
||||
docx2txt==0.8
|
||||
duckduckgo-search=5.3.0
|
||||
EbookLib==0.18
|
||||
elasticsearch==8.12.0
|
||||
escodegen==1.0.11
|
||||
|
||||
@@ -8,3 +8,7 @@ class BaseRetriever(ABC):
|
||||
@abstractmethod
|
||||
def gen(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
@@ -40,23 +40,22 @@ class ClassicRAG(BaseRetriever):
|
||||
docs = []
|
||||
else:
|
||||
docsearch = VectorCreator.create_vectorstore(settings.VECTOR_STORE, self.vectorstore, settings.EMBEDDINGS_KEY)
|
||||
docs = docsearch.search(self.question, k=self.chunks)
|
||||
docs_temp = docsearch.search(self.question, k=self.chunks)
|
||||
docs = [{"title": i.metadata['title'].split('/')[-1] if i.metadata else i.page_content, "text": i.page_content} for i in docs_temp]
|
||||
if settings.LLM_NAME == "llama.cpp":
|
||||
docs = [docs[0]]
|
||||
|
||||
return docs
|
||||
|
||||
def gen(self):
|
||||
docs = self._get_data()
|
||||
|
||||
# join all page_content together with a newline
|
||||
docs_together = "\n".join([doc.page_content for doc in docs])
|
||||
docs_together = "\n".join([doc["text"] for doc in docs])
|
||||
p_chat_combine = self.prompt.replace("{summaries}", docs_together)
|
||||
messages_combine = [{"role": "system", "content": p_chat_combine}]
|
||||
for doc in docs:
|
||||
if doc.metadata:
|
||||
yield {"source": {"title": doc.metadata['title'].split('/')[-1], "text": doc.page_content}}
|
||||
else:
|
||||
yield {"source": {"title": doc.page_content, "text": doc.page_content}}
|
||||
yield {"source": doc}
|
||||
|
||||
if len(self.chat_history) > 1:
|
||||
tokens_current_history = 0
|
||||
|
||||
97
application/retriever/duckduck_search.py
Normal file
97
application/retriever/duckduck_search.py
Normal file
@@ -0,0 +1,97 @@
|
||||
import json
|
||||
import ast
|
||||
from application.retriever.base import BaseRetriever
|
||||
from application.core.settings import settings
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
from application.utils import count_tokens
|
||||
from langchain_community.tools import DuckDuckGoSearchResults
|
||||
from langchain_community.utilities import DuckDuckGoSearchAPIWrapper
|
||||
|
||||
|
||||
|
||||
class DuckDuckSearch(BaseRetriever):
|
||||
|
||||
def __init__(self, question, source, chat_history, prompt, chunks=2, gpt_model='docsgpt'):
|
||||
self.question = question
|
||||
self.source = source
|
||||
self.chat_history = chat_history
|
||||
self.prompt = prompt
|
||||
self.chunks = chunks
|
||||
self.gpt_model = gpt_model
|
||||
|
||||
def _parse_lang_string(self, input_string):
|
||||
result = []
|
||||
current_item = ""
|
||||
inside_brackets = False
|
||||
for char in input_string:
|
||||
if char == "[":
|
||||
inside_brackets = True
|
||||
elif char == "]":
|
||||
inside_brackets = False
|
||||
result.append(current_item)
|
||||
current_item = ""
|
||||
elif inside_brackets:
|
||||
current_item += char
|
||||
|
||||
# Check if there is an unmatched opening bracket at the end
|
||||
if inside_brackets:
|
||||
result.append(current_item)
|
||||
|
||||
return result
|
||||
|
||||
def _get_data(self):
|
||||
if self.chunks == 0:
|
||||
docs = []
|
||||
else:
|
||||
wrapper = DuckDuckGoSearchAPIWrapper(max_results=self.chunks)
|
||||
search = DuckDuckGoSearchResults(api_wrapper=wrapper)
|
||||
results = search.run(self.question)
|
||||
results = self._parse_lang_string(results)
|
||||
|
||||
docs = []
|
||||
for i in results:
|
||||
try:
|
||||
text = i.split("title:")[0]
|
||||
title = i.split("title:")[1].split("link:")[0]
|
||||
link = i.split("link:")[1]
|
||||
docs.append({"text": text, "title": title, "link": link})
|
||||
except IndexError:
|
||||
pass
|
||||
if settings.LLM_NAME == "llama.cpp":
|
||||
docs = [docs[0]]
|
||||
|
||||
return docs
|
||||
|
||||
def gen(self):
|
||||
docs = self._get_data()
|
||||
|
||||
# join all page_content together with a newline
|
||||
docs_together = "\n".join([doc["text"] for doc in docs])
|
||||
p_chat_combine = self.prompt.replace("{summaries}", docs_together)
|
||||
messages_combine = [{"role": "system", "content": p_chat_combine}]
|
||||
for doc in docs:
|
||||
yield {"source": doc}
|
||||
|
||||
if len(self.chat_history) > 1:
|
||||
tokens_current_history = 0
|
||||
# count tokens in history
|
||||
self.chat_history.reverse()
|
||||
for i in self.chat_history:
|
||||
if "prompt" in i and "response" in i:
|
||||
tokens_batch = count_tokens(i["prompt"]) + count_tokens(i["response"])
|
||||
if tokens_current_history + tokens_batch < settings.TOKENS_MAX_HISTORY:
|
||||
tokens_current_history += tokens_batch
|
||||
messages_combine.append({"role": "user", "content": i["prompt"]})
|
||||
messages_combine.append({"role": "system", "content": i["response"]})
|
||||
messages_combine.append({"role": "user", "content": self.question})
|
||||
|
||||
llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY)
|
||||
|
||||
completion = llm.gen_stream(model=self.gpt_model,
|
||||
messages=messages_combine)
|
||||
for line in completion:
|
||||
yield {"answer": str(line)}
|
||||
|
||||
def search(self):
|
||||
return self._get_data()
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
from application.retriever.classic_rag import ClassicRAG
|
||||
from application.retriever.duckduck_search import DuckDuckSearch
|
||||
|
||||
|
||||
|
||||
class RetrieverCreator:
|
||||
retievers = {
|
||||
'classic': ClassicRAG,
|
||||
'duckduck': DuckDuckSearch
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
||||
Reference in New Issue
Block a user