feat: logging stacks

This commit is contained in:
Siddhant Rai
2025-02-27 19:14:10 +05:30
parent 1f0b779c64
commit c6ce4d9374
9 changed files with 246 additions and 23 deletions

View File

@@ -0,0 +1,14 @@
from application.agents.classic_agent import ClassicAgent
class AgentCreator:
agents = {
"classic": ClassicAgent,
}
@classmethod
def create_agent(cls, type, *args, **kwargs):
agent_class = cls.agents.get(type.lower())
if not agent_class:
raise ValueError(f"No agent class found for type {type}")
return agent_class(*args, **kwargs)

View File

@@ -9,7 +9,8 @@ from application.llm.llm_creator import LLMCreator
class BaseAgent:
def __init__(self, llm_name, gpt_model, api_key, user_api_key=None):
def __init__(self, endpoint, llm_name, gpt_model, api_key, user_api_key=None):
self.endpoint = endpoint
self.llm = LLMCreator.create_llm(
llm_name, api_key=api_key, user_api_key=user_api_key
)
@@ -19,7 +20,7 @@ class BaseAgent:
self.tool_config = {}
self.tool_calls = []
def gen(self, query: str) -> Generator[Dict, None, None]:
def gen(self, *args, **kwargs) -> Generator[Dict, None, None]:
raise NotImplementedError('Method "gen" must be implemented in the child class')
def _get_user_tools(self, user="local"):

View File

@@ -2,6 +2,7 @@ import uuid
from typing import Dict, Generator
from application.agents.base import BaseAgent
from application.logging import build_stack_data, log_activity, LogContext
from application.retriever.base import BaseRetriever
@@ -9,6 +10,7 @@ from application.retriever.base import BaseRetriever
class ClassicAgent(BaseAgent):
def __init__(
self,
endpoint,
llm_name,
gpt_model,
api_key,
@@ -16,13 +18,21 @@ class ClassicAgent(BaseAgent):
prompt="",
chat_history=None,
):
super().__init__(llm_name, gpt_model, api_key, user_api_key)
super().__init__(endpoint, llm_name, gpt_model, api_key, user_api_key)
self.prompt = prompt
self.chat_history = chat_history if chat_history is not None else []
def gen(self, query: str, retriever: BaseRetriever) -> Generator[Dict, None, None]:
@log_activity()
def gen(
self, query: str, retriever: BaseRetriever, log_context: LogContext = None
) -> Generator[Dict, None, None]:
yield from self._gen_inner(query, retriever, log_context)
def _gen_inner(
self, query: str, retriever: BaseRetriever, log_context: LogContext
) -> Generator[Dict, None, None]:
retrieved_data = self._retriever_search(retriever, query, log_context)
retrieved_data = retriever.search(query)
docs_together = "\n".join([doc["text"] for doc in retrieved_data])
p_chat_combine = self.prompt.replace("{summaries}", docs_together)
messages_combine = [{"role": "system", "content": p_chat_combine}]
@@ -66,9 +76,7 @@ class ClassicAgent(BaseAgent):
tools_dict = self._get_user_tools()
self._prepare_tools(tools_dict)
resp = self.llm.gen(
model=self.gpt_model, messages=messages_combine, tools=self.tools
)
resp = self._llm_gen(messages_combine, log_context)
if isinstance(resp, str):
yield {"answer": resp}
@@ -81,9 +89,7 @@ class ClassicAgent(BaseAgent):
yield {"answer": resp.message.content}
return
resp = self.llm_handler.handle_response(
self, resp, tools_dict, messages_combine
)
resp = self._llm_handler(resp, tools_dict, messages_combine, log_context)
if isinstance(resp, str):
yield {"answer": resp}
@@ -101,3 +107,29 @@ class ClassicAgent(BaseAgent):
yield {"answer": line}
yield {"tool_calls": self.tool_calls.copy()}
def _retriever_search(self, retriever, query, log_context):
retrieved_data = retriever.search(query)
if log_context:
data = build_stack_data(retriever, exclude_attributes=["llm"])
log_context.stacks.append({"component": "retriever", "data": data})
return retrieved_data
def _llm_gen(self, messages_combine, log_context):
resp = self.llm.gen(
model=self.gpt_model, messages=messages_combine, tools=self.tools
)
if log_context:
data = build_stack_data(self.llm)
log_context.stacks.append({"component": "llm", "data": data})
return resp
def _llm_handler(self, resp, tools_dict, messages_combine, log_context):
resp = self.llm_handler.handle_response(
self, resp, tools_dict, messages_combine
)
if log_context:
data = build_stack_data(self.llm_handler)
log_context.stacks.append({"component": "llm_handler", "data": data})
return resp

