diff --git a/application/agents/agent_creator.py b/application/agents/agent_creator.py new file mode 100644 index 00000000..a76d9faf --- /dev/null +++ b/application/agents/agent_creator.py @@ -0,0 +1,14 @@ +from application.agents.classic_agent import ClassicAgent + + +class AgentCreator: + agents = { + "classic": ClassicAgent, + } + + @classmethod + def create_agent(cls, type, *args, **kwargs): + agent_class = cls.agents.get(type.lower()) + if not agent_class: + raise ValueError(f"No agent class found for type {type}") + return agent_class(*args, **kwargs) diff --git a/application/tools/agent.py b/application/agents/base.py similarity index 73% rename from application/tools/agent.py rename to application/agents/base.py index 10798862..7e36c991 100644 --- a/application/tools/agent.py +++ b/application/agents/base.py @@ -1,23 +1,28 @@ +from typing import Dict, Generator + +from application.agents.llm_handler import get_llm_handler +from application.agents.tools.tool_action_parser import ToolActionParser +from application.agents.tools.tool_manager import ToolManager + from application.core.mongo_db import MongoDB from application.llm.llm_creator import LLMCreator -from application.tools.llm_handler import get_llm_handler -from application.tools.tool_action_parser import ToolActionParser -from application.tools.tool_manager import ToolManager -class Agent: - def __init__(self, llm_name, gpt_model, api_key, user_api_key=None): - # Initialize the LLM with the provided parameters +class BaseAgent: + def __init__(self, endpoint, llm_name, gpt_model, api_key, user_api_key=None): + self.endpoint = endpoint self.llm = LLMCreator.create_llm( llm_name, api_key=api_key, user_api_key=user_api_key ) self.llm_handler = get_llm_handler(llm_name) self.gpt_model = gpt_model - # Static tool configuration (to be replaced later) self.tools = [] self.tool_config = {} self.tool_calls = [] + def gen(self, *args, **kwargs) -> Generator[Dict, None, None]: + raise NotImplementedError('Method "gen" must be implemented in the child class') + def _get_user_tools(self, user="local"): mongo = MongoDB.get_client() db = mongo["docsgpt"] @@ -135,50 +140,3 @@ class Agent: self.tool_calls.append(tool_call_data) return result, call_id - - def _simple_tool_agent(self, messages): - tools_dict = self._get_user_tools() - self._prepare_tools(tools_dict) - - resp = self.llm.gen(model=self.gpt_model, messages=messages, tools=self.tools) - - if isinstance(resp, str): - yield resp - return - if ( - hasattr(resp, "message") - and hasattr(resp.message, "content") - and resp.message.content is not None - ): - yield resp.message.content - return - - resp = self.llm_handler.handle_response(self, resp, tools_dict, messages) - - if isinstance(resp, str): - yield resp - elif ( - hasattr(resp, "message") - and hasattr(resp.message, "content") - and resp.message.content is not None - ): - yield resp.message.content - else: - completion = self.llm.gen_stream( - model=self.gpt_model, messages=messages, tools=self.tools - ) - for line in completion: - yield line - - return - - def gen(self, messages): - self.tool_calls = [] - if self.llm.supports_tools(): - resp = self._simple_tool_agent(messages) - for line in resp: - yield line - else: - resp = self.llm.gen_stream(model=self.gpt_model, messages=messages) - for line in resp: - yield line diff --git a/application/agents/classic_agent.py b/application/agents/classic_agent.py new file mode 100644 index 00000000..79d9e37f --- /dev/null +++ b/application/agents/classic_agent.py @@ -0,0 +1,135 @@ +import uuid +from typing import Dict, Generator + +from application.agents.base import BaseAgent +from application.logging import build_stack_data, log_activity, LogContext + +from application.retriever.base import BaseRetriever + + +class ClassicAgent(BaseAgent): + def __init__( + self, + endpoint, + llm_name, + gpt_model, + api_key, + user_api_key=None, + prompt="", + chat_history=None, + ): + super().__init__(endpoint, llm_name, gpt_model, api_key, user_api_key) + self.prompt = prompt + self.chat_history = chat_history if chat_history is not None else [] + + @log_activity() + def gen( + self, query: str, retriever: BaseRetriever, log_context: LogContext = None + ) -> Generator[Dict, None, None]: + yield from self._gen_inner(query, retriever, log_context) + + def _gen_inner( + self, query: str, retriever: BaseRetriever, log_context: LogContext + ) -> Generator[Dict, None, None]: + retrieved_data = self._retriever_search(retriever, query, log_context) + + docs_together = "\n".join([doc["text"] for doc in retrieved_data]) + p_chat_combine = self.prompt.replace("{summaries}", docs_together) + messages_combine = [{"role": "system", "content": p_chat_combine}] + + if len(self.chat_history) > 0: + for i in self.chat_history: + if "prompt" in i and "response" in i: + messages_combine.append({"role": "user", "content": i["prompt"]}) + messages_combine.append( + {"role": "assistant", "content": i["response"]} + ) + if "tool_calls" in i: + for tool_call in i["tool_calls"]: + call_id = tool_call.get("call_id") + if call_id is None or call_id == "None": + call_id = str(uuid.uuid4()) + + function_call_dict = { + "function_call": { + "name": tool_call.get("action_name"), + "args": tool_call.get("arguments"), + "call_id": call_id, + } + } + function_response_dict = { + "function_response": { + "name": tool_call.get("action_name"), + "response": {"result": tool_call.get("result")}, + "call_id": call_id, + } + } + + messages_combine.append( + {"role": "assistant", "content": [function_call_dict]} + ) + messages_combine.append( + {"role": "tool", "content": [function_response_dict]} + ) + messages_combine.append({"role": "user", "content": query}) + + tools_dict = self._get_user_tools() + self._prepare_tools(tools_dict) + + resp = self._llm_gen(messages_combine, log_context) + + if isinstance(resp, str): + yield {"answer": resp} + return + if ( + hasattr(resp, "message") + and hasattr(resp.message, "content") + and resp.message.content is not None + ): + yield {"answer": resp.message.content} + return + + resp = self._llm_handler(resp, tools_dict, messages_combine, log_context) + + if isinstance(resp, str): + yield {"answer": resp} + elif ( + hasattr(resp, "message") + and hasattr(resp.message, "content") + and resp.message.content is not None + ): + yield {"answer": resp.message.content} + else: + completion = self.llm.gen_stream( + model=self.gpt_model, messages=messages_combine, tools=self.tools + ) + for line in completion: + if isinstance(line, str): + yield {"answer": line} + + yield {"tool_calls": self.tool_calls.copy()} + + def _retriever_search(self, retriever, query, log_context): + retrieved_data = retriever.search(query) + if log_context: + data = build_stack_data(retriever, exclude_attributes=["llm"]) + log_context.stacks.append({"component": "retriever", "data": data}) + return retrieved_data + + def _llm_gen(self, messages_combine, log_context): + resp = self.llm.gen_stream( + model=self.gpt_model, messages=messages_combine, tools=self.tools + ) + if log_context: + data = build_stack_data(self.llm) + log_context.stacks.append({"component": "llm", "data": data}) + return resp + + def _llm_handler(self, resp, tools_dict, messages_combine, log_context): + resp = self.llm_handler.handle_response( + self, resp, tools_dict, messages_combine + ) + if log_context: + data = build_stack_data(self.llm_handler) + log_context.stacks.append({"component": "llm_handler", "data": data}) + return resp diff --git a/application/agents/llm_handler.py b/application/agents/llm_handler.py new file mode 100644 index 00000000..9267dc53 --- /dev/null +++ b/application/agents/llm_handler.py @@ -0,0 +1,250 @@ +import json +from abc import ABC, abstractmethod + +from application.logging import build_stack_data + + +class LLMHandler(ABC): + def __init__(self): + self.llm_calls = [] + self.tool_calls = [] + + @abstractmethod + def handle_response(self, agent, resp, tools_dict, messages, **kwargs): + pass + + +class OpenAILLMHandler(LLMHandler): + def handle_response(self, agent, resp, tools_dict, messages, stream: bool = True): + if not stream: + while hasattr(resp, "finish_reason") and resp.finish_reason == "tool_calls": + message = json.loads(resp.model_dump_json())["message"] + keys_to_remove = {"audio", "function_call", "refusal"} + filtered_data = { + k: v for k, v in message.items() if k not in keys_to_remove + } + messages.append(filtered_data) + + tool_calls = resp.message.tool_calls + for call in tool_calls: + try: + self.tool_calls.append(call) + tool_response, call_id = agent._execute_tool_action( + tools_dict, call + ) + function_call_dict = { + "function_call": { + "name": call.function.name, + "args": call.function.arguments, + "call_id": call_id, + } + } + function_response_dict = { + "function_response": { + "name": call.function.name, + "response": {"result": tool_response}, + "call_id": call_id, + } + } + + messages.append( + {"role": "assistant", "content": [function_call_dict]} + ) + messages.append( + {"role": "tool", "content": [function_response_dict]} + ) + + except Exception as e: + messages.append( + { + "role": "tool", + "content": f"Error executing tool: {str(e)}", + "tool_call_id": call_id, + } + ) + resp = agent.llm.gen_stream( + model=agent.gpt_model, messages=messages, tools=agent.tools + ) + self.llm_calls.append(build_stack_data(agent.llm)) + return resp + + else: + while True: + tool_calls = {} + for chunk in resp: + if isinstance(chunk, str): + return + else: + chunk_delta = chunk.delta + + if ( + hasattr(chunk_delta, "tool_calls") + and chunk_delta.tool_calls is not None + ): + for tool_call in chunk_delta.tool_calls: + index = tool_call.index + if index not in tool_calls: + tool_calls[index] = { + "id": "", + "function": {"name": "", "arguments": ""}, + } + + current = tool_calls[index] + if tool_call.id: + current["id"] = tool_call.id + if tool_call.function.name: + current["function"][ + "name" + ] = tool_call.function.name + if tool_call.function.arguments: + current["function"][ + "arguments" + ] += tool_call.function.arguments + tool_calls[index] = current + + if ( + hasattr(chunk, "finish_reason") + and chunk.finish_reason == "tool_calls" + ): + for index in sorted(tool_calls.keys()): + call = tool_calls[index] + try: + self.tool_calls.append(call) + tool_response, call_id = agent._execute_tool_action( + tools_dict, call + ) + + function_call_dict = { + "function_call": { + "name": call["function"]["name"], + "args": call["function"]["arguments"], + "call_id": call["id"], + } + } + function_response_dict = { + "function_response": { + "name": call["function"]["name"], + "response": {"result": tool_response}, + "call_id": call["id"], + } + } + + messages.append( + { + "role": "assistant", + "content": [function_call_dict], + } + ) + messages.append( + { + "role": "tool", + "content": [function_response_dict], + } + ) + + except Exception as e: + messages.append( + { + "role": "assistant", + "content": f"Error executing tool: {str(e)}", + } + ) + tool_calls = {} + + if ( + hasattr(chunk, "finish_reason") + and chunk.finish_reason == "stop" + ): + return + + resp = agent.llm.gen_stream( + model=agent.gpt_model, messages=messages, tools=agent.tools + ) + self.llm_calls.append(build_stack_data(agent.llm)) + + +class GoogleLLMHandler(LLMHandler): + def handle_response(self, agent, resp, tools_dict, messages, stream: bool = True): + from google.genai import types + + while True: + if not stream: + response = agent.llm.gen( + model=agent.gpt_model, messages=messages, tools=agent.tools + ) + self.llm_calls.append(build_stack_data(agent.llm)) + if response.candidates and response.candidates[0].content.parts: + tool_call_found = False + for part in response.candidates[0].content.parts: + if part.function_call: + tool_call_found = True + self.tool_calls.append(part.function_call) + tool_response, call_id = agent._execute_tool_action( + tools_dict, part.function_call + ) + function_response_part = types.Part.from_function_response( + name=part.function_call.name, + response={"result": tool_response}, + ) + + messages.append( + {"role": "model", "content": [part.to_json_dict()]} + ) + messages.append( + { + "role": "tool", + "content": [function_response_part.to_json_dict()], + } + ) + + if ( + not tool_call_found + and response.candidates[0].content.parts + and response.candidates[0].content.parts[0].text + ): + return response.candidates[0].content.parts[0].text + elif not tool_call_found: + return response.candidates[0].content.parts + + else: + return response + + else: + response = agent.llm.gen_stream( + model=agent.gpt_model, messages=messages, tools=agent.tools + ) + self.llm_calls.append(build_stack_data(agent.llm)) + + tool_call_found = False + for result in response: + if hasattr(result, "function_call"): + tool_call_found = True + self.tool_calls.append(result.function_call) + tool_response, call_id = agent._execute_tool_action( + tools_dict, result.function_call + ) + function_response_part = types.Part.from_function_response( + name=result.function_call.name, + response={"result": tool_response}, + ) + + messages.append( + {"role": "model", "content": [result.to_json_dict()]} + ) + messages.append( + { + "role": "tool", + "content": [function_response_part.to_json_dict()], + } + ) + + if not tool_call_found: + return response + + +def get_llm_handler(llm_type): + handlers = { + "openai": OpenAILLMHandler(), + "google": GoogleLLMHandler(), + } + return handlers.get(llm_type, OpenAILLMHandler()) diff --git a/application/tools/implementations/api_tool.py b/application/agents/tools/api_tool.py similarity index 63% rename from application/tools/implementations/api_tool.py rename to application/agents/tools/api_tool.py index 5d0fec70..06d5fb7a 100644 --- a/application/tools/implementations/api_tool.py +++ b/application/agents/tools/api_tool.py @@ -1,7 +1,7 @@ import json import requests -from application.tools.base import Tool +from application.agents.tools.base import Tool class APITool(Tool): @@ -31,10 +31,27 @@ class APITool(Tool): print(f"Making API call: {method} {url} with body: {body}") response = requests.request(method, url, headers=headers, data=body) response.raise_for_status() - try: - data = response.json() - except ValueError: + + content_type = response.headers.get( + "Content-Type", "application/json" + ).lower() + if "application/json" in content_type: + try: + data = response.json() + except json.JSONDecodeError as e: + print(f"Error decoding JSON: {e}. Raw response: {response.text}") + return { + "status_code": response.status_code, + "message": f"API call returned invalid JSON. Error: {e}", + "data": response.text, + } + elif "text/" in content_type or "application/xml" in content_type: + data = response.text + elif not response.content: data = None + else: + print(f"Unsupported content type: {content_type}") + data = response.content return { "status_code": response.status_code, diff --git a/application/tools/base.py b/application/agents/tools/base.py similarity index 100% rename from application/tools/base.py rename to application/agents/tools/base.py diff --git a/application/agents/tools/brave.py b/application/agents/tools/brave.py new file mode 100644 index 00000000..1428bea4 --- /dev/null +++ b/application/agents/tools/brave.py @@ -0,0 +1,217 @@ +import requests +from application.agents.tools.base import Tool + + +class BraveSearchTool(Tool): + """ + Brave Search + A tool for performing web and image searches using the Brave Search API. + Requires an API key for authentication. + """ + + def __init__(self, config): + self.config = config + self.token = config.get("token", "") + self.base_url = "https://api.search.brave.com/res/v1" + + def execute_action(self, action_name, **kwargs): + actions = { + "brave_web_search": self._web_search, + "brave_image_search": self._image_search, + } + + if action_name in actions: + return actions[action_name](**kwargs) + else: + raise ValueError(f"Unknown action: {action_name}") + + def _web_search(self, query, country="ALL", search_lang="en", count=10, + offset=0, safesearch="off", freshness=None, + result_filter=None, extra_snippets=False, summary=False): + """ + Performs a web search using the Brave Search API. + """ + print(f"Performing Brave web search for: {query}") + + url = f"{self.base_url}/web/search" + + # Build query parameters + params = { + "q": query, + "country": country, + "search_lang": search_lang, + "count": min(count, 20), + "offset": min(offset, 9), + "safesearch": safesearch + } + + # Add optional parameters only if they have values + if freshness: + params["freshness"] = freshness + if result_filter: + params["result_filter"] = result_filter + if extra_snippets: + params["extra_snippets"] = 1 + if summary: + params["summary"] = 1 + + # Set up headers + headers = { + "Accept": "application/json", + "Accept-Encoding": "gzip", + "X-Subscription-Token": self.token + } + + # Make the request + response = requests.get(url, params=params, headers=headers) + + if response.status_code == 200: + return { + "status_code": response.status_code, + "results": response.json(), + "message": "Search completed successfully." + } + else: + return { + "status_code": response.status_code, + "message": f"Search failed with status code: {response.status_code}." + } + + def _image_search(self, query, country="ALL", search_lang="en", count=5, + safesearch="off", spellcheck=False): + """ + Performs an image search using the Brave Search API. + """ + print(f"Performing Brave image search for: {query}") + + url = f"{self.base_url}/images/search" + + # Build query parameters + params = { + "q": query, + "country": country, + "search_lang": search_lang, + "count": min(count, 100), # API max is 100 + "safesearch": safesearch, + "spellcheck": 1 if spellcheck else 0 + } + + # Set up headers + headers = { + "Accept": "application/json", + "Accept-Encoding": "gzip", + "X-Subscription-Token": self.token + } + + # Make the request + response = requests.get(url, params=params, headers=headers) + + if response.status_code == 200: + return { + "status_code": response.status_code, + "results": response.json(), + "message": "Image search completed successfully." + } + else: + return { + "status_code": response.status_code, + "message": f"Image search failed with status code: {response.status_code}." + } + + def get_actions_metadata(self): + return [ + { + "name": "brave_web_search", + "description": "Perform a web search using Brave Search", + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "The search query (max 400 characters, 50 words)", + }, + # "country": { + # "type": "string", + # "description": "The 2-character country code (default: US)", + # }, + "search_lang": { + "type": "string", + "description": "The search language preference (default: en)", + }, + # "count": { + # "type": "integer", + # "description": "Number of results to return (max 20, default: 10)", + # }, + # "offset": { + # "type": "integer", + # "description": "Pagination offset (max 9, default: 0)", + # }, + # "safesearch": { + # "type": "string", + # "description": "Filter level for adult content (off, moderate, strict)", + # }, + "freshness": { + "type": "string", + "description": "Time filter for results (pd: last 24h, pw: last week, pm: last month, py: last year)", + }, + # "result_filter": { + # "type": "string", + # "description": "Comma-delimited list of result types to include", + # }, + # "extra_snippets": { + # "type": "boolean", + # "description": "Get additional excerpts from result pages", + # }, + # "summary": { + # "type": "boolean", + # "description": "Enable summary generation in search results", + # } + }, + "required": ["query"], + "additionalProperties": False, + }, + }, + { + "name": "brave_image_search", + "description": "Perform an image search using Brave Search", + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "The search query (max 400 characters, 50 words)", + }, + # "country": { + # "type": "string", + # "description": "The 2-character country code (default: US)", + # }, + # "search_lang": { + # "type": "string", + # "description": "The search language preference (default: en)", + # }, + "count": { + "type": "integer", + "description": "Number of results to return (max 100, default: 5)", + }, + # "safesearch": { + # "type": "string", + # "description": "Filter level for adult content (off, strict). Default: strict", + # }, + # "spellcheck": { + # "type": "boolean", + # "description": "Whether to spellcheck provided query (default: true)", + # } + }, + "required": ["query"], + "additionalProperties": False, + }, + } + ] + + def get_config_requirements(self): + return { + "token": { + "type": "string", + "description": "Brave Search API key for authentication" + }, + } \ No newline at end of file diff --git a/application/tools/implementations/cryptoprice.py b/application/agents/tools/cryptoprice.py similarity index 95% rename from application/tools/implementations/cryptoprice.py rename to application/agents/tools/cryptoprice.py index 7b88c866..c25c3d43 100644 --- a/application/tools/implementations/cryptoprice.py +++ b/application/agents/tools/cryptoprice.py @@ -1,5 +1,5 @@ import requests -from application.tools.base import Tool +from application.agents.tools.base import Tool class CryptoPriceTool(Tool): @@ -31,7 +31,6 @@ class CryptoPriceTool(Tool): response = requests.get(url) if response.status_code == 200: data = response.json() - # data will be like {"USD": } if the call is successful if currency.upper() in data: return { "status_code": response.status_code, diff --git a/application/tools/implementations/postgres.py b/application/agents/tools/postgres.py similarity index 99% rename from application/tools/implementations/postgres.py rename to application/agents/tools/postgres.py index a83db9aa..2877ebad 100644 --- a/application/tools/implementations/postgres.py +++ b/application/agents/tools/postgres.py @@ -1,5 +1,5 @@ import psycopg2 -from application.tools.base import Tool +from application.agents.tools.base import Tool class PostgresTool(Tool): """ diff --git a/application/tools/implementations/telegram.py b/application/agents/tools/telegram.py similarity index 98% rename from application/tools/implementations/telegram.py rename to application/agents/tools/telegram.py index a32bbe88..06350ae9 100644 --- a/application/tools/implementations/telegram.py +++ b/application/agents/tools/telegram.py @@ -1,5 +1,5 @@ import requests -from application.tools.base import Tool +from application.agents.tools.base import Tool class TelegramTool(Tool): diff --git a/application/agents/tools/tool_action_parser.py b/application/agents/tools/tool_action_parser.py new file mode 100644 index 00000000..c7da5a4c --- /dev/null +++ b/application/agents/tools/tool_action_parser.py @@ -0,0 +1,42 @@ +import json +import logging + +logger = logging.getLogger(__name__) + + +class ToolActionParser: + def __init__(self, llm_type): + self.llm_type = llm_type + self.parsers = { + "OpenAILLM": self._parse_openai_llm, + "GoogleLLM": self._parse_google_llm, + } + + def parse_args(self, call): + parser = self.parsers.get(self.llm_type, self._parse_openai_llm) + return parser(call) + + def _parse_openai_llm(self, call): + if isinstance(call, dict): + try: + call_args = json.loads(call["function"]["arguments"]) + tool_id = call["function"]["name"].split("_")[-1] + action_name = call["function"]["name"].rsplit("_", 1)[0] + except (KeyError, TypeError) as e: + logger.error(f"Error parsing OpenAI LLM call: {e}") + return None, None, None + else: + try: + call_args = json.loads(call.function.arguments) + tool_id = call.function.name.split("_")[-1] + action_name = call.function.name.rsplit("_", 1)[0] + except (AttributeError, TypeError) as e: + logger.error(f"Error parsing OpenAI LLM call: {e}") + return None, None, None + return tool_id, action_name, call_args + + def _parse_google_llm(self, call): + call_args = call.args + tool_id = call.name.split("_")[-1] + action_name = call.name.rsplit("_", 1)[0] + return tool_id, action_name, call_args diff --git a/application/tools/tool_manager.py b/application/agents/tools/tool_manager.py similarity index 79% rename from application/tools/tool_manager.py rename to application/agents/tools/tool_manager.py index 3e0766cf..ad71db28 100644 --- a/application/tools/tool_manager.py +++ b/application/agents/tools/tool_manager.py @@ -3,7 +3,7 @@ import inspect import os import pkgutil -from application.tools.base import Tool +from application.agents.tools.base import Tool class ToolManager: @@ -13,13 +13,11 @@ class ToolManager: self.load_tools() def load_tools(self): - tools_dir = os.path.join(os.path.dirname(__file__), "implementations") + tools_dir = os.path.join(os.path.dirname(__file__)) for finder, name, ispkg in pkgutil.iter_modules([tools_dir]): if name == "base" or name.startswith("__"): continue - module = importlib.import_module( - f"application.tools.implementations.{name}" - ) + module = importlib.import_module(f"application.agents.tools.{name}") for member_name, obj in inspect.getmembers(module, inspect.isclass): if issubclass(obj, Tool) and obj is not Tool: tool_config = self.config.get(name, {}) @@ -27,9 +25,7 @@ class ToolManager: def load_tool(self, tool_name, tool_config): self.config[tool_name] = tool_config - module = importlib.import_module( - f"application.tools.implementations.{tool_name}" - ) + module = importlib.import_module(f"application.agents.tools.{tool_name}") for member_name, obj in inspect.getmembers(module, inspect.isclass): if issubclass(obj, Tool) and obj is not Tool: return obj(tool_config) diff --git a/application/api/answer/routes.py b/application/api/answer/routes.py index a2e2d1af..c8c9708f 100644 --- a/application/api/answer/routes.py +++ b/application/api/answer/routes.py @@ -1,15 +1,16 @@ import asyncio import datetime import json +import logging import os import traceback -import logging from bson.dbref import DBRef from bson.objectid import ObjectId from flask import Blueprint, make_response, request, Response from flask_restx import fields, Namespace, Resource +from application.agents.agent_creator import AgentCreator from application.core.mongo_db import MongoDB from application.core.settings import settings @@ -206,6 +207,7 @@ def get_prompt(prompt_id): def complete_stream( question, + agent, retriever, conversation_id, user_api_key, @@ -217,8 +219,8 @@ def complete_stream( response_full = "" source_log_docs = [] tool_calls = [] - answer = retriever.gen() - sources = retriever.search() + answer = agent.gen(query=question, retriever=retriever) + sources = retriever.search(question) for source in sources: if "text" in source: source["text"] = source["text"][:100].strip() + "..." @@ -384,9 +386,20 @@ class Stream(Resource): prompt = get_prompt(prompt_id) if "isNoneDoc" in data and data["isNoneDoc"] is True: chunks = 0 + + agent = AgentCreator.create_agent( + settings.AGENT_NAME, + endpoint="stream", + llm_name=settings.LLM_NAME, + gpt_model=gpt_model, + api_key=settings.API_KEY, + user_api_key=user_api_key, + prompt=prompt, + chat_history=history, + ) + retriever = RetrieverCreator.create_retriever( retriever_name, - question=question, source=source, chat_history=history, prompt=prompt, @@ -399,6 +412,7 @@ class Stream(Resource): return Response( complete_stream( question=question, + agent=agent, retriever=retriever, conversation_id=conversation_id, user_api_key=user_api_key, diff --git a/application/api/user/routes.py b/application/api/user/routes.py index 6204ada4..d7fb4d89 100644 --- a/application/api/user/routes.py +++ b/application/api/user/routes.py @@ -1,9 +1,9 @@ import datetime +import json import math import os import shutil import uuid -import json from bson.binary import Binary, UuidRepresentation from bson.dbref import DBRef @@ -12,12 +12,13 @@ from flask import Blueprint, current_app, jsonify, make_response, redirect, requ from flask_restx import fields, inputs, Namespace, Resource from werkzeug.utils import secure_filename +from application.agents.tools.tool_manager import ToolManager + from application.api.user.tasks import ingest, ingest_remote from application.core.mongo_db import MongoDB from application.core.settings import settings from application.extensions import api -from application.tools.tool_manager import ToolManager from application.tts.google_tts import GoogleTTS from application.utils import check_required_fields, validate_function_name from application.vectorstore.vector_creator import VectorCreator @@ -449,22 +450,21 @@ class UploadRemote(Resource): return missing_fields try: - config = json.loads(data["data"]) - source_data = None + config = json.loads(data["data"]) + source_data = None - if data["source"] == "github": + if data["source"] == "github": source_data = config.get("repo_url") - elif data["source"] in ["crawler", "url"]: + elif data["source"] in ["crawler", "url"]: source_data = config.get("url") - elif data["source"] == "reddit": - source_data = config + elif data["source"] == "reddit": + source_data = config - - task = ingest_remote.delay( + task = ingest_remote.delay( source_data=source_data, job_name=data["name"], user=data["user"], - loader=data["source"] + loader=data["source"], ) except Exception as err: current_app.logger.error(f"Error uploading remote source: {err}") @@ -1932,11 +1932,14 @@ class UpdateTool(Resource): for action_name in list(data["config"]["actions"].keys()): if not validate_function_name(action_name): return make_response( - jsonify({ - "success": False, - "message": f"Invalid function name '{action_name}'. Function names must match pattern '^[a-zA-Z0-9_-]+$'.", - "param": "tools[].function.name" - }), 400 + jsonify( + { + "success": False, + "message": f"Invalid function name '{action_name}'. Function names must match pattern '^[a-zA-Z0-9_-]+$'.", + "param": "tools[].function.name", + } + ), + 400, ) update_data["config"] = data["config"] if "status" in data: diff --git a/application/core/settings.py b/application/core/settings.py index 5842da33..04d7bbea 100644 --- a/application/core/settings.py +++ b/application/core/settings.py @@ -32,6 +32,7 @@ class Settings(BaseSettings): "faiss" # "faiss" or "elasticsearch" or "qdrant" or "milvus" or "lancedb" ) RETRIEVERS_ENABLED: list = ["classic_rag", "duckduck_search"] # also brave_search + AGENT_NAME: str = "classic" # LLM Cache CACHE_REDIS_URL: str = "redis://localhost:6379/2" diff --git a/application/llm/google_ai.py b/application/llm/google_ai.py index 48254349..5e33550c 100644 --- a/application/llm/google_ai.py +++ b/application/llm/google_ai.py @@ -152,7 +152,15 @@ class GoogleLLM(BaseLLM): config=config, ) for chunk in response: - if chunk.text is not None: + if hasattr(chunk, "candidates") and chunk.candidates: + for candidate in chunk.candidates: + if candidate.content and candidate.content.parts: + for part in candidate.content.parts: + if part.function_call: + yield part + elif part.text: + yield part.text + elif hasattr(chunk, "text"): yield chunk.text def _supports_tools(self): diff --git a/application/llm/openai.py b/application/llm/openai.py index b8f311b0..938de523 100644 --- a/application/llm/openai.py +++ b/application/llm/openai.py @@ -111,13 +111,24 @@ class OpenAILLM(BaseLLM): **kwargs, ): messages = self._clean_messages_openai(messages) - response = self.client.chat.completions.create( - model=model, messages=messages, stream=stream, **kwargs - ) + if tools: + response = self.client.chat.completions.create( + model=model, + messages=messages, + stream=stream, + tools=tools, + **kwargs, + ) + else: + response = self.client.chat.completions.create( + model=model, messages=messages, stream=stream, **kwargs + ) for line in response: if line.choices[0].delta.content is not None: yield line.choices[0].delta.content + else: + yield line.choices[0] def _supports_tools(self): return True diff --git a/application/logging.py b/application/logging.py new file mode 100644 index 00000000..1dd0d557 --- /dev/null +++ b/application/logging.py @@ -0,0 +1,151 @@ +import datetime +import functools +import inspect + +import logging +import uuid +from typing import Any, Callable, Dict, Generator, List + +from application.core.mongo_db import MongoDB + +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) + + +class LogContext: + def __init__(self, endpoint, activity_id, user, api_key, query): + self.endpoint = endpoint + self.activity_id = activity_id + self.user = user + self.api_key = api_key + self.query = query + self.stacks = [] + + +def build_stack_data( + obj: Any, + include_attributes: List[str] = None, + exclude_attributes: List[str] = None, + custom_data: Dict = None, +) -> Dict: + data = {} + if include_attributes is None: + include_attributes = [] + for name, value in inspect.getmembers(obj): + if ( + not name.startswith("_") + and not inspect.ismethod(value) + and not inspect.isfunction(value) + ): + include_attributes.append(name) + for attr_name in include_attributes: + if exclude_attributes and attr_name in exclude_attributes: + continue + try: + attr_value = getattr(obj, attr_name) + if attr_value is not None: + if isinstance(attr_value, (int, float, str, bool)): + data[attr_name] = attr_value + elif isinstance(attr_value, list): + if all(isinstance(item, dict) for item in attr_value): + data[attr_name] = attr_value + elif all(hasattr(item, "__dict__") for item in attr_value): + data[attr_name] = [item.__dict__ for item in attr_value] + else: + data[attr_name] = [str(item) for item in attr_value] + elif isinstance(attr_value, dict): + data[attr_name] = {k: str(v) for k, v in attr_value.items()} + else: + data[attr_name] = str(attr_value) + except AttributeError: + pass + if custom_data: + data.update(custom_data) + return data + + +def log_activity() -> Callable: + def decorator(func: Callable) -> Callable: + @functools.wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Any: + activity_id = str(uuid.uuid4()) + data = build_stack_data(args[0]) + endpoint = data.get("endpoint", "") + user = data.get("user", "local") + api_key = data.get("user_api_key", "") + query = kwargs.get("query", getattr(args[0], "query", "")) + + context = LogContext(endpoint, activity_id, user, api_key, query) + kwargs["log_context"] = context + + logging.info( + f"Starting activity: {endpoint} - {activity_id} - User: {user}" + ) + + generator = func(*args, **kwargs) + yield from _consume_and_log(generator, context) + + return wrapper + + return decorator + + +def _consume_and_log(generator: Generator, context: "LogContext"): + try: + for item in generator: + yield item + except Exception as e: + logging.exception(f"Error in {context.endpoint} - {context.activity_id}: {e}") + context.stacks.append({"component": "error", "data": {"message": str(e)}}) + _log_to_mongodb( + endpoint=context.endpoint, + activity_id=context.activity_id, + user=context.user, + api_key=context.api_key, + query=context.query, + stacks=context.stacks, + level="error", + ) + raise + finally: + _log_to_mongodb( + endpoint=context.endpoint, + activity_id=context.activity_id, + user=context.user, + api_key=context.api_key, + query=context.query, + stacks=context.stacks, + level="info", + ) + + +def _log_to_mongodb( + endpoint: str, + activity_id: str, + user: str, + api_key: str, + query: str, + stacks: List[Dict], + level: str, +) -> None: + try: + mongo = MongoDB.get_client() + db = mongo["docsgpt"] + user_logs_collection = db["stack_logs"] + + log_entry = { + "endpoint": endpoint, + "id": activity_id, + "level": level, + "user": user, + "api_key": api_key, + "query": query, + "stacks": stacks, + "timestamp": datetime.datetime.now(datetime.timezone.utc), + } + user_logs_collection.insert_one(log_entry) + logging.debug(f"Logged activity to MongoDB: {activity_id}") + + except Exception as e: + logging.error(f"Failed to log to MongoDB: {e}") diff --git a/application/retriever/classic_rag.py b/application/retriever/classic_rag.py index ca40f966..03f17f44 100644 --- a/application/retriever/classic_rag.py +++ b/application/retriever/classic_rag.py @@ -1,28 +1,25 @@ -import uuid - from application.core.settings import settings +from application.llm.llm_creator import LLMCreator from application.retriever.base import BaseRetriever -from application.tools.agent import Agent from application.vectorstore.vector_creator import VectorCreator class ClassicRAG(BaseRetriever): - def __init__( self, - question, source, - chat_history, - prompt, + chat_history=None, + prompt="", chunks=2, token_limit=150, gpt_model="docsgpt", user_api_key=None, + llm_name=settings.LLM_NAME, + api_key=settings.API_KEY, ): - self.question = question - self.vectorstore = source["active_docs"] if "active_docs" in source else None - self.chat_history = chat_history + self.original_question = "" + self.chat_history = chat_history if chat_history is not None else [] self.prompt = prompt self.chunks = chunks self.gpt_model = gpt_model @@ -37,12 +34,41 @@ class ClassicRAG(BaseRetriever): ) ) self.user_api_key = user_api_key - self.agent = Agent( - llm_name=settings.LLM_NAME, - gpt_model=self.gpt_model, - api_key=settings.API_KEY, - user_api_key=self.user_api_key, + self.llm_name = llm_name + self.api_key = api_key + self.llm = LLMCreator.create_llm( + self.llm_name, api_key=self.api_key, user_api_key=self.user_api_key ) + self.question = self._rephrase_query() + self.vectorstore = source["active_docs"] if "active_docs" in source else None + + def _rephrase_query(self): + if ( + not self.original_question + or not self.chat_history + or self.chat_history == [] + ): + return self.original_question + + prompt = f"""Given the following conversation history: + {self.chat_history} + + Rephrase the following user question to be a standalone search query + that captures all relevant context from the conversation: + """ + + messages = [ + {"role": "system", "content": prompt}, + {"role": "user", "content": self.original_question}, + ] + + try: + rephrased_query = self.llm.gen(model=self.gpt_model, messages=messages) + print(f"Rephrased query: {rephrased_query}") + return rephrased_query if rephrased_query else self.original_question + except Exception as e: + print(f"Error rephrasing query: {e}") + return self.original_question def _get_data(self): if self.chunks == 0: @@ -69,68 +95,20 @@ class ClassicRAG(BaseRetriever): return docs - def gen(self): - docs = self._get_data() + def gen(): + pass - # join all page_content together with a newline - docs_together = "\n".join([doc["text"] for doc in docs]) - p_chat_combine = self.prompt.replace("{summaries}", docs_together) - messages_combine = [{"role": "system", "content": p_chat_combine}] - for doc in docs: - yield {"source": doc} - - if len(self.chat_history) > 0: - for i in self.chat_history: - if "prompt" in i and "response" in i: - messages_combine.append({"role": "user", "content": i["prompt"]}) - messages_combine.append( - {"role": "assistant", "content": i["response"]} - ) - if "tool_calls" in i: - for tool_call in i["tool_calls"]: - call_id = tool_call.get("call_id") - if call_id is None or call_id == "None": - call_id = str(uuid.uuid4()) - - function_call_dict = { - "function_call": { - "name": tool_call.get("action_name"), - "args": tool_call.get("arguments"), - "call_id": call_id, - } - } - function_response_dict = { - "function_response": { - "name": tool_call.get("action_name"), - "response": {"result": tool_call.get("result")}, - "call_id": call_id, - } - } - - messages_combine.append( - {"role": "assistant", "content": [function_call_dict]} - ) - messages_combine.append( - {"role": "tool", "content": [function_response_dict]} - ) - - messages_combine.append({"role": "user", "content": self.question}) - completion = self.agent.gen(messages_combine) - - for line in completion: - yield {"answer": str(line)} - - yield {"tool_calls": self.agent.tool_calls.copy()} - - def search(self): + def search(self, query: str = ""): + if query: + self.original_question = query + self.question = self._rephrase_query() return self._get_data() def get_params(self): return { - "question": self.question, + "question": self.original_question, + "rephrased_question": self.question, "source": self.vectorstore, - "chat_history": self.chat_history, - "prompt": self.prompt, "chunks": self.chunks, "token_limit": self.token_limit, "gpt_model": self.gpt_model, diff --git a/application/tools/llm_handler.py b/application/tools/llm_handler.py deleted file mode 100644 index 334d2c4c..00000000 --- a/application/tools/llm_handler.py +++ /dev/null @@ -1,112 +0,0 @@ -import json -from abc import ABC, abstractmethod - - -class LLMHandler(ABC): - @abstractmethod - def handle_response(self, agent, resp, tools_dict, messages, **kwargs): - pass - - -class OpenAILLMHandler(LLMHandler): - def handle_response(self, agent, resp, tools_dict, messages): - while resp.finish_reason == "tool_calls": - message = json.loads(resp.model_dump_json())["message"] - keys_to_remove = {"audio", "function_call", "refusal"} - filtered_data = { - k: v for k, v in message.items() if k not in keys_to_remove - } - messages.append(filtered_data) - - tool_calls = resp.message.tool_calls - for call in tool_calls: - try: - tool_response, call_id = agent._execute_tool_action( - tools_dict, call - ) - function_call_dict = { - "function_call": { - "name": call.function.name, - "args": call.function.arguments, - "call_id": call_id, - } - } - function_response_dict = { - "function_response": { - "name": call.function.name, - "response": {"result": tool_response}, - "call_id": call_id, - } - } - - messages.append( - {"role": "assistant", "content": [function_call_dict]} - ) - messages.append( - {"role": "tool", "content": [function_response_dict]} - ) - - except Exception as e: - messages.append( - { - "role": "tool", - "content": f"Error executing tool: {str(e)}", - "tool_call_id": call_id, - } - ) - resp = agent.llm.gen( - model=agent.gpt_model, messages=messages, tools=agent.tools - ) - return resp - - -class GoogleLLMHandler(LLMHandler): - def handle_response(self, agent, resp, tools_dict, messages): - from google.genai import types - - while True: - response = agent.llm.gen( - model=agent.gpt_model, messages=messages, tools=agent.tools - ) - if response.candidates and response.candidates[0].content.parts: - tool_call_found = False - for part in response.candidates[0].content.parts: - if part.function_call: - tool_call_found = True - tool_response, call_id = agent._execute_tool_action( - tools_dict, part.function_call - ) - function_response_part = types.Part.from_function_response( - name=part.function_call.name, - response={"result": tool_response}, - ) - - messages.append( - {"role": "model", "content": [part.to_json_dict()]} - ) - messages.append( - { - "role": "tool", - "content": [function_response_part.to_json_dict()], - } - ) - - if ( - not tool_call_found - and response.candidates[0].content.parts - and response.candidates[0].content.parts[0].text - ): - return response.candidates[0].content.parts[0].text - elif not tool_call_found: - return response.candidates[0].content.parts - - else: - return response - - -def get_llm_handler(llm_type): - handlers = { - "openai": OpenAILLMHandler(), - "google": GoogleLLMHandler(), - } - return handlers.get(llm_type, OpenAILLMHandler()) diff --git a/application/tools/tool_action_parser.py b/application/tools/tool_action_parser.py deleted file mode 100644 index ac0a70c1..00000000 --- a/application/tools/tool_action_parser.py +++ /dev/null @@ -1,26 +0,0 @@ -import json - - -class ToolActionParser: - def __init__(self, llm_type): - self.llm_type = llm_type - self.parsers = { - "OpenAILLM": self._parse_openai_llm, - "GoogleLLM": self._parse_google_llm, - } - - def parse_args(self, call): - parser = self.parsers.get(self.llm_type, self._parse_openai_llm) - return parser(call) - - def _parse_openai_llm(self, call): - call_args = json.loads(call.function.arguments) - tool_id = call.function.name.split("_")[-1] - action_name = call.function.name.rsplit("_", 1)[0] - return tool_id, action_name, call_args - - def _parse_google_llm(self, call): - call_args = call.args - tool_id = call.name.split("_")[-1] - action_name = call.name.rsplit("_", 1)[0] - return tool_id, action_name, call_args diff --git a/application/usage.py b/application/usage.py index fe4cd50e..a18a3848 100644 --- a/application/usage.py +++ b/application/usage.py @@ -1,7 +1,8 @@ import sys from datetime import datetime + from application.core.mongo_db import MongoDB -from application.utils import num_tokens_from_string, num_tokens_from_object_or_list +from application.utils import num_tokens_from_object_or_list, num_tokens_from_string mongo = MongoDB.get_client() db = mongo["docsgpt"] @@ -24,13 +25,16 @@ def gen_token_usage(func): def wrapper(self, model, messages, stream, tools, **kwargs): for message in messages: if message["content"]: - self.token_usage["prompt_tokens"] += num_tokens_from_string(message["content"]) + self.token_usage["prompt_tokens"] += num_tokens_from_string( + message["content"] + ) result = func(self, model, messages, stream, tools, **kwargs) - # check if result is a string if isinstance(result, str): self.token_usage["generated_tokens"] += num_tokens_from_string(result) else: - self.token_usage["generated_tokens"] += num_tokens_from_object_or_list(result) + self.token_usage["generated_tokens"] += num_tokens_from_object_or_list( + result + ) update_token_usage(self.user_api_key, self.token_usage) return result @@ -40,7 +44,9 @@ def gen_token_usage(func): def stream_token_usage(func): def wrapper(self, model, messages, stream, tools, **kwargs): for message in messages: - self.token_usage["prompt_tokens"] += num_tokens_from_string(message["content"]) + self.token_usage["prompt_tokens"] += num_tokens_from_string( + message["content"] + ) batch = [] result = func(self, model, messages, stream, tools, **kwargs) for r in result: diff --git a/frontend/public/toolIcons/tool_brave.svg b/frontend/public/toolIcons/tool_brave.svg new file mode 100644 index 00000000..380c19ed --- /dev/null +++ b/frontend/public/toolIcons/tool_brave.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/frontend/src/conversation/ConversationBubble.tsx b/frontend/src/conversation/ConversationBubble.tsx index 883af5b0..811b90a5 100644 --- a/frontend/src/conversation/ConversationBubble.tsx +++ b/frontend/src/conversation/ConversationBubble.tsx @@ -140,7 +140,15 @@ const ConversationBubble = forwardRef< >