From 6fed84958e8757cf7d49b16aaaf442a0aa594a64 Mon Sep 17 00:00:00 2001 From: Siddhant Rai Date: Mon, 24 Feb 2025 16:41:57 +0530 Subject: [PATCH 1/7] feat: agent-retriever workflow + query rephrase --- application/api/answer/routes.py | 23 +- application/retriever/classic_rag.py | 114 ++++----- application/tools/agent.py | 222 ++++++------------ application/tools/base_agent.py | 140 +++++++++++ application/tools/implementations/api_tool.py | 23 +- 5 files changed, 292 insertions(+), 230 deletions(-) create mode 100644 application/tools/base_agent.py diff --git a/application/api/answer/routes.py b/application/api/answer/routes.py index 34e6abca..f83f0c7e 100644 --- a/application/api/answer/routes.py +++ b/application/api/answer/routes.py @@ -17,6 +17,7 @@ from application.error import bad_request from application.extensions import api from application.llm.llm_creator import LLMCreator from application.retriever.retriever_creator import RetrieverCreator +from application.tools.agent import ClassicAgent from application.utils import check_required_fields, limit_chat_history logger = logging.getLogger(__name__) @@ -199,15 +200,21 @@ def get_prompt(prompt_id): def complete_stream( - question, retriever, conversation_id, user_api_key, isNoneDoc=False, index=None + question, + agent, + retriever, + conversation_id, + user_api_key, + isNoneDoc=False, + index=None, ): try: response_full = "" source_log_docs = [] tool_calls = [] - answer = retriever.gen() - sources = retriever.search() + answer = agent.gen(question, retriever) + sources = retriever.search(question) for source in sources: if "text" in source: source["text"] = source["text"][:100].strip() + "..." @@ -361,9 +368,16 @@ class Stream(Resource): prompt = get_prompt(prompt_id) if "isNoneDoc" in data and data["isNoneDoc"] is True: chunks = 0 + agent = ClassicAgent( + settings.LLM_NAME, + gpt_model, + settings.API_KEY, + user_api_key=user_api_key, + prompt=prompt, + chat_history=history, + ) retriever = RetrieverCreator.create_retriever( retriever_name, - question=question, source=source, chat_history=history, prompt=prompt, @@ -376,6 +390,7 @@ class Stream(Resource): return Response( complete_stream( question=question, + agent=agent, retriever=retriever, conversation_id=conversation_id, user_api_key=user_api_key, diff --git a/application/retriever/classic_rag.py b/application/retriever/classic_rag.py index ca40f966..8a1406f4 100644 --- a/application/retriever/classic_rag.py +++ b/application/retriever/classic_rag.py @@ -1,28 +1,25 @@ -import uuid - from application.core.settings import settings +from application.llm.llm_creator import LLMCreator from application.retriever.base import BaseRetriever -from application.tools.agent import Agent from application.vectorstore.vector_creator import VectorCreator class ClassicRAG(BaseRetriever): - def __init__( self, - question, source, - chat_history, - prompt, + chat_history=None, + prompt="", chunks=2, token_limit=150, gpt_model="docsgpt", user_api_key=None, + llm_name=settings.LLM_NAME, + api_key=settings.API_KEY, ): - self.question = question - self.vectorstore = source["active_docs"] if "active_docs" in source else None - self.chat_history = chat_history + self.original_question = "" + self.chat_history = chat_history if chat_history is not None else [] self.prompt = prompt self.chunks = chunks self.gpt_model = gpt_model @@ -37,12 +34,35 @@ class ClassicRAG(BaseRetriever): ) ) self.user_api_key = user_api_key - self.agent = Agent( - llm_name=settings.LLM_NAME, - gpt_model=self.gpt_model, - api_key=settings.API_KEY, - user_api_key=self.user_api_key, + self.llm_name = llm_name + self.api_key = api_key + self.llm = LLMCreator.create_llm( + self.llm_name, api_key=self.api_key, user_api_key=self.user_api_key ) + self.question = self._rephrase_query() + self.vectorstore = source["active_docs"] if "active_docs" in source else None + + def _rephrase_query(self): + if not self.chat_history or self.chat_history == []: + return self.original_question + + prompt = f"""Given the following conversation history: + {self.chat_history} + + Rephrase the following user question to be a standalone search query + that captures all relevant context from the conversation: + {self.original_question} + """ + + messages = [{"role": "system", "content": prompt}] + + try: + rephrased_query = self.llm.gen(model=self.gpt_model, messages=messages) + print(f"Rephrased query: {rephrased_query}") + return rephrased_query if rephrased_query else self.original_question + except Exception as e: + print(f"Error rephrasing query: {e}") + return self.original_question def _get_data(self): if self.chunks == 0: @@ -69,68 +89,20 @@ class ClassicRAG(BaseRetriever): return docs - def gen(self): - docs = self._get_data() + def gen(): + pass - # join all page_content together with a newline - docs_together = "\n".join([doc["text"] for doc in docs]) - p_chat_combine = self.prompt.replace("{summaries}", docs_together) - messages_combine = [{"role": "system", "content": p_chat_combine}] - for doc in docs: - yield {"source": doc} - - if len(self.chat_history) > 0: - for i in self.chat_history: - if "prompt" in i and "response" in i: - messages_combine.append({"role": "user", "content": i["prompt"]}) - messages_combine.append( - {"role": "assistant", "content": i["response"]} - ) - if "tool_calls" in i: - for tool_call in i["tool_calls"]: - call_id = tool_call.get("call_id") - if call_id is None or call_id == "None": - call_id = str(uuid.uuid4()) - - function_call_dict = { - "function_call": { - "name": tool_call.get("action_name"), - "args": tool_call.get("arguments"), - "call_id": call_id, - } - } - function_response_dict = { - "function_response": { - "name": tool_call.get("action_name"), - "response": {"result": tool_call.get("result")}, - "call_id": call_id, - } - } - - messages_combine.append( - {"role": "assistant", "content": [function_call_dict]} - ) - messages_combine.append( - {"role": "tool", "content": [function_response_dict]} - ) - - messages_combine.append({"role": "user", "content": self.question}) - completion = self.agent.gen(messages_combine) - - for line in completion: - yield {"answer": str(line)} - - yield {"tool_calls": self.agent.tool_calls.copy()} - - def search(self): + def search(self, query: str = ""): + if query: + self.original_question = query + self.question = self._rephrase_query() return self._get_data() def get_params(self): return { - "question": self.question, + "question": self.original_question, + "rephrased_question": self.question, "source": self.vectorstore, - "chat_history": self.chat_history, - "prompt": self.prompt, "chunks": self.chunks, "token_limit": self.token_limit, "gpt_model": self.gpt_model, diff --git a/application/tools/agent.py b/application/tools/agent.py index 10798862..95895e70 100644 --- a/application/tools/agent.py +++ b/application/tools/agent.py @@ -1,184 +1,102 @@ -from application.core.mongo_db import MongoDB -from application.llm.llm_creator import LLMCreator -from application.tools.llm_handler import get_llm_handler -from application.tools.tool_action_parser import ToolActionParser -from application.tools.tool_manager import ToolManager +import uuid +from typing import Dict, Generator + +from application.retriever.base import BaseRetriever +from application.tools.base_agent import BaseAgent -class Agent: - def __init__(self, llm_name, gpt_model, api_key, user_api_key=None): - # Initialize the LLM with the provided parameters - self.llm = LLMCreator.create_llm( - llm_name, api_key=api_key, user_api_key=user_api_key - ) - self.llm_handler = get_llm_handler(llm_name) - self.gpt_model = gpt_model - # Static tool configuration (to be replaced later) - self.tools = [] - self.tool_config = {} - self.tool_calls = [] +class ClassicAgent(BaseAgent): + def __init__( + self, + llm_name, + gpt_model, + api_key, + user_api_key=None, + prompt="", + chat_history=None, + ): + super().__init__(llm_name, gpt_model, api_key, user_api_key) + self.prompt = prompt + self.chat_history = chat_history if chat_history is not None else [] - def _get_user_tools(self, user="local"): - mongo = MongoDB.get_client() - db = mongo["docsgpt"] - user_tools_collection = db["user_tools"] - user_tools = user_tools_collection.find({"user": user, "status": True}) - user_tools = list(user_tools) - tools_by_id = {str(tool["_id"]): tool for tool in user_tools} - return tools_by_id + def gen(self, query: str, retriever: BaseRetriever) -> Generator[Dict, None, None]: - def _build_tool_parameters(self, action): - params = {"type": "object", "properties": {}, "required": []} - for param_type in ["query_params", "headers", "body", "parameters"]: - if param_type in action and action[param_type].get("properties"): - for k, v in action[param_type]["properties"].items(): - if v.get("filled_by_llm", True): - params["properties"][k] = { - key: value - for key, value in v.items() - if key != "filled_by_llm" and key != "value" + retrieved_data = retriever.search(query) + docs_together = "\n".join([doc["text"] for doc in retrieved_data]) + p_chat_combine = self.prompt.replace("{summaries}", docs_together) + messages_combine = [{"role": "system", "content": p_chat_combine}] + + if len(self.chat_history) > 0: + for i in self.chat_history: + if "prompt" in i and "response" in i: + messages_combine.append({"role": "user", "content": i["prompt"]}) + messages_combine.append( + {"role": "assistant", "content": i["response"]} + ) + if "tool_calls" in i: + for tool_call in i["tool_calls"]: + call_id = tool_call.get("call_id") + if call_id is None or call_id == "None": + call_id = str(uuid.uuid4()) + + function_call_dict = { + "function_call": { + "name": tool_call.get("action_name"), + "args": tool_call.get("arguments"), + "call_id": call_id, + } + } + function_response_dict = { + "function_response": { + "name": tool_call.get("action_name"), + "response": {"result": tool_call.get("result")}, + "call_id": call_id, + } } - params["required"].append(k) - return params + messages_combine.append( + {"role": "assistant", "content": [function_call_dict]} + ) + messages_combine.append( + {"role": "tool", "content": [function_response_dict]} + ) + messages_combine.append({"role": "user", "content": query}) - def _prepare_tools(self, tools_dict): - self.tools = [ - { - "type": "function", - "function": { - "name": f"{action['name']}_{tool_id}", - "description": action["description"], - "parameters": self._build_tool_parameters(action), - }, - } - for tool_id, tool in tools_dict.items() - if ( - (tool["name"] == "api_tool" and "actions" in tool.get("config", {})) - or (tool["name"] != "api_tool" and "actions" in tool) - ) - for action in ( - tool["config"]["actions"].values() - if tool["name"] == "api_tool" - else tool["actions"] - ) - if action.get("active", True) - ] - - def _execute_tool_action(self, tools_dict, call): - parser = ToolActionParser(self.llm.__class__.__name__) - tool_id, action_name, call_args = parser.parse_args(call) - - tool_data = tools_dict[tool_id] - action_data = ( - tool_data["config"]["actions"][action_name] - if tool_data["name"] == "api_tool" - else next( - action - for action in tool_data["actions"] - if action["name"] == action_name - ) - ) - - query_params, headers, body, parameters = {}, {}, {}, {} - param_types = { - "query_params": query_params, - "headers": headers, - "body": body, - "parameters": parameters, - } - - for param_type, target_dict in param_types.items(): - if param_type in action_data and action_data[param_type].get("properties"): - for param, details in action_data[param_type]["properties"].items(): - if param not in call_args and "value" in details: - target_dict[param] = details["value"] - - for param, value in call_args.items(): - for param_type, target_dict in param_types.items(): - if param_type in action_data and param in action_data[param_type].get( - "properties", {} - ): - target_dict[param] = value - - tm = ToolManager(config={}) - tool = tm.load_tool( - tool_data["name"], - tool_config=( - { - "url": tool_data["config"]["actions"][action_name]["url"], - "method": tool_data["config"]["actions"][action_name]["method"], - "headers": headers, - "query_params": query_params, - } - if tool_data["name"] == "api_tool" - else tool_data["config"] - ), - ) - if tool_data["name"] == "api_tool": - print( - f"Executing api: {action_name} with query_params: {query_params}, headers: {headers}, body: {body}" - ) - result = tool.execute_action(action_name, **body) - else: - print(f"Executing tool: {action_name} with args: {call_args}") - result = tool.execute_action(action_name, **parameters) - call_id = getattr(call, "id", None) - - tool_call_data = { - "tool_name": tool_data["name"], - "call_id": call_id if call_id is not None else "None", - "action_name": f"{action_name}_{tool_id}", - "arguments": call_args, - "result": result, - } - self.tool_calls.append(tool_call_data) - - return result, call_id - - def _simple_tool_agent(self, messages): tools_dict = self._get_user_tools() self._prepare_tools(tools_dict) - resp = self.llm.gen(model=self.gpt_model, messages=messages, tools=self.tools) + resp = self.llm.gen( + model=self.gpt_model, messages=messages_combine, tools=self.tools + ) if isinstance(resp, str): - yield resp + yield {"answer": resp} return if ( hasattr(resp, "message") and hasattr(resp.message, "content") and resp.message.content is not None ): - yield resp.message.content + yield {"answer": resp.message.content} return - resp = self.llm_handler.handle_response(self, resp, tools_dict, messages) + resp = self.llm_handler.handle_response( + self, resp, tools_dict, messages_combine + ) if isinstance(resp, str): - yield resp + yield {"answer": resp} elif ( hasattr(resp, "message") and hasattr(resp.message, "content") and resp.message.content is not None ): - yield resp.message.content + yield {"answer": resp.message.content} else: completion = self.llm.gen_stream( - model=self.gpt_model, messages=messages, tools=self.tools + model=self.gpt_model, messages=messages_combine, tools=self.tools ) for line in completion: - yield line + yield {"answer": line} - return - - def gen(self, messages): - self.tool_calls = [] - if self.llm.supports_tools(): - resp = self._simple_tool_agent(messages) - for line in resp: - yield line - else: - resp = self.llm.gen_stream(model=self.gpt_model, messages=messages) - for line in resp: - yield line + yield {"tool_calls": self.tool_calls.copy()} diff --git a/application/tools/base_agent.py b/application/tools/base_agent.py new file mode 100644 index 00000000..bc8f61a4 --- /dev/null +++ b/application/tools/base_agent.py @@ -0,0 +1,140 @@ +from typing import Dict, Generator + +from application.core.mongo_db import MongoDB +from application.llm.llm_creator import LLMCreator +from application.tools.llm_handler import get_llm_handler +from application.tools.tool_action_parser import ToolActionParser +from application.tools.tool_manager import ToolManager + + +class BaseAgent: + def __init__(self, llm_name, gpt_model, api_key, user_api_key=None): + self.llm = LLMCreator.create_llm( + llm_name, api_key=api_key, user_api_key=user_api_key + ) + self.llm_handler = get_llm_handler(llm_name) + self.gpt_model = gpt_model + self.tools = [] + self.tool_config = {} + self.tool_calls = [] + + def gen(self, query: str) -> Generator[Dict, None, None]: + raise NotImplementedError('Method "gen" must be implemented in the child class') + + def _get_user_tools(self, user="local"): + mongo = MongoDB.get_client() + db = mongo["docsgpt"] + user_tools_collection = db["user_tools"] + user_tools = user_tools_collection.find({"user": user, "status": True}) + user_tools = list(user_tools) + tools_by_id = {str(tool["_id"]): tool for tool in user_tools} + return tools_by_id + + def _build_tool_parameters(self, action): + params = {"type": "object", "properties": {}, "required": []} + for param_type in ["query_params", "headers", "body", "parameters"]: + if param_type in action and action[param_type].get("properties"): + for k, v in action[param_type]["properties"].items(): + if v.get("filled_by_llm", True): + params["properties"][k] = { + key: value + for key, value in v.items() + if key != "filled_by_llm" and key != "value" + } + + params["required"].append(k) + return params + + def _prepare_tools(self, tools_dict): + self.tools = [ + { + "type": "function", + "function": { + "name": f"{action['name']}_{tool_id}", + "description": action["description"], + "parameters": self._build_tool_parameters(action), + }, + } + for tool_id, tool in tools_dict.items() + if ( + (tool["name"] == "api_tool" and "actions" in tool.get("config", {})) + or (tool["name"] != "api_tool" and "actions" in tool) + ) + for action in ( + tool["config"]["actions"].values() + if tool["name"] == "api_tool" + else tool["actions"] + ) + if action.get("active", True) + ] + + def _execute_tool_action(self, tools_dict, call): + parser = ToolActionParser(self.llm.__class__.__name__) + tool_id, action_name, call_args = parser.parse_args(call) + + tool_data = tools_dict[tool_id] + action_data = ( + tool_data["config"]["actions"][action_name] + if tool_data["name"] == "api_tool" + else next( + action + for action in tool_data["actions"] + if action["name"] == action_name + ) + ) + + query_params, headers, body, parameters = {}, {}, {}, {} + param_types = { + "query_params": query_params, + "headers": headers, + "body": body, + "parameters": parameters, + } + + for param_type, target_dict in param_types.items(): + if param_type in action_data and action_data[param_type].get("properties"): + for param, details in action_data[param_type]["properties"].items(): + if param not in call_args and "value" in details: + target_dict[param] = details["value"] + + for param, value in call_args.items(): + for param_type, target_dict in param_types.items(): + if param_type in action_data and param in action_data[param_type].get( + "properties", {} + ): + target_dict[param] = value + + tm = ToolManager(config={}) + tool = tm.load_tool( + tool_data["name"], + tool_config=( + { + "url": tool_data["config"]["actions"][action_name]["url"], + "method": tool_data["config"]["actions"][action_name]["method"], + "headers": headers, + "query_params": query_params, + } + if tool_data["name"] == "api_tool" + else tool_data["config"] + ), + ) + if tool_data["name"] == "api_tool": + print( + f"Executing api: {action_name} with query_params: {query_params}, headers: {headers}, body: {body}" + ) + result = tool.execute_action(action_name, **body) + else: + print(f"Executing tool: {action_name} with args: {call_args}") + result = tool.execute_action(action_name, **parameters) + call_id = getattr(call, "id", None) + + tool_call_data = { + "tool_name": tool_data["name"], + "call_id": call_id if call_id is not None else "None", + "action_name": f"{action_name}_{tool_id}", + "arguments": call_args, + "result": result, + } + self.tool_calls.append(tool_call_data) + + return result, call_id diff --git a/application/tools/implementations/api_tool.py b/application/tools/implementations/api_tool.py index 5d0fec70..69026dc1 100644 --- a/application/tools/implementations/api_tool.py +++ b/application/tools/implementations/api_tool.py @@ -31,10 +31,27 @@ class APITool(Tool): print(f"Making API call: {method} {url} with body: {body}") response = requests.request(method, url, headers=headers, data=body) response.raise_for_status() - try: - data = response.json() - except ValueError: + + content_type = response.headers.get( + "Content-Type", "application/json" + ).lower() + if "application/json" in content_type: + try: + data = response.json() + except json.JSONDecodeError as e: + print(f"Error decoding JSON: {e}. Raw response: {response.text}") + return { + "status_code": response.status_code, + "message": f"API call returned invalid JSON. Error: {e}", + "data": response.text, + } + elif "text/" in content_type or "application/xml" in content_type: + data = response.text + elif not response.content: data = None + else: + print(f"Unsupported content type: {content_type}") + data = response.content return { "status_code": response.status_code, From 1f0b779c64cd836b0484622ffe2d2e8df1b4bc40 Mon Sep 17 00:00:00 2001 From: Siddhant Rai Date: Tue, 25 Feb 2025 09:03:45 +0530 Subject: [PATCH 2/7] refactor: folder restructure for agent based workflow --- .../{tools/base_agent.py => agents/base.py} | 7 ++-- .../agent.py => agents/classic_agent.py} | 3 +- application/{tools => agents}/llm_handler.py | 0 .../tools}/api_tool.py | 2 +- application/{ => agents}/tools/base.py | 0 .../tools}/cryptoprice.py | 2 +- .../tools}/telegram.py | 2 +- .../{ => agents}/tools/tool_action_parser.py | 0 .../{ => agents}/tools/tool_manager.py | 12 +++---- application/api/answer/routes.py | 4 +-- application/api/user/routes.py | 35 ++++++++++--------- application/retriever/classic_rag.py | 6 ++-- 12 files changed, 38 insertions(+), 35 deletions(-) rename application/{tools/base_agent.py => agents/base.py} (96%) rename application/{tools/agent.py => agents/classic_agent.py} (98%) rename application/{tools => agents}/llm_handler.py (100%) rename application/{tools/implementations => agents/tools}/api_tool.py (98%) rename application/{ => agents}/tools/base.py (100%) rename application/{tools/implementations => agents/tools}/cryptoprice.py (98%) rename application/{tools/implementations => agents/tools}/telegram.py (98%) rename application/{ => agents}/tools/tool_action_parser.py (100%) rename application/{ => agents}/tools/tool_manager.py (79%) diff --git a/application/tools/base_agent.py b/application/agents/base.py similarity index 96% rename from application/tools/base_agent.py rename to application/agents/base.py index bc8f61a4..93dcb4e2 100644 --- a/application/tools/base_agent.py +++ b/application/agents/base.py @@ -1,10 +1,11 @@ from typing import Dict, Generator +from application.agents.llm_handler import get_llm_handler +from application.agents.tools.tool_action_parser import ToolActionParser +from application.agents.tools.tool_manager import ToolManager + from application.core.mongo_db import MongoDB from application.llm.llm_creator import LLMCreator -from application.tools.llm_handler import get_llm_handler -from application.tools.tool_action_parser import ToolActionParser -from application.tools.tool_manager import ToolManager class BaseAgent: diff --git a/application/tools/agent.py b/application/agents/classic_agent.py similarity index 98% rename from application/tools/agent.py rename to application/agents/classic_agent.py index 95895e70..c7846a04 100644 --- a/application/tools/agent.py +++ b/application/agents/classic_agent.py @@ -1,8 +1,9 @@ import uuid from typing import Dict, Generator +from application.agents.base import BaseAgent + from application.retriever.base import BaseRetriever -from application.tools.base_agent import BaseAgent class ClassicAgent(BaseAgent): diff --git a/application/tools/llm_handler.py b/application/agents/llm_handler.py similarity index 100% rename from application/tools/llm_handler.py rename to application/agents/llm_handler.py diff --git a/application/tools/implementations/api_tool.py b/application/agents/tools/api_tool.py similarity index 98% rename from application/tools/implementations/api_tool.py rename to application/agents/tools/api_tool.py index 69026dc1..06d5fb7a 100644 --- a/application/tools/implementations/api_tool.py +++ b/application/agents/tools/api_tool.py @@ -1,7 +1,7 @@ import json import requests -from application.tools.base import Tool +from application.agents.tools.base import Tool class APITool(Tool): diff --git a/application/tools/base.py b/application/agents/tools/base.py similarity index 100% rename from application/tools/base.py rename to application/agents/tools/base.py diff --git a/application/tools/implementations/cryptoprice.py b/application/agents/tools/cryptoprice.py similarity index 98% rename from application/tools/implementations/cryptoprice.py rename to application/agents/tools/cryptoprice.py index 7b88c866..80d0c2fc 100644 --- a/application/tools/implementations/cryptoprice.py +++ b/application/agents/tools/cryptoprice.py @@ -1,5 +1,5 @@ import requests -from application.tools.base import Tool +from application.agents.tools.base import Tool class CryptoPriceTool(Tool): diff --git a/application/tools/implementations/telegram.py b/application/agents/tools/telegram.py similarity index 98% rename from application/tools/implementations/telegram.py rename to application/agents/tools/telegram.py index a32bbe88..06350ae9 100644 --- a/application/tools/implementations/telegram.py +++ b/application/agents/tools/telegram.py @@ -1,5 +1,5 @@ import requests -from application.tools.base import Tool +from application.agents.tools.base import Tool class TelegramTool(Tool): diff --git a/application/tools/tool_action_parser.py b/application/agents/tools/tool_action_parser.py similarity index 100% rename from application/tools/tool_action_parser.py rename to application/agents/tools/tool_action_parser.py diff --git a/application/tools/tool_manager.py b/application/agents/tools/tool_manager.py similarity index 79% rename from application/tools/tool_manager.py rename to application/agents/tools/tool_manager.py index 3e0766cf..ad71db28 100644 --- a/application/tools/tool_manager.py +++ b/application/agents/tools/tool_manager.py @@ -3,7 +3,7 @@ import inspect import os import pkgutil -from application.tools.base import Tool +from application.agents.tools.base import Tool class ToolManager: @@ -13,13 +13,11 @@ class ToolManager: self.load_tools() def load_tools(self): - tools_dir = os.path.join(os.path.dirname(__file__), "implementations") + tools_dir = os.path.join(os.path.dirname(__file__)) for finder, name, ispkg in pkgutil.iter_modules([tools_dir]): if name == "base" or name.startswith("__"): continue - module = importlib.import_module( - f"application.tools.implementations.{name}" - ) + module = importlib.import_module(f"application.agents.tools.{name}") for member_name, obj in inspect.getmembers(module, inspect.isclass): if issubclass(obj, Tool) and obj is not Tool: tool_config = self.config.get(name, {}) @@ -27,9 +25,7 @@ class ToolManager: def load_tool(self, tool_name, tool_config): self.config[tool_name] = tool_config - module = importlib.import_module( - f"application.tools.implementations.{tool_name}" - ) + module = importlib.import_module(f"application.agents.tools.{tool_name}") for member_name, obj in inspect.getmembers(module, inspect.isclass): if issubclass(obj, Tool) and obj is not Tool: return obj(tool_config) diff --git a/application/api/answer/routes.py b/application/api/answer/routes.py index f83f0c7e..d21c256e 100644 --- a/application/api/answer/routes.py +++ b/application/api/answer/routes.py @@ -1,15 +1,16 @@ import asyncio import datetime import json +import logging import os import traceback -import logging from bson.dbref import DBRef from bson.objectid import ObjectId from flask import Blueprint, make_response, request, Response from flask_restx import fields, Namespace, Resource +from application.agents.classic_agent import ClassicAgent from application.core.mongo_db import MongoDB from application.core.settings import settings @@ -17,7 +18,6 @@ from application.error import bad_request from application.extensions import api from application.llm.llm_creator import LLMCreator from application.retriever.retriever_creator import RetrieverCreator -from application.tools.agent import ClassicAgent from application.utils import check_required_fields, limit_chat_history logger = logging.getLogger(__name__) diff --git a/application/api/user/routes.py b/application/api/user/routes.py index f71ab3dc..dced8a82 100644 --- a/application/api/user/routes.py +++ b/application/api/user/routes.py @@ -1,9 +1,9 @@ import datetime +import json import math import os import shutil import uuid -import json from bson.binary import Binary, UuidRepresentation from bson.dbref import DBRef @@ -12,12 +12,13 @@ from flask import Blueprint, current_app, jsonify, make_response, redirect, requ from flask_restx import fields, inputs, Namespace, Resource from werkzeug.utils import secure_filename +from application.agents.tools.tool_manager import ToolManager + from application.api.user.tasks import ingest, ingest_remote from application.core.mongo_db import MongoDB from application.core.settings import settings from application.extensions import api -from application.tools.tool_manager import ToolManager from application.tts.google_tts import GoogleTTS from application.utils import check_required_fields, validate_function_name from application.vectorstore.vector_creator import VectorCreator @@ -429,22 +430,21 @@ class UploadRemote(Resource): return missing_fields try: - config = json.loads(data["data"]) - source_data = None + config = json.loads(data["data"]) + source_data = None - if data["source"] == "github": + if data["source"] == "github": source_data = config.get("repo_url") - elif data["source"] in ["crawler", "url"]: + elif data["source"] in ["crawler", "url"]: source_data = config.get("url") - elif data["source"] == "reddit": - source_data = config + elif data["source"] == "reddit": + source_data = config - - task = ingest_remote.delay( + task = ingest_remote.delay( source_data=source_data, job_name=data["name"], user=data["user"], - loader=data["source"] + loader=data["source"], ) except Exception as err: current_app.logger.error(f"Error uploading remote source: {err}") @@ -1936,11 +1936,14 @@ class UpdateTool(Resource): for action_name in list(data["config"]["actions"].keys()): if not validate_function_name(action_name): return make_response( - jsonify({ - "success": False, - "message": f"Invalid function name '{action_name}'. Function names must match pattern '^[a-zA-Z0-9_-]+$'.", - "param": "tools[].function.name" - }), 400 + jsonify( + { + "success": False, + "message": f"Invalid function name '{action_name}'. Function names must match pattern '^[a-zA-Z0-9_-]+$'.", + "param": "tools[].function.name", + } + ), + 400, ) update_data["config"] = data["config"] if "status" in data: diff --git a/application/retriever/classic_rag.py b/application/retriever/classic_rag.py index 8a1406f4..5c74878c 100644 --- a/application/retriever/classic_rag.py +++ b/application/retriever/classic_rag.py @@ -51,10 +51,12 @@ class ClassicRAG(BaseRetriever): Rephrase the following user question to be a standalone search query that captures all relevant context from the conversation: - {self.original_question} """ - messages = [{"role": "system", "content": prompt}] + messages = [ + {"role": "system", "content": prompt}, + {"role": "user", "content": self.original_question}, + ] try: rephrased_query = self.llm.gen(model=self.gpt_model, messages=messages) From c6ce4d9374f6d81091f8ca5ea47d244e8f4f8490 Mon Sep 17 00:00:00 2001 From: Siddhant Rai Date: Thu, 27 Feb 2025 19:14:10 +0530 Subject: [PATCH 3/7] feat: logging stacks --- application/agents/agent_creator.py | 14 +++ application/agents/base.py | 5 +- application/agents/classic_agent.py | 50 +++++++-- application/agents/llm_handler.py | 10 ++ application/api/answer/routes.py | 16 +-- application/core/settings.py | 1 + application/logging.py | 151 +++++++++++++++++++++++++++ application/retriever/classic_rag.py | 6 +- application/usage.py | 16 ++- 9 files changed, 246 insertions(+), 23 deletions(-) create mode 100644 application/agents/agent_creator.py create mode 100644 application/logging.py diff --git a/application/agents/agent_creator.py b/application/agents/agent_creator.py new file mode 100644 index 00000000..a76d9faf --- /dev/null +++ b/application/agents/agent_creator.py @@ -0,0 +1,14 @@ +from application.agents.classic_agent import ClassicAgent + + +class AgentCreator: + agents = { + "classic": ClassicAgent, + } + + @classmethod + def create_agent(cls, type, *args, **kwargs): + agent_class = cls.agents.get(type.lower()) + if not agent_class: + raise ValueError(f"No agent class found for type {type}") + return agent_class(*args, **kwargs) diff --git a/application/agents/base.py b/application/agents/base.py index 93dcb4e2..7e36c991 100644 --- a/application/agents/base.py +++ b/application/agents/base.py @@ -9,7 +9,8 @@ from application.llm.llm_creator import LLMCreator class BaseAgent: - def __init__(self, llm_name, gpt_model, api_key, user_api_key=None): + def __init__(self, endpoint, llm_name, gpt_model, api_key, user_api_key=None): + self.endpoint = endpoint self.llm = LLMCreator.create_llm( llm_name, api_key=api_key, user_api_key=user_api_key ) @@ -19,7 +20,7 @@ class BaseAgent: self.tool_config = {} self.tool_calls = [] - def gen(self, query: str) -> Generator[Dict, None, None]: + def gen(self, *args, **kwargs) -> Generator[Dict, None, None]: raise NotImplementedError('Method "gen" must be implemented in the child class') def _get_user_tools(self, user="local"): diff --git a/application/agents/classic_agent.py b/application/agents/classic_agent.py index c7846a04..4e64442d 100644 --- a/application/agents/classic_agent.py +++ b/application/agents/classic_agent.py @@ -2,6 +2,7 @@ import uuid from typing import Dict, Generator from application.agents.base import BaseAgent +from application.logging import build_stack_data, log_activity, LogContext from application.retriever.base import BaseRetriever @@ -9,6 +10,7 @@ from application.retriever.base import BaseRetriever class ClassicAgent(BaseAgent): def __init__( self, + endpoint, llm_name, gpt_model, api_key, @@ -16,13 +18,21 @@ class ClassicAgent(BaseAgent): prompt="", chat_history=None, ): - super().__init__(llm_name, gpt_model, api_key, user_api_key) + super().__init__(endpoint, llm_name, gpt_model, api_key, user_api_key) self.prompt = prompt self.chat_history = chat_history if chat_history is not None else [] - def gen(self, query: str, retriever: BaseRetriever) -> Generator[Dict, None, None]: + @log_activity() + def gen( + self, query: str, retriever: BaseRetriever, log_context: LogContext = None + ) -> Generator[Dict, None, None]: + yield from self._gen_inner(query, retriever, log_context) + + def _gen_inner( + self, query: str, retriever: BaseRetriever, log_context: LogContext + ) -> Generator[Dict, None, None]: + retrieved_data = self._retriever_search(retriever, query, log_context) - retrieved_data = retriever.search(query) docs_together = "\n".join([doc["text"] for doc in retrieved_data]) p_chat_combine = self.prompt.replace("{summaries}", docs_together) messages_combine = [{"role": "system", "content": p_chat_combine}] @@ -66,9 +76,7 @@ class ClassicAgent(BaseAgent): tools_dict = self._get_user_tools() self._prepare_tools(tools_dict) - resp = self.llm.gen( - model=self.gpt_model, messages=messages_combine, tools=self.tools - ) + resp = self._llm_gen(messages_combine, log_context) if isinstance(resp, str): yield {"answer": resp} @@ -81,9 +89,7 @@ class ClassicAgent(BaseAgent): yield {"answer": resp.message.content} return - resp = self.llm_handler.handle_response( - self, resp, tools_dict, messages_combine - ) + resp = self._llm_handler(resp, tools_dict, messages_combine, log_context) if isinstance(resp, str): yield {"answer": resp} @@ -101,3 +107,29 @@ class ClassicAgent(BaseAgent): yield {"answer": line} yield {"tool_calls": self.tool_calls.copy()} + + def _retriever_search(self, retriever, query, log_context): + retrieved_data = retriever.search(query) + if log_context: + data = build_stack_data(retriever, exclude_attributes=["llm"]) + log_context.stacks.append({"component": "retriever", "data": data}) + return retrieved_data + + def _llm_gen(self, messages_combine, log_context): + resp = self.llm.gen( + model=self.gpt_model, messages=messages_combine, tools=self.tools + ) + if log_context: + data = build_stack_data(self.llm) + log_context.stacks.append({"component": "llm", "data": data}) + return resp + + def _llm_handler(self, resp, tools_dict, messages_combine, log_context): + resp = self.llm_handler.handle_response( + self, resp, tools_dict, messages_combine + ) + if log_context: + data = build_stack_data(self.llm_handler) + log_context.stacks.append({"component": "llm_handler", "data": data}) + + return resp diff --git a/application/agents/llm_handler.py b/application/agents/llm_handler.py index 334d2c4c..adf240c3 100644 --- a/application/agents/llm_handler.py +++ b/application/agents/llm_handler.py @@ -1,8 +1,14 @@ import json from abc import ABC, abstractmethod +from application.logging import build_stack_data + class LLMHandler(ABC): + def __init__(self): + self.llm_calls = [] + self.tool_calls = [] + @abstractmethod def handle_response(self, agent, resp, tools_dict, messages, **kwargs): pass @@ -21,6 +27,7 @@ class OpenAILLMHandler(LLMHandler): tool_calls = resp.message.tool_calls for call in tool_calls: try: + self.tool_calls.append(call) tool_response, call_id = agent._execute_tool_action( tools_dict, call ) @@ -57,6 +64,7 @@ class OpenAILLMHandler(LLMHandler): resp = agent.llm.gen( model=agent.gpt_model, messages=messages, tools=agent.tools ) + self.llm_calls.append(build_stack_data(agent.llm)) return resp @@ -68,11 +76,13 @@ class GoogleLLMHandler(LLMHandler): response = agent.llm.gen( model=agent.gpt_model, messages=messages, tools=agent.tools ) + self.llm_calls.append(build_stack_data(agent.llm)) if response.candidates and response.candidates[0].content.parts: tool_call_found = False for part in response.candidates[0].content.parts: if part.function_call: tool_call_found = True + self.tool_calls.append(part.function_call) tool_response, call_id = agent._execute_tool_action( tools_dict, part.function_call ) diff --git a/application/api/answer/routes.py b/application/api/answer/routes.py index d21c256e..b249f058 100644 --- a/application/api/answer/routes.py +++ b/application/api/answer/routes.py @@ -10,7 +10,7 @@ from bson.objectid import ObjectId from flask import Blueprint, make_response, request, Response from flask_restx import fields, Namespace, Resource -from application.agents.classic_agent import ClassicAgent +from application.agents.agent_creator import AgentCreator from application.core.mongo_db import MongoDB from application.core.settings import settings @@ -213,7 +213,7 @@ def complete_stream( response_full = "" source_log_docs = [] tool_calls = [] - answer = agent.gen(question, retriever) + answer = agent.gen(query=question, retriever=retriever) sources = retriever.search(question) for source in sources: if "text" in source: @@ -368,14 +368,18 @@ class Stream(Resource): prompt = get_prompt(prompt_id) if "isNoneDoc" in data and data["isNoneDoc"] is True: chunks = 0 - agent = ClassicAgent( - settings.LLM_NAME, - gpt_model, - settings.API_KEY, + + agent = AgentCreator.create_agent( + settings.AGENT_NAME, + endpoint="stream", + llm_name=settings.LLM_NAME, + gpt_model=gpt_model, + api_key=settings.API_KEY, user_api_key=user_api_key, prompt=prompt, chat_history=history, ) + retriever = RetrieverCreator.create_retriever( retriever_name, source=source, diff --git a/application/core/settings.py b/application/core/settings.py index 5842da33..04d7bbea 100644 --- a/application/core/settings.py +++ b/application/core/settings.py @@ -32,6 +32,7 @@ class Settings(BaseSettings): "faiss" # "faiss" or "elasticsearch" or "qdrant" or "milvus" or "lancedb" ) RETRIEVERS_ENABLED: list = ["classic_rag", "duckduck_search"] # also brave_search + AGENT_NAME: str = "classic" # LLM Cache CACHE_REDIS_URL: str = "redis://localhost:6379/2" diff --git a/application/logging.py b/application/logging.py new file mode 100644 index 00000000..1dd0d557 --- /dev/null +++ b/application/logging.py @@ -0,0 +1,151 @@ +import datetime +import functools +import inspect + +import logging +import uuid +from typing import Any, Callable, Dict, Generator, List + +from application.core.mongo_db import MongoDB + +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) + + +class LogContext: + def __init__(self, endpoint, activity_id, user, api_key, query): + self.endpoint = endpoint + self.activity_id = activity_id + self.user = user + self.api_key = api_key + self.query = query + self.stacks = [] + + +def build_stack_data( + obj: Any, + include_attributes: List[str] = None, + exclude_attributes: List[str] = None, + custom_data: Dict = None, +) -> Dict: + data = {} + if include_attributes is None: + include_attributes = [] + for name, value in inspect.getmembers(obj): + if ( + not name.startswith("_") + and not inspect.ismethod(value) + and not inspect.isfunction(value) + ): + include_attributes.append(name) + for attr_name in include_attributes: + if exclude_attributes and attr_name in exclude_attributes: + continue + try: + attr_value = getattr(obj, attr_name) + if attr_value is not None: + if isinstance(attr_value, (int, float, str, bool)): + data[attr_name] = attr_value + elif isinstance(attr_value, list): + if all(isinstance(item, dict) for item in attr_value): + data[attr_name] = attr_value + elif all(hasattr(item, "__dict__") for item in attr_value): + data[attr_name] = [item.__dict__ for item in attr_value] + else: + data[attr_name] = [str(item) for item in attr_value] + elif isinstance(attr_value, dict): + data[attr_name] = {k: str(v) for k, v in attr_value.items()} + else: + data[attr_name] = str(attr_value) + except AttributeError: + pass + if custom_data: + data.update(custom_data) + return data + + +def log_activity() -> Callable: + def decorator(func: Callable) -> Callable: + @functools.wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Any: + activity_id = str(uuid.uuid4()) + data = build_stack_data(args[0]) + endpoint = data.get("endpoint", "") + user = data.get("user", "local") + api_key = data.get("user_api_key", "") + query = kwargs.get("query", getattr(args[0], "query", "")) + + context = LogContext(endpoint, activity_id, user, api_key, query) + kwargs["log_context"] = context + + logging.info( + f"Starting activity: {endpoint} - {activity_id} - User: {user}" + ) + + generator = func(*args, **kwargs) + yield from _consume_and_log(generator, context) + + return wrapper + + return decorator + + +def _consume_and_log(generator: Generator, context: "LogContext"): + try: + for item in generator: + yield item + except Exception as e: + logging.exception(f"Error in {context.endpoint} - {context.activity_id}: {e}") + context.stacks.append({"component": "error", "data": {"message": str(e)}}) + _log_to_mongodb( + endpoint=context.endpoint, + activity_id=context.activity_id, + user=context.user, + api_key=context.api_key, + query=context.query, + stacks=context.stacks, + level="error", + ) + raise + finally: + _log_to_mongodb( + endpoint=context.endpoint, + activity_id=context.activity_id, + user=context.user, + api_key=context.api_key, + query=context.query, + stacks=context.stacks, + level="info", + ) + + +def _log_to_mongodb( + endpoint: str, + activity_id: str, + user: str, + api_key: str, + query: str, + stacks: List[Dict], + level: str, +) -> None: + try: + mongo = MongoDB.get_client() + db = mongo["docsgpt"] + user_logs_collection = db["stack_logs"] + + log_entry = { + "endpoint": endpoint, + "id": activity_id, + "level": level, + "user": user, + "api_key": api_key, + "query": query, + "stacks": stacks, + "timestamp": datetime.datetime.now(datetime.timezone.utc), + } + user_logs_collection.insert_one(log_entry) + logging.debug(f"Logged activity to MongoDB: {activity_id}") + + except Exception as e: + logging.error(f"Failed to log to MongoDB: {e}") diff --git a/application/retriever/classic_rag.py b/application/retriever/classic_rag.py index 5c74878c..03f17f44 100644 --- a/application/retriever/classic_rag.py +++ b/application/retriever/classic_rag.py @@ -43,7 +43,11 @@ class ClassicRAG(BaseRetriever): self.vectorstore = source["active_docs"] if "active_docs" in source else None def _rephrase_query(self): - if not self.chat_history or self.chat_history == []: + if ( + not self.original_question + or not self.chat_history + or self.chat_history == [] + ): return self.original_question prompt = f"""Given the following conversation history: diff --git a/application/usage.py b/application/usage.py index fe4cd50e..a18a3848 100644 --- a/application/usage.py +++ b/application/usage.py @@ -1,7 +1,8 @@ import sys from datetime import datetime + from application.core.mongo_db import MongoDB -from application.utils import num_tokens_from_string, num_tokens_from_object_or_list +from application.utils import num_tokens_from_object_or_list, num_tokens_from_string mongo = MongoDB.get_client() db = mongo["docsgpt"] @@ -24,13 +25,16 @@ def gen_token_usage(func): def wrapper(self, model, messages, stream, tools, **kwargs): for message in messages: if message["content"]: - self.token_usage["prompt_tokens"] += num_tokens_from_string(message["content"]) + self.token_usage["prompt_tokens"] += num_tokens_from_string( + message["content"] + ) result = func(self, model, messages, stream, tools, **kwargs) - # check if result is a string if isinstance(result, str): self.token_usage["generated_tokens"] += num_tokens_from_string(result) else: - self.token_usage["generated_tokens"] += num_tokens_from_object_or_list(result) + self.token_usage["generated_tokens"] += num_tokens_from_object_or_list( + result + ) update_token_usage(self.user_api_key, self.token_usage) return result @@ -40,7 +44,9 @@ def gen_token_usage(func): def stream_token_usage(func): def wrapper(self, model, messages, stream, tools, **kwargs): for message in messages: - self.token_usage["prompt_tokens"] += num_tokens_from_string(message["content"]) + self.token_usage["prompt_tokens"] += num_tokens_from_string( + message["content"] + ) batch = [] result = func(self, model, messages, stream, tools, **kwargs) for r in result: From f88c34a0bef07f1c16c405c23168c84271dcdfaa Mon Sep 17 00:00:00 2001 From: Siddhant Rai Date: Wed, 5 Mar 2025 09:02:55 +0530 Subject: [PATCH 4/7] feat: streaming responses with function call --- application/agents/classic_agent.py | 6 +- application/agents/llm_handler.py | 266 +++++++++++++----- application/agents/tools/cryptoprice.py | 1 - .../agents/tools/tool_action_parser.py | 17 +- application/llm/google_ai.py | 10 +- application/llm/openai.py | 17 +- 6 files changed, 237 insertions(+), 80 deletions(-) diff --git a/application/agents/classic_agent.py b/application/agents/classic_agent.py index 4e64442d..79d9e37f 100644 --- a/application/agents/classic_agent.py +++ b/application/agents/classic_agent.py @@ -104,7 +104,8 @@ class ClassicAgent(BaseAgent): model=self.gpt_model, messages=messages_combine, tools=self.tools ) for line in completion: - yield {"answer": line} + if isinstance(line, str): + yield {"answer": line} yield {"tool_calls": self.tool_calls.copy()} @@ -116,7 +117,7 @@ class ClassicAgent(BaseAgent): return retrieved_data def _llm_gen(self, messages_combine, log_context): - resp = self.llm.gen( + resp = self.llm.gen_stream( model=self.gpt_model, messages=messages_combine, tools=self.tools ) if log_context: @@ -131,5 +132,4 @@ class ClassicAgent(BaseAgent): if log_context: data = build_stack_data(self.llm_handler) log_context.stacks.append({"component": "llm_handler", "data": data}) - return resp diff --git a/application/agents/llm_handler.py b/application/agents/llm_handler.py index adf240c3..9267dc53 100644 --- a/application/agents/llm_handler.py +++ b/application/agents/llm_handler.py @@ -15,84 +15,221 @@ class LLMHandler(ABC): class OpenAILLMHandler(LLMHandler): - def handle_response(self, agent, resp, tools_dict, messages): - while resp.finish_reason == "tool_calls": - message = json.loads(resp.model_dump_json())["message"] - keys_to_remove = {"audio", "function_call", "refusal"} - filtered_data = { - k: v for k, v in message.items() if k not in keys_to_remove - } - messages.append(filtered_data) + def handle_response(self, agent, resp, tools_dict, messages, stream: bool = True): + if not stream: + while hasattr(resp, "finish_reason") and resp.finish_reason == "tool_calls": + message = json.loads(resp.model_dump_json())["message"] + keys_to_remove = {"audio", "function_call", "refusal"} + filtered_data = { + k: v for k, v in message.items() if k not in keys_to_remove + } + messages.append(filtered_data) - tool_calls = resp.message.tool_calls - for call in tool_calls: - try: - self.tool_calls.append(call) - tool_response, call_id = agent._execute_tool_action( - tools_dict, call - ) - function_call_dict = { - "function_call": { - "name": call.function.name, - "args": call.function.arguments, - "call_id": call_id, + tool_calls = resp.message.tool_calls + for call in tool_calls: + try: + self.tool_calls.append(call) + tool_response, call_id = agent._execute_tool_action( + tools_dict, call + ) + function_call_dict = { + "function_call": { + "name": call.function.name, + "args": call.function.arguments, + "call_id": call_id, + } } - } - function_response_dict = { - "function_response": { - "name": call.function.name, - "response": {"result": tool_response}, - "call_id": call_id, + function_response_dict = { + "function_response": { + "name": call.function.name, + "response": {"result": tool_response}, + "call_id": call_id, + } } - } - messages.append( - {"role": "assistant", "content": [function_call_dict]} - ) - messages.append( - {"role": "tool", "content": [function_response_dict]} - ) + messages.append( + {"role": "assistant", "content": [function_call_dict]} + ) + messages.append( + {"role": "tool", "content": [function_response_dict]} + ) - except Exception as e: - messages.append( - { - "role": "tool", - "content": f"Error executing tool: {str(e)}", - "tool_call_id": call_id, - } - ) - resp = agent.llm.gen( - model=agent.gpt_model, messages=messages, tools=agent.tools - ) - self.llm_calls.append(build_stack_data(agent.llm)) - return resp + except Exception as e: + messages.append( + { + "role": "tool", + "content": f"Error executing tool: {str(e)}", + "tool_call_id": call_id, + } + ) + resp = agent.llm.gen_stream( + model=agent.gpt_model, messages=messages, tools=agent.tools + ) + self.llm_calls.append(build_stack_data(agent.llm)) + return resp + + else: + while True: + tool_calls = {} + for chunk in resp: + if isinstance(chunk, str): + return + else: + chunk_delta = chunk.delta + + if ( + hasattr(chunk_delta, "tool_calls") + and chunk_delta.tool_calls is not None + ): + for tool_call in chunk_delta.tool_calls: + index = tool_call.index + if index not in tool_calls: + tool_calls[index] = { + "id": "", + "function": {"name": "", "arguments": ""}, + } + + current = tool_calls[index] + if tool_call.id: + current["id"] = tool_call.id + if tool_call.function.name: + current["function"][ + "name" + ] = tool_call.function.name + if tool_call.function.arguments: + current["function"][ + "arguments" + ] += tool_call.function.arguments + tool_calls[index] = current + + if ( + hasattr(chunk, "finish_reason") + and chunk.finish_reason == "tool_calls" + ): + for index in sorted(tool_calls.keys()): + call = tool_calls[index] + try: + self.tool_calls.append(call) + tool_response, call_id = agent._execute_tool_action( + tools_dict, call + ) + + function_call_dict = { + "function_call": { + "name": call["function"]["name"], + "args": call["function"]["arguments"], + "call_id": call["id"], + } + } + function_response_dict = { + "function_response": { + "name": call["function"]["name"], + "response": {"result": tool_response}, + "call_id": call["id"], + } + } + + messages.append( + { + "role": "assistant", + "content": [function_call_dict], + } + ) + messages.append( + { + "role": "tool", + "content": [function_response_dict], + } + ) + + except Exception as e: + messages.append( + { + "role": "assistant", + "content": f"Error executing tool: {str(e)}", + } + ) + tool_calls = {} + + if ( + hasattr(chunk, "finish_reason") + and chunk.finish_reason == "stop" + ): + return + + resp = agent.llm.gen_stream( + model=agent.gpt_model, messages=messages, tools=agent.tools + ) + self.llm_calls.append(build_stack_data(agent.llm)) class GoogleLLMHandler(LLMHandler): - def handle_response(self, agent, resp, tools_dict, messages): + def handle_response(self, agent, resp, tools_dict, messages, stream: bool = True): from google.genai import types while True: - response = agent.llm.gen( - model=agent.gpt_model, messages=messages, tools=agent.tools - ) - self.llm_calls.append(build_stack_data(agent.llm)) - if response.candidates and response.candidates[0].content.parts: + if not stream: + response = agent.llm.gen( + model=agent.gpt_model, messages=messages, tools=agent.tools + ) + self.llm_calls.append(build_stack_data(agent.llm)) + if response.candidates and response.candidates[0].content.parts: + tool_call_found = False + for part in response.candidates[0].content.parts: + if part.function_call: + tool_call_found = True + self.tool_calls.append(part.function_call) + tool_response, call_id = agent._execute_tool_action( + tools_dict, part.function_call + ) + function_response_part = types.Part.from_function_response( + name=part.function_call.name, + response={"result": tool_response}, + ) + + messages.append( + {"role": "model", "content": [part.to_json_dict()]} + ) + messages.append( + { + "role": "tool", + "content": [function_response_part.to_json_dict()], + } + ) + + if ( + not tool_call_found + and response.candidates[0].content.parts + and response.candidates[0].content.parts[0].text + ): + return response.candidates[0].content.parts[0].text + elif not tool_call_found: + return response.candidates[0].content.parts + + else: + return response + + else: + response = agent.llm.gen_stream( + model=agent.gpt_model, messages=messages, tools=agent.tools + ) + self.llm_calls.append(build_stack_data(agent.llm)) + tool_call_found = False - for part in response.candidates[0].content.parts: - if part.function_call: + for result in response: + if hasattr(result, "function_call"): tool_call_found = True - self.tool_calls.append(part.function_call) + self.tool_calls.append(result.function_call) tool_response, call_id = agent._execute_tool_action( - tools_dict, part.function_call + tools_dict, result.function_call ) function_response_part = types.Part.from_function_response( - name=part.function_call.name, + name=result.function_call.name, response={"result": tool_response}, ) messages.append( - {"role": "model", "content": [part.to_json_dict()]} + {"role": "model", "content": [result.to_json_dict()]} ) messages.append( { @@ -101,17 +238,8 @@ class GoogleLLMHandler(LLMHandler): } ) - if ( - not tool_call_found - and response.candidates[0].content.parts - and response.candidates[0].content.parts[0].text - ): - return response.candidates[0].content.parts[0].text - elif not tool_call_found: - return response.candidates[0].content.parts - - else: - return response + if not tool_call_found: + return response def get_llm_handler(llm_type): diff --git a/application/agents/tools/cryptoprice.py b/application/agents/tools/cryptoprice.py index 80d0c2fc..c25c3d43 100644 --- a/application/agents/tools/cryptoprice.py +++ b/application/agents/tools/cryptoprice.py @@ -31,7 +31,6 @@ class CryptoPriceTool(Tool): response = requests.get(url) if response.status_code == 200: data = response.json() - # data will be like {"USD": } if the call is successful if currency.upper() in data: return { "status_code": response.status_code, diff --git a/application/agents/tools/tool_action_parser.py b/application/agents/tools/tool_action_parser.py index ac0a70c1..4d894d1a 100644 --- a/application/agents/tools/tool_action_parser.py +++ b/application/agents/tools/tool_action_parser.py @@ -14,9 +14,20 @@ class ToolActionParser: return parser(call) def _parse_openai_llm(self, call): - call_args = json.loads(call.function.arguments) - tool_id = call.function.name.split("_")[-1] - action_name = call.function.name.rsplit("_", 1)[0] + if isinstance(call, dict): + try: + call_args = json.loads(call["function"]["arguments"]) + tool_id = call["function"]["name"].split("_")[-1] + action_name = call["function"]["name"].rsplit("_", 1)[0] + except (KeyError, TypeError) as e: + return None, None, None + else: + try: + call_args = json.loads(call.function.arguments) + tool_id = call.function.name.split("_")[-1] + action_name = call.function.name.rsplit("_", 1)[0] + except (AttributeError, TypeError) as e: + return None, None, None return tool_id, action_name, call_args def _parse_google_llm(self, call): diff --git a/application/llm/google_ai.py b/application/llm/google_ai.py index 31943601..d52e26c8 100644 --- a/application/llm/google_ai.py +++ b/application/llm/google_ai.py @@ -152,7 +152,15 @@ class GoogleLLM(BaseLLM): config=config, ) for chunk in response: - if chunk.text is not None: + if hasattr(chunk, "candidates") and chunk.candidates: + for candidate in chunk.candidates: + if candidate.content and candidate.content.parts: + for part in candidate.content.parts: + if part.function_call: + yield part + elif part.text: + yield part.text + elif hasattr(chunk, "text"): yield chunk.text def _supports_tools(self): diff --git a/application/llm/openai.py b/application/llm/openai.py index b8f311b0..938de523 100644 --- a/application/llm/openai.py +++ b/application/llm/openai.py @@ -111,13 +111,24 @@ class OpenAILLM(BaseLLM): **kwargs, ): messages = self._clean_messages_openai(messages) - response = self.client.chat.completions.create( - model=model, messages=messages, stream=stream, **kwargs - ) + if tools: + response = self.client.chat.completions.create( + model=model, + messages=messages, + stream=stream, + tools=tools, + **kwargs, + ) + else: + response = self.client.chat.completions.create( + model=model, messages=messages, stream=stream, **kwargs + ) for line in response: if line.choices[0].delta.content is not None: yield line.choices[0].delta.content + else: + yield line.choices[0] def _supports_tools(self): return True From 06edc261c0504b7defecaa466f139ac1e8278984 Mon Sep 17 00:00:00 2001 From: Alex Date: Wed, 5 Mar 2025 16:09:13 -0500 Subject: [PATCH 5/7] fix: duplicates... --- application/api/answer/routes.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/application/api/answer/routes.py b/application/api/answer/routes.py index fc00d47e..43511158 100644 --- a/application/api/answer/routes.py +++ b/application/api/answer/routes.py @@ -206,13 +206,6 @@ def get_prompt(prompt_id): return prompt def complete_stream( - question, - agent, - retriever, - conversation_id, - user_api_key, - isNoneDoc=False, - index=None, question, retriever, conversation_id, From 49a2b2ce6dd7b52db34a811c6c9e2a9eaf988733 Mon Sep 17 00:00:00 2001 From: Alex Date: Wed, 5 Mar 2025 16:11:06 -0500 Subject: [PATCH 6/7] fix: agent not forgotten --- application/api/answer/routes.py | 1 + 1 file changed, 1 insertion(+) diff --git a/application/api/answer/routes.py b/application/api/answer/routes.py index 43511158..c8c9708f 100644 --- a/application/api/answer/routes.py +++ b/application/api/answer/routes.py @@ -207,6 +207,7 @@ def get_prompt(prompt_id): def complete_stream( question, + agent, retriever, conversation_id, user_api_key, From d4f53bf6bb036247484a07e4f1fe546ad2c28014 Mon Sep 17 00:00:00 2001 From: Alex Date: Thu, 6 Mar 2025 14:31:46 +0000 Subject: [PATCH 7/7] fix: ruff check --- application/agents/tools/tool_action_parser.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/application/agents/tools/tool_action_parser.py b/application/agents/tools/tool_action_parser.py index 4d894d1a..c7da5a4c 100644 --- a/application/agents/tools/tool_action_parser.py +++ b/application/agents/tools/tool_action_parser.py @@ -1,4 +1,7 @@ import json +import logging + +logger = logging.getLogger(__name__) class ToolActionParser: @@ -20,6 +23,7 @@ class ToolActionParser: tool_id = call["function"]["name"].split("_")[-1] action_name = call["function"]["name"].rsplit("_", 1)[0] except (KeyError, TypeError) as e: + logger.error(f"Error parsing OpenAI LLM call: {e}") return None, None, None else: try: @@ -27,6 +31,7 @@ class ToolActionParser: tool_id = call.function.name.split("_")[-1] action_name = call.function.name.rsplit("_", 1)[0] except (AttributeError, TypeError) as e: + logger.error(f"Error parsing OpenAI LLM call: {e}") return None, None, None return tool_id, action_name, call_args