refactor: update env variable names

This commit is contained in:
Siddhant Rai
2025-06-06 15:29:53 +05:30
parent 143f4aa886
commit e9530d5ec5
9 changed files with 121 additions and 86 deletions

View File

@@ -1,8 +1,6 @@
from typing import Dict, Generator
from application.agents.base import BaseAgent
from application.logging import LogContext
from application.retriever.base import BaseRetriever
import logging
@@ -10,55 +8,90 @@ logger = logging.getLogger(__name__)
class ClassicAgent(BaseAgent):
"""A simplified classic agent with clear execution flow.
Usage:
1. Processes a query through retrieval
2. Sets up available tools
3. Generates responses using LLM
4. Handles tool interactions if needed
5. Returns standardized outputs
Easy to extend by overriding specific steps.
"""
def _gen_inner(
self, query: str, retriever: BaseRetriever, log_context: LogContext
) -> Generator[Dict, None, None]:
"""Main execution flow for the agent."""
# Step 1: Retrieve relevant data
retrieved_data = self._retriever_search(retriever, query, log_context)
if self.user_api_key:
tools_dict = self._get_tools(self.user_api_key)
else:
tools_dict = self._get_user_tools(self.user)
# Step 2: Prepare tools
tools_dict = (
self._get_user_tools(self.user)
if not self.user_api_key
else self._get_tools(self.user_api_key)
)
self._prepare_tools(tools_dict)
# Step 3: Build and process messages
messages = self._build_messages(self.prompt, query, retrieved_data)
llm_response = self._llm_gen(messages, log_context)
resp = self._llm_gen(messages, log_context)
# Step 4: Handle the response
yield from self._handle_response(
llm_response, tools_dict, messages, log_context
)
attachments = self.attachments
if isinstance(resp, str):
yield {"answer": resp}
return
if (
hasattr(resp, "message")
and hasattr(resp.message, "content")
and resp.message.content is not None
):
yield {"answer": resp.message.content}
return
resp = self._llm_handler(resp, tools_dict, messages, log_context, attachments)
if isinstance(resp, str):
yield {"answer": resp}
elif (
hasattr(resp, "message")
and hasattr(resp.message, "content")
and resp.message.content is not None
):
yield {"answer": resp.message.content}
else:
for line in resp:
if isinstance(line, str):
yield {"answer": line}
# Step 5: Return metadata
yield {"sources": retrieved_data}
yield {"tool_calls": self._get_truncated_tool_calls()}
# Log tool calls for debugging
log_context.stacks.append(
{"component": "agent", "data": {"tool_calls": self.tool_calls.copy()}}
)
yield {"sources": retrieved_data}
# clean tool_call_data only send first 50 characters of tool_call['result']
for tool_call in self.tool_calls:
if len(str(tool_call["result"])) > 50:
tool_call["result"] = str(tool_call["result"])[:50] + "..."
yield {"tool_calls": self.tool_calls.copy()}
def _handle_response(self, response, tools_dict, messages, log_context):
"""Handle different types of LLM responses consistently."""
# Handle simple string responses
if isinstance(response, str):
yield {"answer": response}
return
# Handle content from message objects
if hasattr(response, "message") and getattr(response.message, "content", None):
yield {"answer": response.message.content}
return
# Handle complex responses that may require tool use
processed_response = self._llm_handler(
response, tools_dict, messages, log_context, self.attachments
)
# Yield the final processed response
if isinstance(processed_response, str):
yield {"answer": processed_response}
elif hasattr(processed_response, "message") and getattr(
processed_response.message, "content", None
):
yield {"answer": processed_response.message.content}
else:
for line in processed_response:
if isinstance(line, str):
yield {"answer": line}
def _get_truncated_tool_calls(self):
"""Return tool calls with truncated results for cleaner output."""
return [
{
**tool_call,
"result": (
f"{str(tool_call['result'])[:50]}..."
if len(str(tool_call["result"])) > 50
else tool_call["result"]
),
}
for tool_call in self.tool_calls
]