feat: agent-retriever workflow + query rephrase

This commit is contained in:
Siddhant Rai
2025-02-24 16:41:57 +05:30
parent 5924693e90
commit 6fed84958e
5 changed files with 292 additions and 230 deletions

View File

@@ -1,28 +1,25 @@
import uuid
from application.core.settings import settings
from application.llm.llm_creator import LLMCreator
from application.retriever.base import BaseRetriever
from application.tools.agent import Agent
from application.vectorstore.vector_creator import VectorCreator
class ClassicRAG(BaseRetriever):
def __init__(
self,
question,
source,
chat_history,
prompt,
chat_history=None,
prompt="",
chunks=2,
token_limit=150,
gpt_model="docsgpt",
user_api_key=None,
llm_name=settings.LLM_NAME,
api_key=settings.API_KEY,
):
self.question = question
self.vectorstore = source["active_docs"] if "active_docs" in source else None
self.chat_history = chat_history
self.original_question = ""
self.chat_history = chat_history if chat_history is not None else []
self.prompt = prompt
self.chunks = chunks
self.gpt_model = gpt_model
@@ -37,12 +34,35 @@ class ClassicRAG(BaseRetriever):
)
)
self.user_api_key = user_api_key
self.agent = Agent(
llm_name=settings.LLM_NAME,
gpt_model=self.gpt_model,
api_key=settings.API_KEY,
user_api_key=self.user_api_key,
self.llm_name = llm_name
self.api_key = api_key
self.llm = LLMCreator.create_llm(
self.llm_name, api_key=self.api_key, user_api_key=self.user_api_key
)
self.question = self._rephrase_query()
self.vectorstore = source["active_docs"] if "active_docs" in source else None
def _rephrase_query(self):
if not self.chat_history or self.chat_history == []:
return self.original_question
prompt = f"""Given the following conversation history:
{self.chat_history}
Rephrase the following user question to be a standalone search query
that captures all relevant context from the conversation:
{self.original_question}
"""
messages = [{"role": "system", "content": prompt}]
try:
rephrased_query = self.llm.gen(model=self.gpt_model, messages=messages)
print(f"Rephrased query: {rephrased_query}")
return rephrased_query if rephrased_query else self.original_question
except Exception as e:
print(f"Error rephrasing query: {e}")
return self.original_question
def _get_data(self):
if self.chunks == 0:
@@ -69,68 +89,20 @@ class ClassicRAG(BaseRetriever):
return docs
def gen(self):
docs = self._get_data()
def gen():
pass
# join all page_content together with a newline
docs_together = "\n".join([doc["text"] for doc in docs])
p_chat_combine = self.prompt.replace("{summaries}", docs_together)
messages_combine = [{"role": "system", "content": p_chat_combine}]
for doc in docs:
yield {"source": doc}
if len(self.chat_history) > 0:
for i in self.chat_history:
if "prompt" in i and "response" in i:
messages_combine.append({"role": "user", "content": i["prompt"]})
messages_combine.append(
{"role": "assistant", "content": i["response"]}
)
if "tool_calls" in i:
for tool_call in i["tool_calls"]:
call_id = tool_call.get("call_id")
if call_id is None or call_id == "None":
call_id = str(uuid.uuid4())
function_call_dict = {
"function_call": {
"name": tool_call.get("action_name"),
"args": tool_call.get("arguments"),
"call_id": call_id,
}
}
function_response_dict = {
"function_response": {
"name": tool_call.get("action_name"),
"response": {"result": tool_call.get("result")},
"call_id": call_id,
}
}
messages_combine.append(
{"role": "assistant", "content": [function_call_dict]}
)
messages_combine.append(
{"role": "tool", "content": [function_response_dict]}
)
messages_combine.append({"role": "user", "content": self.question})
completion = self.agent.gen(messages_combine)
for line in completion:
yield {"answer": str(line)}
yield {"tool_calls": self.agent.tool_calls.copy()}
def search(self):
def search(self, query: str = ""):
if query:
self.original_question = query
self.question = self._rephrase_query()
return self._get_data()
def get_params(self):
return {
"question": self.question,
"question": self.original_question,
"rephrased_question": self.question,
"source": self.vectorstore,
"chat_history": self.chat_history,
"prompt": self.prompt,
"chunks": self.chunks,
"token_limit": self.token_limit,
"gpt_model": self.gpt_model,