View File

@@ -1,8 +1,14 @@
import json
from abc import ABC, abstractmethod
from application.logging import build_stack_data
class LLMHandler(ABC):
def __init__(self):
self.llm_calls = []
self.tool_calls = []
@abstractmethod
def handle_response(self, agent, resp, tools_dict, messages, **kwargs):
pass
@@ -21,6 +27,7 @@ class OpenAILLMHandler(LLMHandler):
tool_calls = resp.message.tool_calls
for call in tool_calls:
try:
self.tool_calls.append(call)
tool_response, call_id = agent._execute_tool_action(
tools_dict, call
)
@@ -57,6 +64,7 @@ class OpenAILLMHandler(LLMHandler):
resp = agent.llm.gen(
model=agent.gpt_model, messages=messages, tools=agent.tools
)
self.llm_calls.append(build_stack_data(agent.llm))
return resp
@@ -68,11 +76,13 @@ class GoogleLLMHandler(LLMHandler):
response = agent.llm.gen(
model=agent.gpt_model, messages=messages, tools=agent.tools
)
self.llm_calls.append(build_stack_data(agent.llm))
if response.candidates and response.candidates[0].content.parts:
tool_call_found = False
for part in response.candidates[0].content.parts:
if part.function_call:
tool_call_found = True
self.tool_calls.append(part.function_call)
tool_response, call_id = agent._execute_tool_action(
tools_dict, part.function_call
)

View File

@@ -10,7 +10,7 @@ from bson.objectid import ObjectId
from flask import Blueprint, make_response, request, Response
from flask_restx import fields, Namespace, Resource
from application.agents.classic_agent import ClassicAgent
from application.agents.agent_creator import AgentCreator
from application.core.mongo_db import MongoDB
from application.core.settings import settings
@@ -213,7 +213,7 @@ def complete_stream(
response_full = ""
source_log_docs = []
tool_calls = []
answer = agent.gen(question, retriever)
answer = agent.gen(query=question, retriever=retriever)
sources = retriever.search(question)
for source in sources:
if "text" in source:
@@ -368,14 +368,18 @@ class Stream(Resource):
prompt = get_prompt(prompt_id)
if "isNoneDoc" in data and data["isNoneDoc"] is True:
chunks = 0
agent = ClassicAgent(
settings.LLM_NAME,
gpt_model,
settings.API_KEY,
agent = AgentCreator.create_agent(
settings.AGENT_NAME,
endpoint="stream",
llm_name=settings.LLM_NAME,
gpt_model=gpt_model,
api_key=settings.API_KEY,
user_api_key=user_api_key,
prompt=prompt,
chat_history=history,
)
retriever = RetrieverCreator.create_retriever(
retriever_name,
source=source,

View File

@@ -32,6 +32,7 @@ class Settings(BaseSettings):
"faiss" # "faiss" or "elasticsearch" or "qdrant" or "milvus" or "lancedb"
)
RETRIEVERS_ENABLED: list = ["classic_rag", "duckduck_search"] # also brave_search
AGENT_NAME: str = "classic"
# LLM Cache
CACHE_REDIS_URL: str = "redis://localhost:6379/2"

151
application/logging.py Normal file
View File

@@ -0,0 +1,151 @@
import datetime
import functools
import inspect
import logging
import uuid
from typing import Any, Callable, Dict, Generator, List
from application.core.mongo_db import MongoDB
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
class LogContext:
def __init__(self, endpoint, activity_id, user, api_key, query):
self.endpoint = endpoint
self.activity_id = activity_id
self.user = user
self.api_key = api_key
self.query = query
self.stacks = []
def build_stack_data(
obj: Any,
include_attributes: List[str] = None,
exclude_attributes: List[str] = None,
custom_data: Dict = None,
) -> Dict:
data = {}
if include_attributes is None:
include_attributes = []
for name, value in inspect.getmembers(obj):
if (
not name.startswith("_")
and not inspect.ismethod(value)
and not inspect.isfunction(value)
):
include_attributes.append(name)
for attr_name in include_attributes:
if exclude_attributes and attr_name in exclude_attributes:
continue
try:
attr_value = getattr(obj, attr_name)
if attr_value is not None:
if isinstance(attr_value, (int, float, str, bool)):
data[attr_name] = attr_value
elif isinstance(attr_value, list):
if all(isinstance(item, dict) for item in attr_value):
data[attr_name] = attr_value
elif all(hasattr(item, "__dict__") for item in attr_value):
data[attr_name] = [item.__dict__ for item in attr_value]
else:
data[attr_name] = [str(item) for item in attr_value]
elif isinstance(attr_value, dict):
data[attr_name] = {k: str(v) for k, v in attr_value.items()}
else:
data[attr_name] = str(attr_value)
except AttributeError:
pass
if custom_data:
data.update(custom_data)
return data
def log_activity() -> Callable:
def decorator(func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
activity_id = str(uuid.uuid4())
data = build_stack_data(args[0])
endpoint = data.get("endpoint", "")
user = data.get("user", "local")
api_key = data.get("user_api_key", "")
query = kwargs.get("query", getattr(args[0], "query", ""))
context = LogContext(endpoint, activity_id, user, api_key, query)
kwargs["log_context"] = context
logging.info(
f"Starting activity: {endpoint} - {activity_id} - User: {user}"
)
generator = func(*args, **kwargs)
yield from _consume_and_log(generator, context)
return wrapper
return decorator
def _consume_and_log(generator: Generator, context: "LogContext"):
try:
for item in generator:
yield item
except Exception as e:
logging.exception(f"Error in {context.endpoint} - {context.activity_id}: {e}")
context.stacks.append({"component": "error", "data": {"message": str(e)}})
_log_to_mongodb(
endpoint=context.endpoint,
activity_id=context.activity_id,
user=context.user,
api_key=context.api_key,
query=context.query,
stacks=context.stacks,
level="error",
)
raise
finally:
_log_to_mongodb(
endpoint=context.endpoint,
activity_id=context.activity_id,
user=context.user,
api_key=context.api_key,
query=context.query,
stacks=context.stacks,
level="info",
)
def _log_to_mongodb(
endpoint: str,
activity_id: str,
user: str,
api_key: str,
query: str,
stacks: List[Dict],
level: str,
) -> None:
try:
mongo = MongoDB.get_client()
db = mongo["docsgpt"]
user_logs_collection = db["stack_logs"]
log_entry = {
"endpoint": endpoint,
"id": activity_id,
"level": level,
"user": user,
"api_key": api_key,
"query": query,
"stacks": stacks,
"timestamp": datetime.datetime.now(datetime.timezone.utc),
}
user_logs_collection.insert_one(log_entry)
logging.debug(f"Logged activity to MongoDB: {activity_id}")
except Exception as e:
logging.error(f"Failed to log to MongoDB: {e}")

View File

@@ -43,7 +43,11 @@ class ClassicRAG(BaseRetriever):
self.vectorstore = source["active_docs"] if "active_docs" in source else None
def _rephrase_query(self):
if not self.chat_history or self.chat_history == []:
if (
not self.original_question
or not self.chat_history
or self.chat_history == []
):
return self.original_question
prompt = f"""Given the following conversation history:

View File

@@ -1,7 +1,8 @@
import sys
from datetime import datetime
from application.core.mongo_db import MongoDB
from application.utils import num_tokens_from_string, num_tokens_from_object_or_list
from application.utils import num_tokens_from_object_or_list, num_tokens_from_string
mongo = MongoDB.get_client()
db = mongo["docsgpt"]
@@ -24,13 +25,16 @@ def gen_token_usage(func):
def wrapper(self, model, messages, stream, tools, **kwargs):
for message in messages:
if message["content"]:
self.token_usage["prompt_tokens"] += num_tokens_from_string(message["content"])
self.token_usage["prompt_tokens"] += num_tokens_from_string(
message["content"]
)
result = func(self, model, messages, stream, tools, **kwargs)
# check if result is a string
if isinstance(result, str):
self.token_usage["generated_tokens"] += num_tokens_from_string(result)
else:
self.token_usage["generated_tokens"] += num_tokens_from_object_or_list(result)
self.token_usage["generated_tokens"] += num_tokens_from_object_or_list(
result
)
update_token_usage(self.user_api_key, self.token_usage)
return result
@@ -40,7 +44,9 @@ def gen_token_usage(func):
def stream_token_usage(func):
def wrapper(self, model, messages, stream, tools, **kwargs):
for message in messages:
self.token_usage["prompt_tokens"] += num_tokens_from_string(message["content"])
self.token_usage["prompt_tokens"] += num_tokens_from_string(
message["content"]
)
batch = []
result = func(self, model, messages, stream, tools, **kwargs)
for r in result: