fix: api_key capturing + pytest errors

This commit is contained in:
Siddhant Rai
2024-04-15 22:32:24 +05:30
parent 60a670ce29
commit 77991896b4
11 changed files with 276 additions and 145 deletions

View File

@@ -7,21 +7,30 @@ from application.llm.llm_creator import LLMCreator
from application.utils import count_tokens
class ClassicRAG(BaseRetriever):
def __init__(self, question, source, chat_history, prompt, chunks=2, gpt_model='docsgpt'):
def __init__(
self,
question,
source,
chat_history,
prompt,
chunks=2,
gpt_model="docsgpt",
api_key=None,
):
self.question = question
self.vectorstore = self._get_vectorstore(source=source)
self.chat_history = chat_history
self.prompt = prompt
self.chunks = chunks
self.gpt_model = gpt_model
self.api_key = api_key
def _get_vectorstore(self, source):
if "active_docs" in source:
if source["active_docs"].split("/")[0] == "default":
vectorstore = ""
vectorstore = ""
elif source["active_docs"].split("/")[0] == "local":
vectorstore = "indexes/" + source["active_docs"]
else:
@@ -33,32 +42,33 @@ class ClassicRAG(BaseRetriever):
vectorstore = os.path.join("application", vectorstore)
return vectorstore
def _get_data(self):
if self.chunks == 0:
docs = []
else:
docsearch = VectorCreator.create_vectorstore(
settings.VECTOR_STORE,
self.vectorstore,
settings.EMBEDDINGS_KEY
settings.VECTOR_STORE, self.vectorstore, settings.EMBEDDINGS_KEY
)
docs_temp = docsearch.search(self.question, k=self.chunks)
docs = [
{
"title": i.metadata['title'].split('/')[-1] if i.metadata else i.page_content,
"text": i.page_content
}
"title": (
i.metadata["title"].split("/")[-1]
if i.metadata
else i.page_content
),
"text": i.page_content,
}
for i in docs_temp
]
if settings.LLM_NAME == "llama.cpp":
docs = [docs[0]]
return docs
def gen(self):
docs = self._get_data()
# join all page_content together with a newline
docs_together = "\n".join([doc["text"] for doc in docs])
p_chat_combine = self.prompt.replace("{summaries}", docs_together)
@@ -72,20 +82,27 @@ class ClassicRAG(BaseRetriever):
self.chat_history.reverse()
for i in self.chat_history:
if "prompt" in i and "response" in i:
tokens_batch = count_tokens(i["prompt"]) + count_tokens(i["response"])
if tokens_current_history + tokens_batch < settings.TOKENS_MAX_HISTORY:
tokens_batch = count_tokens(i["prompt"]) + count_tokens(
i["response"]
)
if (
tokens_current_history + tokens_batch
< settings.TOKENS_MAX_HISTORY
):
tokens_current_history += tokens_batch
messages_combine.append({"role": "user", "content": i["prompt"]})
messages_combine.append({"role": "system", "content": i["response"]})
messages_combine.append(
{"role": "user", "content": i["prompt"]}
)
messages_combine.append(
{"role": "system", "content": i["response"]}
)
messages_combine.append({"role": "user", "content": self.question})
llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=settings.API_KEY)
llm = LLMCreator.create_llm(settings.LLM_NAME, api_key=self.api_key)
completion = llm.gen_stream(model=self.gpt_model,
messages=messages_combine)
completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine)
for line in completion:
yield {"answer": str(line)}
def search(self):
return self._get_data()