diff --git a/application/agents/base.py b/application/agents/base.py index d44244cf..a4bbd001 100644 --- a/application/agents/base.py +++ b/application/agents/base.py @@ -2,16 +2,18 @@ import uuid from abc import ABC, abstractmethod from typing import Dict, Generator, List, Optional -from application.agents.llm_handler import get_llm_handler +from bson.objectid import ObjectId + from application.agents.tools.tool_action_parser import ToolActionParser from application.agents.tools.tool_manager import ToolManager from application.core.mongo_db import MongoDB +from application.core.settings import settings + +from application.llm.handlers.handler_creator import LLMHandlerCreator from application.llm.llm_creator import LLMCreator from application.logging import build_stack_data, log_activity, LogContext from application.retriever.base import BaseRetriever -from application.core.settings import settings -from bson.objectid import ObjectId class BaseAgent(ABC): @@ -45,7 +47,9 @@ class BaseAgent(ABC): user_api_key=user_api_key, decoded_token=decoded_token, ) - self.llm_handler = get_llm_handler(llm_name) + self.llm_handler = LLMHandlerCreator.create_handler( + llm_name if llm_name else "default" + ) self.attachments = attachments or [] @log_activity() @@ -268,8 +272,8 @@ class BaseAgent(ABC): log_context: Optional[LogContext] = None, attachments: Optional[List[Dict]] = None, ): - resp = self.llm_handler.handle_response( - self, resp, tools_dict, messages, attachments + resp = self.llm_handler.process_message_flow( + self, resp, tools_dict, messages, attachments, True ) if log_context: data = build_stack_data(self.llm_handler, exclude_attributes=["tool_calls"]) diff --git a/application/agents/llm_handler.py b/application/agents/llm_handler.py deleted file mode 100644 index 1b995f71..00000000 --- a/application/agents/llm_handler.py +++ /dev/null @@ -1,351 +0,0 @@ -import json -import logging -from abc import ABC, abstractmethod - -from application.logging import build_stack_data - -logger = logging.getLogger(__name__) - - -class LLMHandler(ABC): - def __init__(self): - self.llm_calls = [] - self.tool_calls = [] - - @abstractmethod - def handle_response(self, agent, resp, tools_dict, messages, attachments=None, **kwargs): - pass - - def prepare_messages_with_attachments(self, agent, messages, attachments=None): - """ - Prepare messages with attachment content if available. - - Args: - agent: The current agent instance. - messages (list): List of message dictionaries. - attachments (list): List of attachment dictionaries with content. - - Returns: - list: Messages with attachment context added to the system prompt. - """ - if not attachments: - return messages - - logger.info(f"Preparing messages with {len(attachments)} attachments") - - supported_types = agent.llm.get_supported_attachment_types() - - supported_attachments = [] - unsupported_attachments = [] - - for attachment in attachments: - mime_type = attachment.get('mime_type') - if mime_type in supported_types: - supported_attachments.append(attachment) - else: - unsupported_attachments.append(attachment) - - # Process supported attachments with the LLM's custom method - prepared_messages = messages - if supported_attachments: - logger.info(f"Processing {len(supported_attachments)} supported attachments with {agent.llm.__class__.__name__}'s method") - prepared_messages = agent.llm.prepare_messages_with_attachments(messages, supported_attachments) - - # Process unsupported attachments with the default method - if unsupported_attachments: - logger.info(f"Processing {len(unsupported_attachments)} unsupported attachments with default method") - prepared_messages = self._append_attachment_content_to_system(prepared_messages, unsupported_attachments) - - return prepared_messages - - def _append_attachment_content_to_system(self, messages, attachments): - """ - Default method to append attachment content to the system prompt. - - Args: - messages (list): List of message dictionaries. - attachments (list): List of attachment dictionaries with content. - - Returns: - list: Messages with attachment context added to the system prompt. - """ - prepared_messages = messages.copy() - - attachment_texts = [] - for attachment in attachments: - logger.info(f"Adding attachment {attachment.get('id')} to context") - if 'content' in attachment: - attachment_texts.append(f"Attached file content:\n\n{attachment['content']}") - - if attachment_texts: - combined_attachment_text = "\n\n".join(attachment_texts) - - system_found = False - for i in range(len(prepared_messages)): - if prepared_messages[i].get("role") == "system": - prepared_messages[i]["content"] += f"\n\n{combined_attachment_text}" - system_found = True - break - - if not system_found: - prepared_messages.insert(0, {"role": "system", "content": combined_attachment_text}) - - return prepared_messages - -class OpenAILLMHandler(LLMHandler): - def handle_response(self, agent, resp, tools_dict, messages, attachments=None, stream: bool = True): - - messages = self.prepare_messages_with_attachments(agent, messages, attachments) - logger.info(f"Messages with attachments: {messages}") - if not stream: - while hasattr(resp, "finish_reason") and resp.finish_reason == "tool_calls": - message = json.loads(resp.model_dump_json())["message"] - keys_to_remove = {"audio", "function_call", "refusal"} - filtered_data = { - k: v for k, v in message.items() if k not in keys_to_remove - } - messages.append(filtered_data) - - tool_calls = resp.message.tool_calls - for call in tool_calls: - try: - self.tool_calls.append(call) - tool_response, call_id = agent._execute_tool_action( - tools_dict, call - ) - function_call_dict = { - "function_call": { - "name": call.function.name, - "args": call.function.arguments, - "call_id": call_id, - } - } - function_response_dict = { - "function_response": { - "name": call.function.name, - "response": {"result": tool_response}, - "call_id": call_id, - } - } - - messages.append( - {"role": "assistant", "content": [function_call_dict]} - ) - messages.append( - {"role": "tool", "content": [function_response_dict]} - ) - - messages = self.prepare_messages_with_attachments(agent, messages, attachments) - except Exception as e: - logging.error(f"Error executing tool: {str(e)}", exc_info=True) - messages.append( - { - "role": "tool", - "content": f"Error executing tool: {str(e)}", - "tool_call_id": call_id, - } - ) - resp = agent.llm.gen_stream( - model=agent.gpt_model, messages=messages, tools=agent.tools - ) - self.llm_calls.append(build_stack_data(agent.llm)) - return resp - - else: - text_buffer = "" - while True: - tool_calls = {} - for chunk in resp: - if isinstance(chunk, str) and len(chunk) > 0: - yield chunk - continue - elif hasattr(chunk, "delta"): - chunk_delta = chunk.delta - - if ( - hasattr(chunk_delta, "tool_calls") - and chunk_delta.tool_calls is not None - ): - for tool_call in chunk_delta.tool_calls: - index = tool_call.index - if index not in tool_calls: - tool_calls[index] = { - "id": "", - "function": {"name": "", "arguments": ""}, - } - - current = tool_calls[index] - if tool_call.id: - current["id"] = tool_call.id - if tool_call.function.name: - current["function"][ - "name" - ] = tool_call.function.name - if tool_call.function.arguments: - current["function"][ - "arguments" - ] += tool_call.function.arguments - tool_calls[index] = current - - if ( - hasattr(chunk, "finish_reason") - and chunk.finish_reason == "tool_calls" - ): - for index in sorted(tool_calls.keys()): - call = tool_calls[index] - try: - self.tool_calls.append(call) - tool_response, call_id = agent._execute_tool_action( - tools_dict, call - ) - if isinstance(call["function"]["arguments"], str): - call["function"]["arguments"] = json.loads(call["function"]["arguments"]) - - function_call_dict = { - "function_call": { - "name": call["function"]["name"], - "args": call["function"]["arguments"], - "call_id": call["id"], - } - } - function_response_dict = { - "function_response": { - "name": call["function"]["name"], - "response": {"result": tool_response}, - "call_id": call["id"], - } - } - - messages.append( - { - "role": "assistant", - "content": [function_call_dict], - } - ) - messages.append( - { - "role": "tool", - "content": [function_response_dict], - } - ) - - except Exception as e: - logging.error(f"Error executing tool: {str(e)}", exc_info=True) - messages.append( - { - "role": "assistant", - "content": f"Error executing tool: {str(e)}", - } - ) - tool_calls = {} - if hasattr(chunk_delta, "content") and chunk_delta.content: - # Add to buffer or yield immediately based on your preference - text_buffer += chunk_delta.content - yield text_buffer - text_buffer = "" - - if ( - hasattr(chunk, "finish_reason") - and chunk.finish_reason == "stop" - ): - return resp - elif isinstance(chunk, str) and len(chunk) == 0: - continue - - logger.info(f"Regenerating with messages: {messages}") - resp = agent.llm.gen_stream( - model=agent.gpt_model, messages=messages, tools=agent.tools - ) - self.llm_calls.append(build_stack_data(agent.llm)) - - -class GoogleLLMHandler(LLMHandler): - def handle_response(self, agent, resp, tools_dict, messages, attachments=None, stream: bool = True): - from google.genai import types - - messages = self.prepare_messages_with_attachments(agent, messages, attachments) - - while True: - if not stream: - response = agent.llm.gen( - model=agent.gpt_model, messages=messages, tools=agent.tools - ) - self.llm_calls.append(build_stack_data(agent.llm)) - if response.candidates and response.candidates[0].content.parts: - tool_call_found = False - for part in response.candidates[0].content.parts: - if part.function_call: - tool_call_found = True - self.tool_calls.append(part.function_call) - tool_response, call_id = agent._execute_tool_action( - tools_dict, part.function_call - ) - function_response_part = types.Part.from_function_response( - name=part.function_call.name, - response={"result": tool_response}, - ) - - messages.append( - {"role": "model", "content": [part.to_json_dict()]} - ) - messages.append( - { - "role": "tool", - "content": [function_response_part.to_json_dict()], - } - ) - - if ( - not tool_call_found - and response.candidates[0].content.parts - and response.candidates[0].content.parts[0].text - ): - return response.candidates[0].content.parts[0].text - elif not tool_call_found: - return response.candidates[0].content.parts - - else: - return response - - else: - response = agent.llm.gen_stream( - model=agent.gpt_model, messages=messages, tools=agent.tools - ) - self.llm_calls.append(build_stack_data(agent.llm)) - - tool_call_found = False - for result in response: - if hasattr(result, "function_call"): - tool_call_found = True - self.tool_calls.append(result.function_call) - tool_response, call_id = agent._execute_tool_action( - tools_dict, result.function_call - ) - function_response_part = types.Part.from_function_response( - name=result.function_call.name, - response={"result": tool_response}, - ) - - messages.append( - {"role": "model", "content": [result.to_json_dict()]} - ) - messages.append( - { - "role": "tool", - "content": [function_response_part.to_json_dict()], - } - ) - else: - tool_call_found = False - yield result - - if not tool_call_found: - return response - - -def get_llm_handler(llm_type): - handlers = { - "openai": OpenAILLMHandler(), - "google": GoogleLLMHandler(), - } - return handlers.get(llm_type, OpenAILLMHandler()) diff --git a/application/agents/tools/tool_action_parser.py b/application/agents/tools/tool_action_parser.py index c7da5a4c..0589ac88 100644 --- a/application/agents/tools/tool_action_parser.py +++ b/application/agents/tools/tool_action_parser.py @@ -17,26 +17,21 @@ class ToolActionParser: return parser(call) def _parse_openai_llm(self, call): - if isinstance(call, dict): - try: - call_args = json.loads(call["function"]["arguments"]) - tool_id = call["function"]["name"].split("_")[-1] - action_name = call["function"]["name"].rsplit("_", 1)[0] - except (KeyError, TypeError) as e: - logger.error(f"Error parsing OpenAI LLM call: {e}") - return None, None, None - else: - try: - call_args = json.loads(call.function.arguments) - tool_id = call.function.name.split("_")[-1] - action_name = call.function.name.rsplit("_", 1)[0] - except (AttributeError, TypeError) as e: - logger.error(f"Error parsing OpenAI LLM call: {e}") - return None, None, None + try: + call_args = json.loads(call.arguments) + tool_id = call.name.split("_")[-1] + action_name = call.name.rsplit("_", 1)[0] + except (AttributeError, TypeError) as e: + logger.error(f"Error parsing OpenAI LLM call: {e}") + return None, None, None return tool_id, action_name, call_args def _parse_google_llm(self, call): - call_args = call.args - tool_id = call.name.split("_")[-1] - action_name = call.name.rsplit("_", 1)[0] + try: + call_args = call.arguments + tool_id = call.name.split("_")[-1] + action_name = call.name.rsplit("_", 1)[0] + except (AttributeError, TypeError) as e: + logger.error(f"Error parsing Google LLM call: {e}") + return None, None, None return tool_id, action_name, call_args diff --git a/application/llm/handlers/__init__.py b/application/llm/handlers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/application/llm/handlers/base.py b/application/llm/handlers/base.py new file mode 100644 index 00000000..ede7cec3 --- /dev/null +++ b/application/llm/handlers/base.py @@ -0,0 +1,317 @@ +import logging +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, Dict, Generator, List, Optional, Union + +from application.logging import build_stack_data + +logger = logging.getLogger(__name__) + + +@dataclass +class ToolCall: + """Represents a tool/function call from the LLM.""" + + id: str + name: str + arguments: Union[str, Dict] + index: Optional[int] = None + + @classmethod + def from_dict(cls, data: Dict) -> "ToolCall": + """Create ToolCall from dictionary.""" + return cls( + id=data.get("id", ""), + name=data.get("name", ""), + arguments=data.get("arguments", {}), + index=data.get("index"), + ) + + +@dataclass +class LLMResponse: + """Represents a response from the LLM.""" + + content: str + tool_calls: List[ToolCall] + finish_reason: str + raw_response: Any + + @property + def requires_tool_call(self) -> bool: + """Check if the response requires tool calls.""" + return bool(self.tool_calls) and self.finish_reason == "tool_calls" + + +class LLMHandler(ABC): + """Abstract base class for LLM handlers.""" + + def __init__(self): + self.llm_calls = [] + self.tool_calls = [] + + @abstractmethod + def parse_response(self, response: Any) -> LLMResponse: + """Parse raw LLM response into standardized format.""" + pass + + @abstractmethod + def create_tool_message(self, tool_call: ToolCall, result: Any) -> Dict: + """Create a tool result message for the conversation history.""" + pass + + @abstractmethod + def _iterate_stream(self, response: Any) -> Generator: + """Iterate through streaming response chunks.""" + pass + + def process_message_flow( + self, + agent, + initial_response, + tools_dict: Dict, + messages: List[Dict], + attachments: Optional[List] = None, + stream: bool = False, + ) -> Union[str, Generator]: + """ + Main orchestration method for processing LLM message flow. + + Args: + agent: The agent instance + initial_response: Initial LLM response + tools_dict: Dictionary of available tools + messages: Conversation history + attachments: Optional attachments + stream: Whether to use streaming + + Returns: + Final response or generator for streaming + """ + messages = self.prepare_messages(agent, messages, attachments) + + if stream: + return self.handle_streaming(agent, initial_response, tools_dict, messages) + else: + return self.handle_non_streaming( + agent, initial_response, tools_dict, messages + ) + + def prepare_messages( + self, agent, messages: List[Dict], attachments: Optional[List] = None + ) -> List[Dict]: + """ + Prepare messages with attachments and provider-specific formatting. + + Args: + agent: The agent instance + messages: Original messages + attachments: List of attachments + + Returns: + Prepared messages list + """ + if not attachments: + return messages + logger.info(f"Preparing messages with {len(attachments)} attachments") + supported_types = agent.llm.get_supported_attachment_types() + + supported_attachments = [ + a for a in attachments if a.get("mime_type") in supported_types + ] + unsupported_attachments = [ + a for a in attachments if a.get("mime_type") not in supported_types + ] + + # Process supported attachments with the LLM's custom method + + if supported_attachments: + logger.info( + f"Processing {len(supported_attachments)} supported attachments" + ) + messages = agent.llm.prepare_messages_with_attachments( + messages, supported_attachments + ) + # Process unsupported attachments with default method + + if unsupported_attachments: + logger.info( + f"Processing {len(unsupported_attachments)} unsupported attachments" + ) + messages = self._append_unsupported_attachments( + messages, unsupported_attachments + ) + return messages + + def _append_unsupported_attachments( + self, messages: List[Dict], attachments: List[Dict] + ) -> List[Dict]: + """ + Default method to append unsupported attachment content to system prompt. + + Args: + messages: Current messages + attachments: List of unsupported attachments + + Returns: + Updated messages list + """ + prepared_messages = messages.copy() + attachment_texts = [] + + for attachment in attachments: + logger.info(f"Adding attachment {attachment.get('id')} to context") + if "content" in attachment: + attachment_texts.append( + f"Attached file content:\n\n{attachment['content']}" + ) + if attachment_texts: + combined_text = "\n\n".join(attachment_texts) + + system_msg = next( + (msg for msg in prepared_messages if msg.get("role") == "system"), + {"role": "system", "content": ""}, + ) + + if system_msg not in prepared_messages: + prepared_messages.insert(0, system_msg) + system_msg["content"] += f"\n\n{combined_text}" + return prepared_messages + + def handle_tool_calls( + self, agent, tool_calls: List[ToolCall], tools_dict: Dict, messages: List[Dict] + ) -> List[Dict]: + """ + Execute tool calls and update conversation history. + + Args: + agent: The agent instance + tool_calls: List of tool calls to execute + tools_dict: Available tools dictionary + messages: Current conversation history + + Returns: + Updated messages list + """ + updated_messages = messages.copy() + + for call in tool_calls: + try: + self.tool_calls.append(call) + tool_response, call_id = agent._execute_tool_action(tools_dict, call) + + updated_messages.append( + { + "role": "assistant", + "content": [ + { + "function_call": { + "name": call.name, + "args": call.arguments, + "call_id": call_id, + } + } + ], + } + ) + + updated_messages.append(self.create_tool_message(call, tool_response)) + + except Exception as e: + logger.error(f"Error executing tool: {str(e)}", exc_info=True) + updated_messages.append( + { + "role": "tool", + "content": f"Error executing tool: {str(e)}", + "tool_call_id": call.id, + } + ) + + return updated_messages + + def handle_non_streaming( + self, agent, response: Any, tools_dict: Dict, messages: List[Dict] + ) -> Union[str, Dict]: + """ + Handle non-streaming response flow. + + Args: + agent: The agent instance + response: Current LLM response + tools_dict: Available tools dictionary + messages: Conversation history + + Returns: + Final response after processing all tool calls + """ + parsed = self.parse_response(response) + self.llm_calls.append(build_stack_data(agent.llm)) + + while parsed.requires_tool_call: + messages = self.handle_tool_calls( + agent, parsed.tool_calls, tools_dict, messages + ) + + response = agent.llm.gen( + model=agent.gpt_model, messages=messages, tools=agent.tools + ) + parsed = self.parse_response(response) + self.llm_calls.append(build_stack_data(agent.llm)) + + return parsed.content + + def handle_streaming( + self, agent, response: Any, tools_dict: Dict, messages: List[Dict] + ) -> Generator: + """ + Handle streaming response flow. + + Args: + agent: The agent instance + response: Current LLM response + tools_dict: Available tools dictionary + messages: Conversation history + + Yields: + Streaming response chunks + """ + buffer = "" + tool_calls = {} + + for chunk in self._iterate_stream(response): + if isinstance(chunk, str): + yield chunk + continue + parsed = self.parse_response(chunk) + + if parsed.tool_calls: + for call in parsed.tool_calls: + if call.index not in tool_calls: + tool_calls[call.index] = call + else: + existing = tool_calls[call.index] + if call.id: + existing.id = call.id + if call.name: + existing.name = call.name + if call.arguments: + existing.arguments += call.arguments + if parsed.finish_reason == "tool_calls": + messages = self.handle_tool_calls( + agent, list(tool_calls.values()), tools_dict, messages + ) + tool_calls = {} + + response = agent.llm.gen_stream( + model=agent.gpt_model, messages=messages, tools=agent.tools + ) + self.llm_calls.append(build_stack_data(agent.llm)) + + yield from self.handle_streaming(agent, response, tools_dict, messages) + return + if parsed.content: + buffer += parsed.content + yield buffer + buffer = "" + if parsed.finish_reason == "stop": + return diff --git a/application/llm/handlers/google.py b/application/llm/handlers/google.py new file mode 100644 index 00000000..b43f2a16 --- /dev/null +++ b/application/llm/handlers/google.py @@ -0,0 +1,78 @@ +import uuid +from typing import Any, Dict, Generator + +from application.llm.handlers.base import LLMHandler, LLMResponse, ToolCall + + +class GoogleLLMHandler(LLMHandler): + """Handler for Google's GenAI API.""" + + def parse_response(self, response: Any) -> LLMResponse: + """Parse Google response into standardized format.""" + + if isinstance(response, str): + return LLMResponse( + content=response, + tool_calls=[], + finish_reason="stop", + raw_response=response, + ) + + if hasattr(response, "candidates"): + parts = response.candidates[0].content.parts if response.candidates else [] + tool_calls = [ + ToolCall( + id=str(uuid.uuid4()), + name=part.function_call.name, + arguments=part.function_call.args, + ) + for part in parts + if hasattr(part, "function_call") and part.function_call is not None + ] + + content = " ".join( + part.text + for part in parts + if hasattr(part, "text") and part.text is not None + ) + return LLMResponse( + content=content, + tool_calls=tool_calls, + finish_reason="tool_calls" if tool_calls else "stop", + raw_response=response, + ) + + else: + tool_calls = [] + if hasattr(response, "function_call"): + tool_calls.append( + ToolCall( + id=str(uuid.uuid4()), + name=response.function_call.name, + arguments=response.function_call.args, + ) + ) + return LLMResponse( + content=response.text if hasattr(response, "text") else "", + tool_calls=tool_calls, + finish_reason="tool_calls" if tool_calls else "stop", + raw_response=response, + ) + + def create_tool_message(self, tool_call: ToolCall, result: Any) -> Dict: + """Create Google-style tool message.""" + from google.genai import types + + return { + "role": "tool", + "content": [ + types.Part.from_function_response( + name=tool_call.name, response={"result": result} + ).to_json_dict() + ], + } + + def _iterate_stream(self, response: Any) -> Generator: + """Iterate through Google streaming response.""" + for chunk in response: + yield chunk diff --git a/application/llm/handlers/handler_creator.py b/application/llm/handlers/handler_creator.py new file mode 100644 index 00000000..e39c000d --- /dev/null +++ b/application/llm/handlers/handler_creator.py @@ -0,0 +1,18 @@ +from application.llm.handlers.base import LLMHandler +from application.llm.handlers.google import GoogleLLMHandler +from application.llm.handlers.openai import OpenAILLMHandler + + +class LLMHandlerCreator: + handlers = { + "openai": OpenAILLMHandler, + "google": GoogleLLMHandler, + "default": OpenAILLMHandler, + } + + @classmethod + def create_handler(cls, llm_type: str, *args, **kwargs) -> LLMHandler: + handler_class = cls.handlers.get(llm_type.lower()) + if not handler_class: + raise ValueError(f"No LLM handler class found for type {llm_type}") + return handler_class(*args, **kwargs) diff --git a/application/llm/handlers/openai.py b/application/llm/handlers/openai.py new file mode 100644 index 00000000..99ddde4c --- /dev/null +++ b/application/llm/handlers/openai.py @@ -0,0 +1,57 @@ +from typing import Any, Dict, Generator + +from application.llm.handlers.base import LLMHandler, LLMResponse, ToolCall + + +class OpenAILLMHandler(LLMHandler): + """Handler for OpenAI API.""" + + def parse_response(self, response: Any) -> LLMResponse: + """Parse OpenAI response into standardized format.""" + if isinstance(response, str): + return LLMResponse( + content=response, + tool_calls=[], + finish_reason="stop", + raw_response=response, + ) + + message = getattr(response, "message", None) or getattr(response, "delta", None) + + tool_calls = [] + if hasattr(message, "tool_calls"): + tool_calls = [ + ToolCall( + id=getattr(tc, "id", ""), + name=getattr(tc.function, "name", ""), + arguments=getattr(tc.function, "arguments", ""), + index=getattr(tc, "index", None), + ) + for tc in message.tool_calls or [] + ] + return LLMResponse( + content=getattr(message, "content", ""), + tool_calls=tool_calls, + finish_reason=getattr(response, "finish_reason", ""), + raw_response=response, + ) + + def create_tool_message(self, tool_call: ToolCall, result: Any) -> Dict: + """Create OpenAI-style tool message.""" + return { + "role": "tool", + "content": [ + { + "function_response": { + "name": tool_call.name, + "response": {"result": result}, + "call_id": tool_call.id, + } + } + ], + } + + def _iterate_stream(self, response: Any) -> Generator: + """Iterate through OpenAI streaming response.""" + for chunk in response: + yield chunk