(feat:attach) pass attachments for generation

This commit is contained in:
ManishMadan2882
2025-04-02 15:14:56 +05:30
parent 4241307990
commit f235a94986
4 changed files with 58 additions and 8 deletions

View File

@@ -1,8 +1,11 @@
import json
import logging
from abc import ABC, abstractmethod
from application.logging import build_stack_data
logger = logging.getLogger(__name__)
class LLMHandler(ABC):
def __init__(self):
@@ -10,12 +13,51 @@ class LLMHandler(ABC):
self.tool_calls = []
@abstractmethod
def handle_response(self, agent, resp, tools_dict, messages, **kwargs):
def handle_response(self, agent, resp, tools_dict, messages, attachments=None, **kwargs):
pass
def prepare_messages_with_attachments(self, agent, messages, attachments=None):
"""
Prepare messages with attachment content if available.
Args:
agent: The current agent instance.
messages (list): List of message dictionaries.
attachments (list): List of attachment dictionaries with content.
Returns:
list: Messages with attachment context added.
"""
if not attachments:
return messages
logger.info(f"Preparing messages with {len(attachments)} attachments")
# If the LLM has its own attachment handling, use that
if hasattr(agent.llm, "prepare_messages_with_attachments"):
return agent.llm.prepare_messages_with_attachments(messages, attachments)
# Otherwise, use a generic approach:
# Insert attachment context after system messages, before user messages.
attachment_context = []
for attachment in attachments:
logger.info(f"Adding attachment {attachment.get('id')} to context")
attachment_context.append({
"role": "system",
"content": f"The user has attached a file with the following content:\n\n{attachment['content']}"
})
system_messages = [msg for msg in messages if msg.get("role") == "system"]
user_messages = [msg for msg in messages if msg.get("role") != "system"]
return system_messages + attachment_context + user_messages
class OpenAILLMHandler(LLMHandler):
def handle_response(self, agent, resp, tools_dict, messages, stream: bool = True):
def handle_response(self, agent, resp, tools_dict, messages, attachments=None, stream: bool = True):
messages = self.prepare_messages_with_attachments(agent, messages, attachments)
if not stream:
while hasattr(resp, "finish_reason") and resp.finish_reason == "tool_calls":
message = json.loads(resp.model_dump_json())["message"]
@@ -168,8 +210,10 @@ class OpenAILLMHandler(LLMHandler):
class GoogleLLMHandler(LLMHandler):
def handle_response(self, agent, resp, tools_dict, messages, stream: bool = True):
def handle_response(self, agent, resp, tools_dict, messages, attachments=None, stream: bool = True):
from google.genai import types
messages = self.prepare_messages_with_attachments(agent, messages, attachments)
while True:
if not stream: