feat: logging stacks

This commit is contained in:
Siddhant Rai
2025-02-27 19:14:10 +05:30
parent 1f0b779c64
commit c6ce4d9374
9 changed files with 246 additions and 23 deletions

View File

@@ -2,6 +2,7 @@ import uuid
from typing import Dict, Generator
from application.agents.base import BaseAgent
from application.logging import build_stack_data, log_activity, LogContext
from application.retriever.base import BaseRetriever
@@ -9,6 +10,7 @@ from application.retriever.base import BaseRetriever
class ClassicAgent(BaseAgent):
def __init__(
self,
endpoint,
llm_name,
gpt_model,
api_key,
@@ -16,13 +18,21 @@ class ClassicAgent(BaseAgent):
prompt="",
chat_history=None,
):
super().__init__(llm_name, gpt_model, api_key, user_api_key)
super().__init__(endpoint, llm_name, gpt_model, api_key, user_api_key)
self.prompt = prompt
self.chat_history = chat_history if chat_history is not None else []
def gen(self, query: str, retriever: BaseRetriever) -> Generator[Dict, None, None]:
@log_activity()
def gen(
self, query: str, retriever: BaseRetriever, log_context: LogContext = None
) -> Generator[Dict, None, None]:
yield from self._gen_inner(query, retriever, log_context)
def _gen_inner(
self, query: str, retriever: BaseRetriever, log_context: LogContext
) -> Generator[Dict, None, None]:
retrieved_data = self._retriever_search(retriever, query, log_context)
retrieved_data = retriever.search(query)
docs_together = "\n".join([doc["text"] for doc in retrieved_data])
p_chat_combine = self.prompt.replace("{summaries}", docs_together)
messages_combine = [{"role": "system", "content": p_chat_combine}]
@@ -66,9 +76,7 @@ class ClassicAgent(BaseAgent):
tools_dict = self._get_user_tools()
self._prepare_tools(tools_dict)
resp = self.llm.gen(
model=self.gpt_model, messages=messages_combine, tools=self.tools
)
resp = self._llm_gen(messages_combine, log_context)
if isinstance(resp, str):
yield {"answer": resp}
@@ -81,9 +89,7 @@ class ClassicAgent(BaseAgent):
yield {"answer": resp.message.content}
return
resp = self.llm_handler.handle_response(
self, resp, tools_dict, messages_combine
)
resp = self._llm_handler(resp, tools_dict, messages_combine, log_context)
if isinstance(resp, str):
yield {"answer": resp}
@@ -101,3 +107,29 @@ class ClassicAgent(BaseAgent):
yield {"answer": line}
yield {"tool_calls": self.tool_calls.copy()}
def _retriever_search(self, retriever, query, log_context):
retrieved_data = retriever.search(query)
if log_context:
data = build_stack_data(retriever, exclude_attributes=["llm"])
log_context.stacks.append({"component": "retriever", "data": data})
return retrieved_data
def _llm_gen(self, messages_combine, log_context):
resp = self.llm.gen(
model=self.gpt_model, messages=messages_combine, tools=self.tools
)
if log_context:
data = build_stack_data(self.llm)
log_context.stacks.append({"component": "llm", "data": data})
return resp
def _llm_handler(self, resp, tools_dict, messages_combine, log_context):
resp = self.llm_handler.handle_response(
self, resp, tools_dict, messages_combine
)
if log_context:
data = build_stack_data(self.llm_handler)
log_context.stacks.append({"component": "llm_handler", "data": data})
return resp