diff --git a/application/agents/agent_creator.py b/application/agents/agent_creator.py new file mode 100644 index 00000000..a76d9faf --- /dev/null +++ b/application/agents/agent_creator.py @@ -0,0 +1,14 @@ +from application.agents.classic_agent import ClassicAgent + + +class AgentCreator: + agents = { + "classic": ClassicAgent, + } + + @classmethod + def create_agent(cls, type, *args, **kwargs): + agent_class = cls.agents.get(type.lower()) + if not agent_class: + raise ValueError(f"No agent class found for type {type}") + return agent_class(*args, **kwargs) diff --git a/application/agents/base.py b/application/agents/base.py index 93dcb4e2..7e36c991 100644 --- a/application/agents/base.py +++ b/application/agents/base.py @@ -9,7 +9,8 @@ from application.llm.llm_creator import LLMCreator class BaseAgent: - def __init__(self, llm_name, gpt_model, api_key, user_api_key=None): + def __init__(self, endpoint, llm_name, gpt_model, api_key, user_api_key=None): + self.endpoint = endpoint self.llm = LLMCreator.create_llm( llm_name, api_key=api_key, user_api_key=user_api_key ) @@ -19,7 +20,7 @@ class BaseAgent: self.tool_config = {} self.tool_calls = [] - def gen(self, query: str) -> Generator[Dict, None, None]: + def gen(self, *args, **kwargs) -> Generator[Dict, None, None]: raise NotImplementedError('Method "gen" must be implemented in the child class') def _get_user_tools(self, user="local"): diff --git a/application/agents/classic_agent.py b/application/agents/classic_agent.py index c7846a04..4e64442d 100644 --- a/application/agents/classic_agent.py +++ b/application/agents/classic_agent.py @@ -2,6 +2,7 @@ import uuid from typing import Dict, Generator from application.agents.base import BaseAgent +from application.logging import build_stack_data, log_activity, LogContext from application.retriever.base import BaseRetriever @@ -9,6 +10,7 @@ from application.retriever.base import BaseRetriever class ClassicAgent(BaseAgent): def __init__( self, + endpoint, llm_name, gpt_model, api_key, @@ -16,13 +18,21 @@ class ClassicAgent(BaseAgent): prompt="", chat_history=None, ): - super().__init__(llm_name, gpt_model, api_key, user_api_key) + super().__init__(endpoint, llm_name, gpt_model, api_key, user_api_key) self.prompt = prompt self.chat_history = chat_history if chat_history is not None else [] - def gen(self, query: str, retriever: BaseRetriever) -> Generator[Dict, None, None]: + @log_activity() + def gen( + self, query: str, retriever: BaseRetriever, log_context: LogContext = None + ) -> Generator[Dict, None, None]: + yield from self._gen_inner(query, retriever, log_context) + + def _gen_inner( + self, query: str, retriever: BaseRetriever, log_context: LogContext + ) -> Generator[Dict, None, None]: + retrieved_data = self._retriever_search(retriever, query, log_context) - retrieved_data = retriever.search(query) docs_together = "\n".join([doc["text"] for doc in retrieved_data]) p_chat_combine = self.prompt.replace("{summaries}", docs_together) messages_combine = [{"role": "system", "content": p_chat_combine}] @@ -66,9 +76,7 @@ class ClassicAgent(BaseAgent): tools_dict = self._get_user_tools() self._prepare_tools(tools_dict) - resp = self.llm.gen( - model=self.gpt_model, messages=messages_combine, tools=self.tools - ) + resp = self._llm_gen(messages_combine, log_context) if isinstance(resp, str): yield {"answer": resp} @@ -81,9 +89,7 @@ class ClassicAgent(BaseAgent): yield {"answer": resp.message.content} return - resp = self.llm_handler.handle_response( - self, resp, tools_dict, messages_combine - ) + resp = self._llm_handler(resp, tools_dict, messages_combine, log_context) if isinstance(resp, str): yield {"answer": resp} @@ -101,3 +107,29 @@ class ClassicAgent(BaseAgent): yield {"answer": line} yield {"tool_calls": self.tool_calls.copy()} + + def _retriever_search(self, retriever, query, log_context): + retrieved_data = retriever.search(query) + if log_context: + data = build_stack_data(retriever, exclude_attributes=["llm"]) + log_context.stacks.append({"component": "retriever", "data": data}) + return retrieved_data + + def _llm_gen(self, messages_combine, log_context): + resp = self.llm.gen( + model=self.gpt_model, messages=messages_combine, tools=self.tools + ) + if log_context: + data = build_stack_data(self.llm) + log_context.stacks.append({"component": "llm", "data": data}) + return resp + + def _llm_handler(self, resp, tools_dict, messages_combine, log_context): + resp = self.llm_handler.handle_response( + self, resp, tools_dict, messages_combine + ) + if log_context: + data = build_stack_data(self.llm_handler) + log_context.stacks.append({"component": "llm_handler", "data": data}) + + return resp diff --git a/application/agents/llm_handler.py b/application/agents/llm_handler.py index 334d2c4c..adf240c3 100644 --- a/application/agents/llm_handler.py +++ b/application/agents/llm_handler.py @@ -1,8 +1,14 @@ import json from abc import ABC, abstractmethod +from application.logging import build_stack_data + class LLMHandler(ABC): + def __init__(self): + self.llm_calls = [] + self.tool_calls = [] + @abstractmethod def handle_response(self, agent, resp, tools_dict, messages, **kwargs): pass @@ -21,6 +27,7 @@ class OpenAILLMHandler(LLMHandler): tool_calls = resp.message.tool_calls for call in tool_calls: try: + self.tool_calls.append(call) tool_response, call_id = agent._execute_tool_action( tools_dict, call ) @@ -57,6 +64,7 @@ class OpenAILLMHandler(LLMHandler): resp = agent.llm.gen( model=agent.gpt_model, messages=messages, tools=agent.tools ) + self.llm_calls.append(build_stack_data(agent.llm)) return resp @@ -68,11 +76,13 @@ class GoogleLLMHandler(LLMHandler): response = agent.llm.gen( model=agent.gpt_model, messages=messages, tools=agent.tools ) + self.llm_calls.append(build_stack_data(agent.llm)) if response.candidates and response.candidates[0].content.parts: tool_call_found = False for part in response.candidates[0].content.parts: if part.function_call: tool_call_found = True + self.tool_calls.append(part.function_call) tool_response, call_id = agent._execute_tool_action( tools_dict, part.function_call ) diff --git a/application/api/answer/routes.py b/application/api/answer/routes.py index d21c256e..b249f058 100644 --- a/application/api/answer/routes.py +++ b/application/api/answer/routes.py @@ -10,7 +10,7 @@ from bson.objectid import ObjectId from flask import Blueprint, make_response, request, Response from flask_restx import fields, Namespace, Resource -from application.agents.classic_agent import ClassicAgent +from application.agents.agent_creator import AgentCreator from application.core.mongo_db import MongoDB from application.core.settings import settings @@ -213,7 +213,7 @@ def complete_stream( response_full = "" source_log_docs = [] tool_calls = [] - answer = agent.gen(question, retriever) + answer = agent.gen(query=question, retriever=retriever) sources = retriever.search(question) for source in sources: if "text" in source: @@ -368,14 +368,18 @@ class Stream(Resource): prompt = get_prompt(prompt_id) if "isNoneDoc" in data and data["isNoneDoc"] is True: chunks = 0 - agent = ClassicAgent( - settings.LLM_NAME, - gpt_model, - settings.API_KEY, + + agent = AgentCreator.create_agent( + settings.AGENT_NAME, + endpoint="stream", + llm_name=settings.LLM_NAME, + gpt_model=gpt_model, + api_key=settings.API_KEY, user_api_key=user_api_key, prompt=prompt, chat_history=history, ) + retriever = RetrieverCreator.create_retriever( retriever_name, source=source, diff --git a/application/core/settings.py b/application/core/settings.py index 5842da33..04d7bbea 100644 --- a/application/core/settings.py +++ b/application/core/settings.py @@ -32,6 +32,7 @@ class Settings(BaseSettings): "faiss" # "faiss" or "elasticsearch" or "qdrant" or "milvus" or "lancedb" ) RETRIEVERS_ENABLED: list = ["classic_rag", "duckduck_search"] # also brave_search + AGENT_NAME: str = "classic" # LLM Cache CACHE_REDIS_URL: str = "redis://localhost:6379/2" diff --git a/application/logging.py b/application/logging.py new file mode 100644 index 00000000..1dd0d557 --- /dev/null +++ b/application/logging.py @@ -0,0 +1,151 @@ +import datetime +import functools +import inspect + +import logging +import uuid +from typing import Any, Callable, Dict, Generator, List + +from application.core.mongo_db import MongoDB + +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) + + +class LogContext: + def __init__(self, endpoint, activity_id, user, api_key, query): + self.endpoint = endpoint + self.activity_id = activity_id + self.user = user + self.api_key = api_key + self.query = query + self.stacks = [] + + +def build_stack_data( + obj: Any, + include_attributes: List[str] = None, + exclude_attributes: List[str] = None, + custom_data: Dict = None, +) -> Dict: + data = {} + if include_attributes is None: + include_attributes = [] + for name, value in inspect.getmembers(obj): + if ( + not name.startswith("_") + and not inspect.ismethod(value) + and not inspect.isfunction(value) + ): + include_attributes.append(name) + for attr_name in include_attributes: + if exclude_attributes and attr_name in exclude_attributes: + continue + try: + attr_value = getattr(obj, attr_name) + if attr_value is not None: + if isinstance(attr_value, (int, float, str, bool)): + data[attr_name] = attr_value + elif isinstance(attr_value, list): + if all(isinstance(item, dict) for item in attr_value): + data[attr_name] = attr_value + elif all(hasattr(item, "__dict__") for item in attr_value): + data[attr_name] = [item.__dict__ for item in attr_value] + else: + data[attr_name] = [str(item) for item in attr_value] + elif isinstance(attr_value, dict): + data[attr_name] = {k: str(v) for k, v in attr_value.items()} + else: + data[attr_name] = str(attr_value) + except AttributeError: + pass + if custom_data: + data.update(custom_data) + return data + + +def log_activity() -> Callable: + def decorator(func: Callable) -> Callable: + @functools.wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Any: + activity_id = str(uuid.uuid4()) + data = build_stack_data(args[0]) + endpoint = data.get("endpoint", "") + user = data.get("user", "local") + api_key = data.get("user_api_key", "") + query = kwargs.get("query", getattr(args[0], "query", "")) + + context = LogContext(endpoint, activity_id, user, api_key, query) + kwargs["log_context"] = context + + logging.info( + f"Starting activity: {endpoint} - {activity_id} - User: {user}" + ) + + generator = func(*args, **kwargs) + yield from _consume_and_log(generator, context) + + return wrapper + + return decorator + + +def _consume_and_log(generator: Generator, context: "LogContext"): + try: + for item in generator: + yield item + except Exception as e: + logging.exception(f"Error in {context.endpoint} - {context.activity_id}: {e}") + context.stacks.append({"component": "error", "data": {"message": str(e)}}) + _log_to_mongodb( + endpoint=context.endpoint, + activity_id=context.activity_id, + user=context.user, + api_key=context.api_key, + query=context.query, + stacks=context.stacks, + level="error", + ) + raise + finally: + _log_to_mongodb( + endpoint=context.endpoint, + activity_id=context.activity_id, + user=context.user, + api_key=context.api_key, + query=context.query, + stacks=context.stacks, + level="info", + ) + + +def _log_to_mongodb( + endpoint: str, + activity_id: str, + user: str, + api_key: str, + query: str, + stacks: List[Dict], + level: str, +) -> None: + try: + mongo = MongoDB.get_client() + db = mongo["docsgpt"] + user_logs_collection = db["stack_logs"] + + log_entry = { + "endpoint": endpoint, + "id": activity_id, + "level": level, + "user": user, + "api_key": api_key, + "query": query, + "stacks": stacks, + "timestamp": datetime.datetime.now(datetime.timezone.utc), + } + user_logs_collection.insert_one(log_entry) + logging.debug(f"Logged activity to MongoDB: {activity_id}") + + except Exception as e: + logging.error(f"Failed to log to MongoDB: {e}") diff --git a/application/retriever/classic_rag.py b/application/retriever/classic_rag.py index 5c74878c..03f17f44 100644 --- a/application/retriever/classic_rag.py +++ b/application/retriever/classic_rag.py @@ -43,7 +43,11 @@ class ClassicRAG(BaseRetriever): self.vectorstore = source["active_docs"] if "active_docs" in source else None def _rephrase_query(self): - if not self.chat_history or self.chat_history == []: + if ( + not self.original_question + or not self.chat_history + or self.chat_history == [] + ): return self.original_question prompt = f"""Given the following conversation history: diff --git a/application/usage.py b/application/usage.py index fe4cd50e..a18a3848 100644 --- a/application/usage.py +++ b/application/usage.py @@ -1,7 +1,8 @@ import sys from datetime import datetime + from application.core.mongo_db import MongoDB -from application.utils import num_tokens_from_string, num_tokens_from_object_or_list +from application.utils import num_tokens_from_object_or_list, num_tokens_from_string mongo = MongoDB.get_client() db = mongo["docsgpt"] @@ -24,13 +25,16 @@ def gen_token_usage(func): def wrapper(self, model, messages, stream, tools, **kwargs): for message in messages: if message["content"]: - self.token_usage["prompt_tokens"] += num_tokens_from_string(message["content"]) + self.token_usage["prompt_tokens"] += num_tokens_from_string( + message["content"] + ) result = func(self, model, messages, stream, tools, **kwargs) - # check if result is a string if isinstance(result, str): self.token_usage["generated_tokens"] += num_tokens_from_string(result) else: - self.token_usage["generated_tokens"] += num_tokens_from_object_or_list(result) + self.token_usage["generated_tokens"] += num_tokens_from_object_or_list( + result + ) update_token_usage(self.user_api_key, self.token_usage) return result @@ -40,7 +44,9 @@ def gen_token_usage(func): def stream_token_usage(func): def wrapper(self, model, messages, stream, tools, **kwargs): for message in messages: - self.token_usage["prompt_tokens"] += num_tokens_from_string(message["content"]) + self.token_usage["prompt_tokens"] += num_tokens_from_string( + message["content"] + ) batch = [] result = func(self, model, messages, stream, tools, **kwargs) for r in result: