feat: logging stacks

This commit is contained in:
Siddhant Rai
2025-02-27 19:14:10 +05:30
parent 1f0b779c64
commit c6ce4d9374
9 changed files with 246 additions and 23 deletions

View File

@@ -0,0 +1,14 @@
from application.agents.classic_agent import ClassicAgent
class AgentCreator:
agents = {
"classic": ClassicAgent,
}
@classmethod
def create_agent(cls, type, *args, **kwargs):
agent_class = cls.agents.get(type.lower())
if not agent_class:
raise ValueError(f"No agent class found for type {type}")
return agent_class(*args, **kwargs)

View File

@@ -9,7 +9,8 @@ from application.llm.llm_creator import LLMCreator
class BaseAgent:
def __init__(self, llm_name, gpt_model, api_key, user_api_key=None):
def __init__(self, endpoint, llm_name, gpt_model, api_key, user_api_key=None):
self.endpoint = endpoint
self.llm = LLMCreator.create_llm(
llm_name, api_key=api_key, user_api_key=user_api_key
)
@@ -19,7 +20,7 @@ class BaseAgent:
self.tool_config = {}
self.tool_calls = []
def gen(self, query: str) -> Generator[Dict, None, None]:
def gen(self, *args, **kwargs) -> Generator[Dict, None, None]:
raise NotImplementedError('Method "gen" must be implemented in the child class')
def _get_user_tools(self, user="local"):

View File

@@ -2,6 +2,7 @@ import uuid
from typing import Dict, Generator
from application.agents.base import BaseAgent
from application.logging import build_stack_data, log_activity, LogContext
from application.retriever.base import BaseRetriever
@@ -9,6 +10,7 @@ from application.retriever.base import BaseRetriever
class ClassicAgent(BaseAgent):
def __init__(
self,
endpoint,
llm_name,
gpt_model,
api_key,
@@ -16,13 +18,21 @@ class ClassicAgent(BaseAgent):
prompt="",
chat_history=None,
):
super().__init__(llm_name, gpt_model, api_key, user_api_key)
super().__init__(endpoint, llm_name, gpt_model, api_key, user_api_key)
self.prompt = prompt
self.chat_history = chat_history if chat_history is not None else []
def gen(self, query: str, retriever: BaseRetriever) -> Generator[Dict, None, None]:
@log_activity()
def gen(
self, query: str, retriever: BaseRetriever, log_context: LogContext = None
) -> Generator[Dict, None, None]:
yield from self._gen_inner(query, retriever, log_context)
def _gen_inner(
self, query: str, retriever: BaseRetriever, log_context: LogContext
) -> Generator[Dict, None, None]:
retrieved_data = self._retriever_search(retriever, query, log_context)
retrieved_data = retriever.search(query)
docs_together = "\n".join([doc["text"] for doc in retrieved_data])
p_chat_combine = self.prompt.replace("{summaries}", docs_together)
messages_combine = [{"role": "system", "content": p_chat_combine}]
@@ -66,9 +76,7 @@ class ClassicAgent(BaseAgent):
tools_dict = self._get_user_tools()
self._prepare_tools(tools_dict)
resp = self.llm.gen(
model=self.gpt_model, messages=messages_combine, tools=self.tools
)
resp = self._llm_gen(messages_combine, log_context)
if isinstance(resp, str):
yield {"answer": resp}
@@ -81,9 +89,7 @@ class ClassicAgent(BaseAgent):
yield {"answer": resp.message.content}
return
resp = self.llm_handler.handle_response(
self, resp, tools_dict, messages_combine
)
resp = self._llm_handler(resp, tools_dict, messages_combine, log_context)
if isinstance(resp, str):
yield {"answer": resp}
@@ -101,3 +107,29 @@ class ClassicAgent(BaseAgent):
yield {"answer": line}
yield {"tool_calls": self.tool_calls.copy()}
def _retriever_search(self, retriever, query, log_context):
retrieved_data = retriever.search(query)
if log_context:
data = build_stack_data(retriever, exclude_attributes=["llm"])
log_context.stacks.append({"component": "retriever", "data": data})
return retrieved_data
def _llm_gen(self, messages_combine, log_context):
resp = self.llm.gen(
model=self.gpt_model, messages=messages_combine, tools=self.tools
)
if log_context:
data = build_stack_data(self.llm)
log_context.stacks.append({"component": "llm", "data": data})
return resp
def _llm_handler(self, resp, tools_dict, messages_combine, log_context):
resp = self.llm_handler.handle_response(
self, resp, tools_dict, messages_combine
)
if log_context:
data = build_stack_data(self.llm_handler)
log_context.stacks.append({"component": "llm_handler", "data": data})
return resp

View File

@@ -1,8 +1,14 @@
import json
from abc import ABC, abstractmethod
from application.logging import build_stack_data
class LLMHandler(ABC):
def __init__(self):
self.llm_calls = []
self.tool_calls = []
@abstractmethod
def handle_response(self, agent, resp, tools_dict, messages, **kwargs):
pass
@@ -21,6 +27,7 @@ class OpenAILLMHandler(LLMHandler):
tool_calls = resp.message.tool_calls
for call in tool_calls:
try:
self.tool_calls.append(call)
tool_response, call_id = agent._execute_tool_action(
tools_dict, call
)
@@ -57,6 +64,7 @@ class OpenAILLMHandler(LLMHandler):
resp = agent.llm.gen(
model=agent.gpt_model, messages=messages, tools=agent.tools
)
self.llm_calls.append(build_stack_data(agent.llm))
return resp
@@ -68,11 +76,13 @@ class GoogleLLMHandler(LLMHandler):
response = agent.llm.gen(
model=agent.gpt_model, messages=messages, tools=agent.tools
)
self.llm_calls.append(build_stack_data(agent.llm))
if response.candidates and response.candidates[0].content.parts:
tool_call_found = False
for part in response.candidates[0].content.parts:
if part.function_call:
tool_call_found = True
self.tool_calls.append(part.function_call)
tool_response, call_id = agent._execute_tool_action(
tools_dict, part.function_call
)