fix: python lint errors

This commit is contained in:
Siddhant Rai
2024-12-19 10:06:06 +05:30
parent c3f538c2f6
commit daa332aa20
4 changed files with 33 additions and 23 deletions

View File

@@ -1,10 +1,9 @@
from application.retriever.base import BaseRetriever
from application.core.settings import settings
from application.vectorstore.vector_creator import VectorCreator
from application.llm.llm_creator import LLMCreator
from application.retriever.base import BaseRetriever
from application.tools.agent import Agent
from application.utils import num_tokens_from_string
from application.vectorstore.vector_creator import VectorCreator
class ClassicRAG(BaseRetriever):
@@ -21,7 +20,7 @@ class ClassicRAG(BaseRetriever):
user_api_key=None,
):
self.question = question
self.vectorstore = source['active_docs'] if 'active_docs' in source else None
self.vectorstore = source["active_docs"] if "active_docs" in source else None
self.chat_history = chat_history
self.prompt = prompt
self.chunks = chunks
@@ -78,9 +77,9 @@ class ClassicRAG(BaseRetriever):
# count tokens in history
for i in self.chat_history:
if "prompt" in i and "response" in i:
tokens_batch = num_tokens_from_string(i["prompt"]) + num_tokens_from_string(
i["response"]
)
tokens_batch = num_tokens_from_string(
i["prompt"]
) + num_tokens_from_string(i["response"])
if tokens_current_history + tokens_batch < self.token_limit:
tokens_current_history += tokens_batch
messages_combine.append(
@@ -95,14 +94,19 @@ class ClassicRAG(BaseRetriever):
# settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=self.user_api_key
# )
# completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine)
agent = Agent(llm_name=settings.LLM_NAME,gpt_model=self.gpt_model, api_key=settings.API_KEY, user_api_key=self.user_api_key)
agent = Agent(
llm_name=settings.LLM_NAME,
gpt_model=self.gpt_model,
api_key=settings.API_KEY,
user_api_key=self.user_api_key,
)
completion = agent.gen(messages_combine)
for line in completion:
yield {"answer": str(line)}
def search(self):
return self._get_data()
def get_params(self):
return {
"question": self.question,
@@ -112,5 +116,5 @@ class ClassicRAG(BaseRetriever):
"chunks": self.chunks,
"token_limit": self.token_limit,
"gpt_model": self.gpt_model,
"user_api_key": self.user_api_key
"user_api_key": self.user_api_key,
}