From 6fed84958e8757cf7d49b16aaaf442a0aa594a64 Mon Sep 17 00:00:00 2001 From: Siddhant Rai Date: Mon, 24 Feb 2025 16:41:57 +0530 Subject: [PATCH] feat: agent-retriever workflow + query rephrase --- application/api/answer/routes.py | 23 +- application/retriever/classic_rag.py | 114 ++++----- application/tools/agent.py | 222 ++++++------------ application/tools/base_agent.py | 140 +++++++++++ application/tools/implementations/api_tool.py | 23 +- 5 files changed, 292 insertions(+), 230 deletions(-) create mode 100644 application/tools/base_agent.py diff --git a/application/api/answer/routes.py b/application/api/answer/routes.py index 34e6abca..f83f0c7e 100644 --- a/application/api/answer/routes.py +++ b/application/api/answer/routes.py @@ -17,6 +17,7 @@ from application.error import bad_request from application.extensions import api from application.llm.llm_creator import LLMCreator from application.retriever.retriever_creator import RetrieverCreator +from application.tools.agent import ClassicAgent from application.utils import check_required_fields, limit_chat_history logger = logging.getLogger(__name__) @@ -199,15 +200,21 @@ def get_prompt(prompt_id): def complete_stream( - question, retriever, conversation_id, user_api_key, isNoneDoc=False, index=None + question, + agent, + retriever, + conversation_id, + user_api_key, + isNoneDoc=False, + index=None, ): try: response_full = "" source_log_docs = [] tool_calls = [] - answer = retriever.gen() - sources = retriever.search() + answer = agent.gen(question, retriever) + sources = retriever.search(question) for source in sources: if "text" in source: source["text"] = source["text"][:100].strip() + "..." @@ -361,9 +368,16 @@ class Stream(Resource): prompt = get_prompt(prompt_id) if "isNoneDoc" in data and data["isNoneDoc"] is True: chunks = 0 + agent = ClassicAgent( + settings.LLM_NAME, + gpt_model, + settings.API_KEY, + user_api_key=user_api_key, + prompt=prompt, + chat_history=history, + ) retriever = RetrieverCreator.create_retriever( retriever_name, - question=question, source=source, chat_history=history, prompt=prompt, @@ -376,6 +390,7 @@ class Stream(Resource): return Response( complete_stream( question=question, + agent=agent, retriever=retriever, conversation_id=conversation_id, user_api_key=user_api_key, diff --git a/application/retriever/classic_rag.py b/application/retriever/classic_rag.py index ca40f966..8a1406f4 100644 --- a/application/retriever/classic_rag.py +++ b/application/retriever/classic_rag.py @@ -1,28 +1,25 @@ -import uuid - from application.core.settings import settings +from application.llm.llm_creator import LLMCreator from application.retriever.base import BaseRetriever -from application.tools.agent import Agent from application.vectorstore.vector_creator import VectorCreator class ClassicRAG(BaseRetriever): - def __init__( self, - question, source, - chat_history, - prompt, + chat_history=None, + prompt="", chunks=2, token_limit=150, gpt_model="docsgpt", user_api_key=None, + llm_name=settings.LLM_NAME, + api_key=settings.API_KEY, ): - self.question = question - self.vectorstore = source["active_docs"] if "active_docs" in source else None - self.chat_history = chat_history + self.original_question = "" + self.chat_history = chat_history if chat_history is not None else [] self.prompt = prompt self.chunks = chunks self.gpt_model = gpt_model @@ -37,12 +34,35 @@ class ClassicRAG(BaseRetriever): ) ) self.user_api_key = user_api_key - self.agent = Agent( - llm_name=settings.LLM_NAME, - gpt_model=self.gpt_model, - api_key=settings.API_KEY, - user_api_key=self.user_api_key, + self.llm_name = llm_name + self.api_key = api_key + self.llm = LLMCreator.create_llm( + self.llm_name, api_key=self.api_key, user_api_key=self.user_api_key ) + self.question = self._rephrase_query() + self.vectorstore = source["active_docs"] if "active_docs" in source else None + + def _rephrase_query(self): + if not self.chat_history or self.chat_history == []: + return self.original_question + + prompt = f"""Given the following conversation history: + {self.chat_history} + + Rephrase the following user question to be a standalone search query + that captures all relevant context from the conversation: + {self.original_question} + """ + + messages = [{"role": "system", "content": prompt}] + + try: + rephrased_query = self.llm.gen(model=self.gpt_model, messages=messages) + print(f"Rephrased query: {rephrased_query}") + return rephrased_query if rephrased_query else self.original_question + except Exception as e: + print(f"Error rephrasing query: {e}") + return self.original_question def _get_data(self): if self.chunks == 0: @@ -69,68 +89,20 @@ class ClassicRAG(BaseRetriever): return docs - def gen(self): - docs = self._get_data() + def gen(): + pass - # join all page_content together with a newline - docs_together = "\n".join([doc["text"] for doc in docs]) - p_chat_combine = self.prompt.replace("{summaries}", docs_together) - messages_combine = [{"role": "system", "content": p_chat_combine}] - for doc in docs: - yield {"source": doc} - - if len(self.chat_history) > 0: - for i in self.chat_history: - if "prompt" in i and "response" in i: - messages_combine.append({"role": "user", "content": i["prompt"]}) - messages_combine.append( - {"role": "assistant", "content": i["response"]} - ) - if "tool_calls" in i: - for tool_call in i["tool_calls"]: - call_id = tool_call.get("call_id") - if call_id is None or call_id == "None": - call_id = str(uuid.uuid4()) - - function_call_dict = { - "function_call": { - "name": tool_call.get("action_name"), - "args": tool_call.get("arguments"), - "call_id": call_id, - } - } - function_response_dict = { - "function_response": { - "name": tool_call.get("action_name"), - "response": {"result": tool_call.get("result")}, - "call_id": call_id, - } - } - - messages_combine.append( - {"role": "assistant", "content": [function_call_dict]} - ) - messages_combine.append( - {"role": "tool", "content": [function_response_dict]} - ) - - messages_combine.append({"role": "user", "content": self.question}) - completion = self.agent.gen(messages_combine) - - for line in completion: - yield {"answer": str(line)} - - yield {"tool_calls": self.agent.tool_calls.copy()} - - def search(self): + def search(self, query: str = ""): + if query: + self.original_question = query + self.question = self._rephrase_query() return self._get_data() def get_params(self): return { - "question": self.question, + "question": self.original_question, + "rephrased_question": self.question, "source": self.vectorstore, - "chat_history": self.chat_history, - "prompt": self.prompt, "chunks": self.chunks, "token_limit": self.token_limit, "gpt_model": self.gpt_model, diff --git a/application/tools/agent.py b/application/tools/agent.py index 10798862..95895e70 100644 --- a/application/tools/agent.py +++ b/application/tools/agent.py @@ -1,184 +1,102 @@ -from application.core.mongo_db import MongoDB -from application.llm.llm_creator import LLMCreator -from application.tools.llm_handler import get_llm_handler -from application.tools.tool_action_parser import ToolActionParser -from application.tools.tool_manager import ToolManager +import uuid +from typing import Dict, Generator + +from application.retriever.base import BaseRetriever +from application.tools.base_agent import BaseAgent -class Agent: - def __init__(self, llm_name, gpt_model, api_key, user_api_key=None): - # Initialize the LLM with the provided parameters - self.llm = LLMCreator.create_llm( - llm_name, api_key=api_key, user_api_key=user_api_key - ) - self.llm_handler = get_llm_handler(llm_name) - self.gpt_model = gpt_model - # Static tool configuration (to be replaced later) - self.tools = [] - self.tool_config = {} - self.tool_calls = [] +class ClassicAgent(BaseAgent): + def __init__( + self, + llm_name, + gpt_model, + api_key, + user_api_key=None, + prompt="", + chat_history=None, + ): + super().__init__(llm_name, gpt_model, api_key, user_api_key) + self.prompt = prompt + self.chat_history = chat_history if chat_history is not None else [] - def _get_user_tools(self, user="local"): - mongo = MongoDB.get_client() - db = mongo["docsgpt"] - user_tools_collection = db["user_tools"] - user_tools = user_tools_collection.find({"user": user, "status": True}) - user_tools = list(user_tools) - tools_by_id = {str(tool["_id"]): tool for tool in user_tools} - return tools_by_id + def gen(self, query: str, retriever: BaseRetriever) -> Generator[Dict, None, None]: - def _build_tool_parameters(self, action): - params = {"type": "object", "properties": {}, "required": []} - for param_type in ["query_params", "headers", "body", "parameters"]: - if param_type in action and action[param_type].get("properties"): - for k, v in action[param_type]["properties"].items(): - if v.get("filled_by_llm", True): - params["properties"][k] = { - key: value - for key, value in v.items() - if key != "filled_by_llm" and key != "value" + retrieved_data = retriever.search(query) + docs_together = "\n".join([doc["text"] for doc in retrieved_data]) + p_chat_combine = self.prompt.replace("{summaries}", docs_together) + messages_combine = [{"role": "system", "content": p_chat_combine}] + + if len(self.chat_history) > 0: + for i in self.chat_history: + if "prompt" in i and "response" in i: + messages_combine.append({"role": "user", "content": i["prompt"]}) + messages_combine.append( + {"role": "assistant", "content": i["response"]} + ) + if "tool_calls" in i: + for tool_call in i["tool_calls"]: + call_id = tool_call.get("call_id") + if call_id is None or call_id == "None": + call_id = str(uuid.uuid4()) + + function_call_dict = { + "function_call": { + "name": tool_call.get("action_name"), + "args": tool_call.get("arguments"), + "call_id": call_id, + } + } + function_response_dict = { + "function_response": { + "name": tool_call.get("action_name"), + "response": {"result": tool_call.get("result")}, + "call_id": call_id, + } } - params["required"].append(k) - return params + messages_combine.append( + {"role": "assistant", "content": [function_call_dict]} + ) + messages_combine.append( + {"role": "tool", "content": [function_response_dict]} + ) + messages_combine.append({"role": "user", "content": query}) - def _prepare_tools(self, tools_dict): - self.tools = [ - { - "type": "function", - "function": { - "name": f"{action['name']}_{tool_id}", - "description": action["description"], - "parameters": self._build_tool_parameters(action), - }, - } - for tool_id, tool in tools_dict.items() - if ( - (tool["name"] == "api_tool" and "actions" in tool.get("config", {})) - or (tool["name"] != "api_tool" and "actions" in tool) - ) - for action in ( - tool["config"]["actions"].values() - if tool["name"] == "api_tool" - else tool["actions"] - ) - if action.get("active", True) - ] - - def _execute_tool_action(self, tools_dict, call): - parser = ToolActionParser(self.llm.__class__.__name__) - tool_id, action_name, call_args = parser.parse_args(call) - - tool_data = tools_dict[tool_id] - action_data = ( - tool_data["config"]["actions"][action_name] - if tool_data["name"] == "api_tool" - else next( - action - for action in tool_data["actions"] - if action["name"] == action_name - ) - ) - - query_params, headers, body, parameters = {}, {}, {}, {} - param_types = { - "query_params": query_params, - "headers": headers, - "body": body, - "parameters": parameters, - } - - for param_type, target_dict in param_types.items(): - if param_type in action_data and action_data[param_type].get("properties"): - for param, details in action_data[param_type]["properties"].items(): - if param not in call_args and "value" in details: - target_dict[param] = details["value"] - - for param, value in call_args.items(): - for param_type, target_dict in param_types.items(): - if param_type in action_data and param in action_data[param_type].get( - "properties", {} - ): - target_dict[param] = value - - tm = ToolManager(config={}) - tool = tm.load_tool( - tool_data["name"], - tool_config=( - { - "url": tool_data["config"]["actions"][action_name]["url"], - "method": tool_data["config"]["actions"][action_name]["method"], - "headers": headers, - "query_params": query_params, - } - if tool_data["name"] == "api_tool" - else tool_data["config"] - ), - ) - if tool_data["name"] == "api_tool": - print( - f"Executing api: {action_name} with query_params: {query_params}, headers: {headers}, body: {body}" - ) - result = tool.execute_action(action_name, **body) - else: - print(f"Executing tool: {action_name} with args: {call_args}") - result = tool.execute_action(action_name, **parameters) - call_id = getattr(call, "id", None) - - tool_call_data = { - "tool_name": tool_data["name"], - "call_id": call_id if call_id is not None else "None", - "action_name": f"{action_name}_{tool_id}", - "arguments": call_args, - "result": result, - } - self.tool_calls.append(tool_call_data) - - return result, call_id - - def _simple_tool_agent(self, messages): tools_dict = self._get_user_tools() self._prepare_tools(tools_dict) - resp = self.llm.gen(model=self.gpt_model, messages=messages, tools=self.tools) + resp = self.llm.gen( + model=self.gpt_model, messages=messages_combine, tools=self.tools + ) if isinstance(resp, str): - yield resp + yield {"answer": resp} return if ( hasattr(resp, "message") and hasattr(resp.message, "content") and resp.message.content is not None ): - yield resp.message.content + yield {"answer": resp.message.content} return - resp = self.llm_handler.handle_response(self, resp, tools_dict, messages) + resp = self.llm_handler.handle_response( + self, resp, tools_dict, messages_combine + ) if isinstance(resp, str): - yield resp + yield {"answer": resp} elif ( hasattr(resp, "message") and hasattr(resp.message, "content") and resp.message.content is not None ): - yield resp.message.content + yield {"answer": resp.message.content} else: completion = self.llm.gen_stream( - model=self.gpt_model, messages=messages, tools=self.tools + model=self.gpt_model, messages=messages_combine, tools=self.tools ) for line in completion: - yield line + yield {"answer": line} - return - - def gen(self, messages): - self.tool_calls = [] - if self.llm.supports_tools(): - resp = self._simple_tool_agent(messages) - for line in resp: - yield line - else: - resp = self.llm.gen_stream(model=self.gpt_model, messages=messages) - for line in resp: - yield line + yield {"tool_calls": self.tool_calls.copy()} diff --git a/application/tools/base_agent.py b/application/tools/base_agent.py new file mode 100644 index 00000000..bc8f61a4 --- /dev/null +++ b/application/tools/base_agent.py @@ -0,0 +1,140 @@ +from typing import Dict, Generator + +from application.core.mongo_db import MongoDB +from application.llm.llm_creator import LLMCreator +from application.tools.llm_handler import get_llm_handler +from application.tools.tool_action_parser import ToolActionParser +from application.tools.tool_manager import ToolManager + + +class BaseAgent: + def __init__(self, llm_name, gpt_model, api_key, user_api_key=None): + self.llm = LLMCreator.create_llm( + llm_name, api_key=api_key, user_api_key=user_api_key + ) + self.llm_handler = get_llm_handler(llm_name) + self.gpt_model = gpt_model + self.tools = [] + self.tool_config = {} + self.tool_calls = [] + + def gen(self, query: str) -> Generator[Dict, None, None]: + raise NotImplementedError('Method "gen" must be implemented in the child class') + + def _get_user_tools(self, user="local"): + mongo = MongoDB.get_client() + db = mongo["docsgpt"] + user_tools_collection = db["user_tools"] + user_tools = user_tools_collection.find({"user": user, "status": True}) + user_tools = list(user_tools) + tools_by_id = {str(tool["_id"]): tool for tool in user_tools} + return tools_by_id + + def _build_tool_parameters(self, action): + params = {"type": "object", "properties": {}, "required": []} + for param_type in ["query_params", "headers", "body", "parameters"]: + if param_type in action and action[param_type].get("properties"): + for k, v in action[param_type]["properties"].items(): + if v.get("filled_by_llm", True): + params["properties"][k] = { + key: value + for key, value in v.items() + if key != "filled_by_llm" and key != "value" + } + + params["required"].append(k) + return params + + def _prepare_tools(self, tools_dict): + self.tools = [ + { + "type": "function", + "function": { + "name": f"{action['name']}_{tool_id}", + "description": action["description"], + "parameters": self._build_tool_parameters(action), + }, + } + for tool_id, tool in tools_dict.items() + if ( + (tool["name"] == "api_tool" and "actions" in tool.get("config", {})) + or (tool["name"] != "api_tool" and "actions" in tool) + ) + for action in ( + tool["config"]["actions"].values() + if tool["name"] == "api_tool" + else tool["actions"] + ) + if action.get("active", True) + ] + + def _execute_tool_action(self, tools_dict, call): + parser = ToolActionParser(self.llm.__class__.__name__) + tool_id, action_name, call_args = parser.parse_args(call) + + tool_data = tools_dict[tool_id] + action_data = ( + tool_data["config"]["actions"][action_name] + if tool_data["name"] == "api_tool" + else next( + action + for action in tool_data["actions"] + if action["name"] == action_name + ) + ) + + query_params, headers, body, parameters = {}, {}, {}, {} + param_types = { + "query_params": query_params, + "headers": headers, + "body": body, + "parameters": parameters, + } + + for param_type, target_dict in param_types.items(): + if param_type in action_data and action_data[param_type].get("properties"): + for param, details in action_data[param_type]["properties"].items(): + if param not in call_args and "value" in details: + target_dict[param] = details["value"] + + for param, value in call_args.items(): + for param_type, target_dict in param_types.items(): + if param_type in action_data and param in action_data[param_type].get( + "properties", {} + ): + target_dict[param] = value + + tm = ToolManager(config={}) + tool = tm.load_tool( + tool_data["name"], + tool_config=( + { + "url": tool_data["config"]["actions"][action_name]["url"], + "method": tool_data["config"]["actions"][action_name]["method"], + "headers": headers, + "query_params": query_params, + } + if tool_data["name"] == "api_tool" + else tool_data["config"] + ), + ) + if tool_data["name"] == "api_tool": + print( + f"Executing api: {action_name} with query_params: {query_params}, headers: {headers}, body: {body}" + ) + result = tool.execute_action(action_name, **body) + else: + print(f"Executing tool: {action_name} with args: {call_args}") + result = tool.execute_action(action_name, **parameters) + call_id = getattr(call, "id", None) + + tool_call_data = { + "tool_name": tool_data["name"], + "call_id": call_id if call_id is not None else "None", + "action_name": f"{action_name}_{tool_id}", + "arguments": call_args, + "result": result, + } + self.tool_calls.append(tool_call_data) + + return result, call_id diff --git a/application/tools/implementations/api_tool.py b/application/tools/implementations/api_tool.py index 5d0fec70..69026dc1 100644 --- a/application/tools/implementations/api_tool.py +++ b/application/tools/implementations/api_tool.py @@ -31,10 +31,27 @@ class APITool(Tool): print(f"Making API call: {method} {url} with body: {body}") response = requests.request(method, url, headers=headers, data=body) response.raise_for_status() - try: - data = response.json() - except ValueError: + + content_type = response.headers.get( + "Content-Type", "application/json" + ).lower() + if "application/json" in content_type: + try: + data = response.json() + except json.JSONDecodeError as e: + print(f"Error decoding JSON: {e}. Raw response: {response.text}") + return { + "status_code": response.status_code, + "message": f"API call returned invalid JSON. Error: {e}", + "data": response.text, + } + elif "text/" in content_type or "application/xml" in content_type: + data = response.text + elif not response.content: data = None + else: + print(f"Unsupported content type: {content_type}") + data = response.content return { "status_code": response.status_code,