mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 00:23:17 +00:00
Update application files and fix LLM models, create new retriever class
This commit is contained in:
83
application/retriever/classic_rag.py
Normal file
83
application/retriever/classic_rag.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import os
|
||||
import json
|
||||
from application.retriever.base import BaseRetriever
|
||||
from application.core.settings import settings
|
||||
from application.vectorstore.vector_creator import VectorCreator
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
|
||||
from application.utils import count_tokens
|
||||
|
||||
|
||||
|
||||
class ClassicRAG(BaseRetriever):
|
||||
|
||||
def __init__(self, question, source, chat_history, prompt, chunks=2, gpt_model='docsgpt'):
|
||||
self.question = question
|
||||
self.vectorstore = self._get_vectorstore(source=source)
|
||||
self.chat_history = chat_history
|
||||
self.prompt = prompt
|
||||
self.chunks = chunks
|
||||
self.gpt_model = gpt_model
|
||||
|
||||
def _get_vectorstore(self, source):
|
||||
if "active_docs" in source:
|
||||
if source["active_docs"].split("/")[0] == "default":
|
||||
vectorstore = ""
|
||||
elif source["active_docs"].split("/")[0] == "local":
|
||||
vectorstore = "indexes/" + source["active_docs"]
|
||||
else:
|
||||
vectorstore = "vectors/" + source["active_docs"]
|
||||
if source["active_docs"] == "default":
|
||||
vectorstore = ""
|
||||
else:
|
||||
vectorstore = ""
|
||||
vectorstore = os.path.join("application", vectorstore)
|
||||
return vectorstore
|
||||
|
||||
|
||||
def _get_data(self):
|
||||
if self.chunks == 0:
|
||||
docs = []
|
||||
else:
|
||||
docsearch = VectorCreator.create_vectorstore(settings.VECTOR_STORE, self.vectorstore, settings.EMBEDDINGS_KEY)
|
||||
docs = docsearch.search(self.question, k=self.chunks)
|
||||
if settings.LLM_NAME == "llama.cpp":
|
||||
docs = [docs[0]]
|
||||
return docs
|
||||
|
||||
def gen(self):
|
||||
docs = self._get_data()
|
||||
|
||||
# join all page_content together with a newline
|
||||
docs_together = "\n".join([doc.page_content for doc in docs])
|
||||
p_chat_combine = self.prompt.replace("{summaries}", docs_together)
|
||||
messages_combine = [{"role": "system", "content": p_chat_combine}]
|
||||
for doc in docs:
|
||||
if doc.metadata:
|
||||
yield {"source": {"title": doc.metadata['title'].split('/')[-1], "text": doc.page_content}}
|
||||
else:
|
||||
yield {"source": {"title": doc.page_content, "text": doc.page_content}}
|
||||
|
||||
if len(self.chat_history) > 1:
|
||||
tokens_current_history = 0
|
||||
# count tokens in history
|
||||
self.chat_history.reverse()
|
||||
for i in self.chat_history:
|
||||
if "prompt" in i and "response" in i:
|
||||
tokens_batch = count_tokens(i["prompt"]) + count_tokens(i["response"])
|
||||
if tokens_current_history + tokens_batch < settings.TOKENS_MAX_HISTORY:
|
||||
tokens_current_history += tokens_batch
|
||||
messages_combine.append({"role": "user", "content": i["prompt"]})
|
||||
messages_combine.append({"role": "system", "content": i["response"]})
|
||||
messages_combine.append({"role": "user", "content": self.question})
|
||||
|
||||
llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY)
|
||||
|
||||
completion = llm.gen_stream(model=self.gpt_model,
|
||||
messages=messages_combine)
|
||||
for line in completion:
|
||||
yield {"answer": str(line)}
|
||||
|
||||
def search(self):
|
||||
return self._get_data()
|
||||
|
||||
Reference in New Issue
Block a user