refactor: reorganize LLM handler structure and improve tool call parsing

This commit is contained in:
Siddhant Rai
2025-06-02 12:17:29 +05:30
parent 9c6352dd5b
commit 92c3c707e1
8 changed files with 494 additions and 376 deletions

View File

@@ -2,16 +2,18 @@ import uuid
from abc import ABC, abstractmethod
from typing import Dict, Generator, List, Optional
from application.agents.llm_handler import get_llm_handler
from bson.objectid import ObjectId
from application.agents.tools.tool_action_parser import ToolActionParser
from application.agents.tools.tool_manager import ToolManager
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.llm.handlers.handler_creator import LLMHandlerCreator
from application.llm.llm_creator import LLMCreator
from application.logging import build_stack_data, log_activity, LogContext
from application.retriever.base import BaseRetriever
from application.core.settings import settings
from bson.objectid import ObjectId
class BaseAgent(ABC):
@@ -45,7 +47,9 @@ class BaseAgent(ABC):
user_api_key=user_api_key,
decoded_token=decoded_token,
)
self.llm_handler = get_llm_handler(llm_name)
self.llm_handler = LLMHandlerCreator.create_handler(
llm_name if llm_name else "default"
)
self.attachments = attachments or []
@log_activity()
@@ -268,8 +272,8 @@ class BaseAgent(ABC):
log_context: Optional[LogContext] = None,
attachments: Optional[List[Dict]] = None,
):
resp = self.llm_handler.handle_response(
self, resp, tools_dict, messages, attachments
resp = self.llm_handler.process_message_flow(
self, resp, tools_dict, messages, attachments, True
)
if log_context:
data = build_stack_data(self.llm_handler, exclude_attributes=["tool_calls"])

View File

@@ -1,351 +0,0 @@
import json
import logging
from abc import ABC, abstractmethod
from application.logging import build_stack_data
logger = logging.getLogger(__name__)
class LLMHandler(ABC):
def __init__(self):
self.llm_calls = []
self.tool_calls = []
@abstractmethod
def handle_response(self, agent, resp, tools_dict, messages, attachments=None, **kwargs):
pass
def prepare_messages_with_attachments(self, agent, messages, attachments=None):
"""
Prepare messages with attachment content if available.
Args:
agent: The current agent instance.
messages (list): List of message dictionaries.
attachments (list): List of attachment dictionaries with content.
Returns:
list: Messages with attachment context added to the system prompt.
"""
if not attachments:
return messages
logger.info(f"Preparing messages with {len(attachments)} attachments")
supported_types = agent.llm.get_supported_attachment_types()
supported_attachments = []
unsupported_attachments = []
for attachment in attachments:
mime_type = attachment.get('mime_type')
if mime_type in supported_types:
supported_attachments.append(attachment)
else:
unsupported_attachments.append(attachment)
# Process supported attachments with the LLM's custom method
prepared_messages = messages
if supported_attachments:
logger.info(f"Processing {len(supported_attachments)} supported attachments with {agent.llm.__class__.__name__}'s method")
prepared_messages = agent.llm.prepare_messages_with_attachments(messages, supported_attachments)
# Process unsupported attachments with the default method
if unsupported_attachments:
logger.info(f"Processing {len(unsupported_attachments)} unsupported attachments with default method")
prepared_messages = self._append_attachment_content_to_system(prepared_messages, unsupported_attachments)
return prepared_messages
def _append_attachment_content_to_system(self, messages, attachments):
"""
Default method to append attachment content to the system prompt.
Args:
messages (list): List of message dictionaries.
attachments (list): List of attachment dictionaries with content.
Returns:
list: Messages with attachment context added to the system prompt.
"""
prepared_messages = messages.copy()
attachment_texts = []
for attachment in attachments:
logger.info(f"Adding attachment {attachment.get('id')} to context")
if 'content' in attachment:
attachment_texts.append(f"Attached file content:\n\n{attachment['content']}")
if attachment_texts:
combined_attachment_text = "\n\n".join(attachment_texts)
system_found = False
for i in range(len(prepared_messages)):
if prepared_messages[i].get("role") == "system":
prepared_messages[i]["content"] += f"\n\n{combined_attachment_text}"
system_found = True
break
if not system_found:
prepared_messages.insert(0, {"role": "system", "content": combined_attachment_text})
return prepared_messages
class OpenAILLMHandler(LLMHandler):
def handle_response(self, agent, resp, tools_dict, messages, attachments=None, stream: bool = True):
messages = self.prepare_messages_with_attachments(agent, messages, attachments)
logger.info(f"Messages with attachments: {messages}")
if not stream:
while hasattr(resp, "finish_reason") and resp.finish_reason == "tool_calls":
message = json.loads(resp.model_dump_json())["message"]
keys_to_remove = {"audio", "function_call", "refusal"}
filtered_data = {
k: v for k, v in message.items() if k not in keys_to_remove
}
messages.append(filtered_data)
tool_calls = resp.message.tool_calls
for call in tool_calls:
try:
self.tool_calls.append(call)
tool_response, call_id = agent._execute_tool_action(
tools_dict, call
)
function_call_dict = {
"function_call": {
"name": call.function.name,
"args": call.function.arguments,
"call_id": call_id,
}
}
function_response_dict = {
"function_response": {
"name": call.function.name,
"response": {"result": tool_response},
"call_id": call_id,
}
}
messages.append(
{"role": "assistant", "content": [function_call_dict]}
)
messages.append(
{"role": "tool", "content": [function_response_dict]}
)
messages = self.prepare_messages_with_attachments(agent, messages, attachments)
except Exception as e:
logging.error(f"Error executing tool: {str(e)}", exc_info=True)
messages.append(
{
"role": "tool",
"content": f"Error executing tool: {str(e)}",
"tool_call_id": call_id,
}
)
resp = agent.llm.gen_stream(
model=agent.gpt_model, messages=messages, tools=agent.tools
)
self.llm_calls.append(build_stack_data(agent.llm))
return resp
else:
text_buffer = ""
while True:
tool_calls = {}
for chunk in resp:
if isinstance(chunk, str) and len(chunk) > 0:
yield chunk
continue
elif hasattr(chunk, "delta"):
chunk_delta = chunk.delta
if (
hasattr(chunk_delta, "tool_calls")
and chunk_delta.tool_calls is not None
):
for tool_call in chunk_delta.tool_calls:
index = tool_call.index
if index not in tool_calls:
tool_calls[index] = {
"id": "",
"function": {"name": "", "arguments": ""},
}
current = tool_calls[index]
if tool_call.id:
current["id"] = tool_call.id
if tool_call.function.name:
current["function"][
"name"
] = tool_call.function.name
if tool_call.function.arguments:
current["function"][
"arguments"
] += tool_call.function.arguments
tool_calls[index] = current
if (
hasattr(chunk, "finish_reason")
and chunk.finish_reason == "tool_calls"
):
for index in sorted(tool_calls.keys()):
call = tool_calls[index]
try:
self.tool_calls.append(call)
tool_response, call_id = agent._execute_tool_action(
tools_dict, call
)
if isinstance(call["function"]["arguments"], str):
call["function"]["arguments"] = json.loads(call["function"]["arguments"])
function_call_dict = {
"function_call": {
"name": call["function"]["name"],
"args": call["function"]["arguments"],
"call_id": call["id"],
}
}
function_response_dict = {
"function_response": {
"name": call["function"]["name"],
"response": {"result": tool_response},
"call_id": call["id"],
}
}
messages.append(
{
"role": "assistant",
"content": [function_call_dict],
}
)
messages.append(
{
"role": "tool",
"content": [function_response_dict],
}
)
except Exception as e:
logging.error(f"Error executing tool: {str(e)}", exc_info=True)
messages.append(
{
"role": "assistant",
"content": f"Error executing tool: {str(e)}",
}
)
tool_calls = {}
if hasattr(chunk_delta, "content") and chunk_delta.content:
# Add to buffer or yield immediately based on your preference
text_buffer += chunk_delta.content
yield text_buffer
text_buffer = ""
if (
hasattr(chunk, "finish_reason")
and chunk.finish_reason == "stop"
):
return resp
elif isinstance(chunk, str) and len(chunk) == 0:
continue
logger.info(f"Regenerating with messages: {messages}")
resp = agent.llm.gen_stream(
model=agent.gpt_model, messages=messages, tools=agent.tools
)
self.llm_calls.append(build_stack_data(agent.llm))
class GoogleLLMHandler(LLMHandler):
def handle_response(self, agent, resp, tools_dict, messages, attachments=None, stream: bool = True):
from google.genai import types
messages = self.prepare_messages_with_attachments(agent, messages, attachments)
while True:
if not stream:
response = agent.llm.gen(
model=agent.gpt_model, messages=messages, tools=agent.tools
)
self.llm_calls.append(build_stack_data(agent.llm))
if response.candidates and response.candidates[0].content.parts:
tool_call_found = False
for part in response.candidates[0].content.parts:
if part.function_call:
tool_call_found = True
self.tool_calls.append(part.function_call)
tool_response, call_id = agent._execute_tool_action(
tools_dict, part.function_call
)
function_response_part = types.Part.from_function_response(
name=part.function_call.name,
response={"result": tool_response},
)
messages.append(
{"role": "model", "content": [part.to_json_dict()]}
)
messages.append(
{
"role": "tool",
"content": [function_response_part.to_json_dict()],
}
)
if (
not tool_call_found
and response.candidates[0].content.parts
and response.candidates[0].content.parts[0].text
):
return response.candidates[0].content.parts[0].text
elif not tool_call_found:
return response.candidates[0].content.parts
else:
return response
else:
response = agent.llm.gen_stream(
model=agent.gpt_model, messages=messages, tools=agent.tools
)
self.llm_calls.append(build_stack_data(agent.llm))
tool_call_found = False
for result in response:
if hasattr(result, "function_call"):
tool_call_found = True
self.tool_calls.append(result.function_call)
tool_response, call_id = agent._execute_tool_action(
tools_dict, result.function_call
)
function_response_part = types.Part.from_function_response(
name=result.function_call.name,
response={"result": tool_response},
)
messages.append(
{"role": "model", "content": [result.to_json_dict()]}
)
messages.append(
{
"role": "tool",
"content": [function_response_part.to_json_dict()],
}
)
else:
tool_call_found = False
yield result
if not tool_call_found:
return response
def get_llm_handler(llm_type):
handlers = {
"openai": OpenAILLMHandler(),
"google": GoogleLLMHandler(),
}
return handlers.get(llm_type, OpenAILLMHandler())

View File

@@ -17,26 +17,21 @@ class ToolActionParser:
return parser(call)
def _parse_openai_llm(self, call):
if isinstance(call, dict):
try:
call_args = json.loads(call["function"]["arguments"])
tool_id = call["function"]["name"].split("_")[-1]
action_name = call["function"]["name"].rsplit("_", 1)[0]
except (KeyError, TypeError) as e:
logger.error(f"Error parsing OpenAI LLM call: {e}")
return None, None, None
else:
try:
call_args = json.loads(call.function.arguments)
tool_id = call.function.name.split("_")[-1]
action_name = call.function.name.rsplit("_", 1)[0]
except (AttributeError, TypeError) as e:
logger.error(f"Error parsing OpenAI LLM call: {e}")
return None, None, None
try:
call_args = json.loads(call.arguments)
tool_id = call.name.split("_")[-1]
action_name = call.name.rsplit("_", 1)[0]
except (AttributeError, TypeError) as e:
logger.error(f"Error parsing OpenAI LLM call: {e}")
return None, None, None
return tool_id, action_name, call_args
def _parse_google_llm(self, call):
call_args = call.args
tool_id = call.name.split("_")[-1]
action_name = call.name.rsplit("_", 1)[0]
try:
call_args = call.arguments
tool_id = call.name.split("_")[-1]
action_name = call.name.rsplit("_", 1)[0]
except (AttributeError, TypeError) as e:
logger.error(f"Error parsing Google LLM call: {e}")
return None, None, None
return tool_id, action_name, call_args

View File

View File

@@ -0,0 +1,317 @@
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, Generator, List, Optional, Union
from application.logging import build_stack_data
logger = logging.getLogger(__name__)
@dataclass
class ToolCall:
"""Represents a tool/function call from the LLM."""
id: str
name: str
arguments: Union[str, Dict]
index: Optional[int] = None
@classmethod
def from_dict(cls, data: Dict) -> "ToolCall":
"""Create ToolCall from dictionary."""
return cls(
id=data.get("id", ""),
name=data.get("name", ""),
arguments=data.get("arguments", {}),
index=data.get("index"),
)
@dataclass
class LLMResponse:
"""Represents a response from the LLM."""
content: str
tool_calls: List[ToolCall]
finish_reason: str
raw_response: Any
@property
def requires_tool_call(self) -> bool:
"""Check if the response requires tool calls."""
return bool(self.tool_calls) and self.finish_reason == "tool_calls"
class LLMHandler(ABC):
"""Abstract base class for LLM handlers."""
def __init__(self):
self.llm_calls = []
self.tool_calls = []
@abstractmethod
def parse_response(self, response: Any) -> LLMResponse:
"""Parse raw LLM response into standardized format."""
pass
@abstractmethod
def create_tool_message(self, tool_call: ToolCall, result: Any) -> Dict:
"""Create a tool result message for the conversation history."""
pass
@abstractmethod
def _iterate_stream(self, response: Any) -> Generator:
"""Iterate through streaming response chunks."""
pass
def process_message_flow(
self,
agent,
initial_response,
tools_dict: Dict,
messages: List[Dict],
attachments: Optional[List] = None,
stream: bool = False,
) -> Union[str, Generator]:
"""
Main orchestration method for processing LLM message flow.
Args:
agent: The agent instance
initial_response: Initial LLM response
tools_dict: Dictionary of available tools
messages: Conversation history
attachments: Optional attachments
stream: Whether to use streaming
Returns:
Final response or generator for streaming
"""
messages = self.prepare_messages(agent, messages, attachments)
if stream:
return self.handle_streaming(agent, initial_response, tools_dict, messages)
else:
return self.handle_non_streaming(
agent, initial_response, tools_dict, messages
)
def prepare_messages(
self, agent, messages: List[Dict], attachments: Optional[List] = None
) -> List[Dict]:
"""
Prepare messages with attachments and provider-specific formatting.
Args:
agent: The agent instance
messages: Original messages
attachments: List of attachments
Returns:
Prepared messages list
"""
if not attachments:
return messages
logger.info(f"Preparing messages with {len(attachments)} attachments")
supported_types = agent.llm.get_supported_attachment_types()
supported_attachments = [
a for a in attachments if a.get("mime_type") in supported_types
]
unsupported_attachments = [
a for a in attachments if a.get("mime_type") not in supported_types
]
# Process supported attachments with the LLM's custom method
if supported_attachments:
logger.info(
f"Processing {len(supported_attachments)} supported attachments"
)
messages = agent.llm.prepare_messages_with_attachments(
messages, supported_attachments
)
# Process unsupported attachments with default method
if unsupported_attachments:
logger.info(
f"Processing {len(unsupported_attachments)} unsupported attachments"
)
messages = self._append_unsupported_attachments(
messages, unsupported_attachments
)
return messages
def _append_unsupported_attachments(
self, messages: List[Dict], attachments: List[Dict]
) -> List[Dict]:
"""
Default method to append unsupported attachment content to system prompt.
Args:
messages: Current messages
attachments: List of unsupported attachments
Returns:
Updated messages list
"""
prepared_messages = messages.copy()
attachment_texts = []
for attachment in attachments:
logger.info(f"Adding attachment {attachment.get('id')} to context")
if "content" in attachment:
attachment_texts.append(
f"Attached file content:\n\n{attachment['content']}"
)
if attachment_texts:
combined_text = "\n\n".join(attachment_texts)
system_msg = next(
(msg for msg in prepared_messages if msg.get("role") == "system"),
{"role": "system", "content": ""},
)
if system_msg not in prepared_messages:
prepared_messages.insert(0, system_msg)
system_msg["content"] += f"\n\n{combined_text}"
return prepared_messages
def handle_tool_calls(
self, agent, tool_calls: List[ToolCall], tools_dict: Dict, messages: List[Dict]
) -> List[Dict]:
"""
Execute tool calls and update conversation history.
Args:
agent: The agent instance
tool_calls: List of tool calls to execute
tools_dict: Available tools dictionary
messages: Current conversation history
Returns:
Updated messages list
"""
updated_messages = messages.copy()
for call in tool_calls:
try:
self.tool_calls.append(call)
tool_response, call_id = agent._execute_tool_action(tools_dict, call)
updated_messages.append(
{
"role": "assistant",
"content": [
{
"function_call": {
"name": call.name,
"args": call.arguments,
"call_id": call_id,
}
}
],
}
)
updated_messages.append(self.create_tool_message(call, tool_response))
except Exception as e:
logger.error(f"Error executing tool: {str(e)}", exc_info=True)
updated_messages.append(
{
"role": "tool",
"content": f"Error executing tool: {str(e)}",
"tool_call_id": call.id,
}
)
return updated_messages
def handle_non_streaming(
self, agent, response: Any, tools_dict: Dict, messages: List[Dict]
) -> Union[str, Dict]:
"""
Handle non-streaming response flow.
Args:
agent: The agent instance
response: Current LLM response
tools_dict: Available tools dictionary
messages: Conversation history
Returns:
Final response after processing all tool calls
"""
parsed = self.parse_response(response)
self.llm_calls.append(build_stack_data(agent.llm))
while parsed.requires_tool_call:
messages = self.handle_tool_calls(
agent, parsed.tool_calls, tools_dict, messages
)
response = agent.llm.gen(
model=agent.gpt_model, messages=messages, tools=agent.tools
)
parsed = self.parse_response(response)
self.llm_calls.append(build_stack_data(agent.llm))
return parsed.content
def handle_streaming(
self, agent, response: Any, tools_dict: Dict, messages: List[Dict]
) -> Generator:
"""
Handle streaming response flow.
Args:
agent: The agent instance
response: Current LLM response
tools_dict: Available tools dictionary
messages: Conversation history
Yields:
Streaming response chunks
"""
buffer = ""
tool_calls = {}
for chunk in self._iterate_stream(response):
if isinstance(chunk, str):
yield chunk
continue
parsed = self.parse_response(chunk)
if parsed.tool_calls:
for call in parsed.tool_calls:
if call.index not in tool_calls:
tool_calls[call.index] = call
else:
existing = tool_calls[call.index]
if call.id:
existing.id = call.id
if call.name:
existing.name = call.name
if call.arguments:
existing.arguments += call.arguments
if parsed.finish_reason == "tool_calls":
messages = self.handle_tool_calls(
agent, list(tool_calls.values()), tools_dict, messages
)
tool_calls = {}
response = agent.llm.gen_stream(
model=agent.gpt_model, messages=messages, tools=agent.tools
)
self.llm_calls.append(build_stack_data(agent.llm))
yield from self.handle_streaming(agent, response, tools_dict, messages)
return
if parsed.content:
buffer += parsed.content
yield buffer
buffer = ""
if parsed.finish_reason == "stop":
return

View File

@@ -0,0 +1,78 @@
import uuid
from typing import Any, Dict, Generator
from application.llm.handlers.base import LLMHandler, LLMResponse, ToolCall
class GoogleLLMHandler(LLMHandler):
"""Handler for Google's GenAI API."""
def parse_response(self, response: Any) -> LLMResponse:
"""Parse Google response into standardized format."""
if isinstance(response, str):
return LLMResponse(
content=response,
tool_calls=[],
finish_reason="stop",
raw_response=response,
)
if hasattr(response, "candidates"):
parts = response.candidates[0].content.parts if response.candidates else []
tool_calls = [
ToolCall(
id=str(uuid.uuid4()),
name=part.function_call.name,
arguments=part.function_call.args,
)
for part in parts
if hasattr(part, "function_call") and part.function_call is not None
]
content = " ".join(
part.text
for part in parts
if hasattr(part, "text") and part.text is not None
)
return LLMResponse(
content=content,
tool_calls=tool_calls,
finish_reason="tool_calls" if tool_calls else "stop",
raw_response=response,
)
else:
tool_calls = []
if hasattr(response, "function_call"):
tool_calls.append(
ToolCall(
id=str(uuid.uuid4()),
name=response.function_call.name,
arguments=response.function_call.args,
)
)
return LLMResponse(
content=response.text if hasattr(response, "text") else "",
tool_calls=tool_calls,
finish_reason="tool_calls" if tool_calls else "stop",
raw_response=response,
)
def create_tool_message(self, tool_call: ToolCall, result: Any) -> Dict:
"""Create Google-style tool message."""
from google.genai import types
return {
"role": "tool",
"content": [
types.Part.from_function_response(
name=tool_call.name, response={"result": result}
).to_json_dict()
],
}
def _iterate_stream(self, response: Any) -> Generator:
"""Iterate through Google streaming response."""
for chunk in response:
yield chunk

View File

@@ -0,0 +1,18 @@
from application.llm.handlers.base import LLMHandler
from application.llm.handlers.google import GoogleLLMHandler
from application.llm.handlers.openai import OpenAILLMHandler
class LLMHandlerCreator:
handlers = {
"openai": OpenAILLMHandler,
"google": GoogleLLMHandler,
"default": OpenAILLMHandler,
}
@classmethod
def create_handler(cls, llm_type: str, *args, **kwargs) -> LLMHandler:
handler_class = cls.handlers.get(llm_type.lower())
if not handler_class:
raise ValueError(f"No LLM handler class found for type {llm_type}")
return handler_class(*args, **kwargs)

View File

@@ -0,0 +1,57 @@
from typing import Any, Dict, Generator
from application.llm.handlers.base import LLMHandler, LLMResponse, ToolCall
class OpenAILLMHandler(LLMHandler):
"""Handler for OpenAI API."""
def parse_response(self, response: Any) -> LLMResponse:
"""Parse OpenAI response into standardized format."""
if isinstance(response, str):
return LLMResponse(
content=response,
tool_calls=[],
finish_reason="stop",
raw_response=response,
)
message = getattr(response, "message", None) or getattr(response, "delta", None)
tool_calls = []
if hasattr(message, "tool_calls"):
tool_calls = [
ToolCall(
id=getattr(tc, "id", ""),
name=getattr(tc.function, "name", ""),
arguments=getattr(tc.function, "arguments", ""),
index=getattr(tc, "index", None),
)
for tc in message.tool_calls or []
]
return LLMResponse(
content=getattr(message, "content", ""),
tool_calls=tool_calls,
finish_reason=getattr(response, "finish_reason", ""),
raw_response=response,
)
def create_tool_message(self, tool_call: ToolCall, result: Any) -> Dict:
"""Create OpenAI-style tool message."""
return {
"role": "tool",
"content": [
{
"function_response": {
"name": tool_call.name,
"response": {"result": result},
"call_id": tool_call.id,
}
}
],
}
def _iterate_stream(self, response: Any) -> Generator:
"""Iterate through OpenAI streaming response."""
for chunk in response:
yield chunk