diff --git a/application/agents/base.py b/application/agents/base.py index d0f972a9..edbf5475 100644 --- a/application/agents/base.py +++ b/application/agents/base.py @@ -17,6 +17,7 @@ class BaseAgent: api_key, user_api_key=None, decoded_token=None, + attachments=None, ): self.endpoint = endpoint self.llm = LLMCreator.create_llm( @@ -30,6 +31,7 @@ class BaseAgent: self.tools = [] self.tool_config = {} self.tool_calls = [] + self.attachments = attachments or [] def gen(self, *args, **kwargs) -> Generator[Dict, None, None]: raise NotImplementedError('Method "gen" must be implemented in the child class') diff --git a/application/agents/classic_agent.py b/application/agents/classic_agent.py index 2752c833..aa5a3302 100644 --- a/application/agents/classic_agent.py +++ b/application/agents/classic_agent.py @@ -5,7 +5,8 @@ from application.agents.base import BaseAgent from application.logging import build_stack_data, log_activity, LogContext from application.retriever.base import BaseRetriever - +import logging +logger = logging.getLogger(__name__) class ClassicAgent(BaseAgent): def __init__( @@ -18,9 +19,10 @@ class ClassicAgent(BaseAgent): prompt="", chat_history=None, decoded_token=None, + attachments=None, ): super().__init__( - endpoint, llm_name, gpt_model, api_key, user_api_key, decoded_token + endpoint, llm_name, gpt_model, api_key, user_api_key, decoded_token, attachments ) self.user = decoded_token.get("sub") self.prompt = prompt @@ -93,7 +95,7 @@ class ClassicAgent(BaseAgent): yield {"answer": resp.message.content} return - resp = self._llm_handler(resp, tools_dict, messages_combine, log_context) + resp = self._llm_handler(resp, tools_dict, messages_combine, log_context, self.attachments) if isinstance(resp, str): yield {"answer": resp} @@ -130,9 +132,10 @@ class ClassicAgent(BaseAgent): log_context.stacks.append({"component": "llm", "data": data}) return resp - def _llm_handler(self, resp, tools_dict, messages_combine, log_context): + def _llm_handler(self, resp, tools_dict, messages_combine, log_context, attachments=None): + logger.info(f"Handling LLM response with {len(attachments) if attachments else 0} attachments") resp = self.llm_handler.handle_response( - self, resp, tools_dict, messages_combine + self, resp, tools_dict, messages_combine, attachments=attachments ) if log_context: data = build_stack_data(self.llm_handler) diff --git a/application/agents/llm_handler.py b/application/agents/llm_handler.py index a70357f8..229db1ef 100644 --- a/application/agents/llm_handler.py +++ b/application/agents/llm_handler.py @@ -1,8 +1,11 @@ import json +import logging from abc import ABC, abstractmethod from application.logging import build_stack_data +logger = logging.getLogger(__name__) + class LLMHandler(ABC): def __init__(self): @@ -10,12 +13,51 @@ class LLMHandler(ABC): self.tool_calls = [] @abstractmethod - def handle_response(self, agent, resp, tools_dict, messages, **kwargs): + def handle_response(self, agent, resp, tools_dict, messages, attachments=None, **kwargs): pass + + def prepare_messages_with_attachments(self, agent, messages, attachments=None): + """ + Prepare messages with attachment content if available. + + Args: + agent: The current agent instance. + messages (list): List of message dictionaries. + attachments (list): List of attachment dictionaries with content. + + Returns: + list: Messages with attachment context added. + """ + if not attachments: + return messages + + logger.info(f"Preparing messages with {len(attachments)} attachments") + + # If the LLM has its own attachment handling, use that + if hasattr(agent.llm, "prepare_messages_with_attachments"): + return agent.llm.prepare_messages_with_attachments(messages, attachments) + + # Otherwise, use a generic approach: + # Insert attachment context after system messages, before user messages. + attachment_context = [] + for attachment in attachments: + logger.info(f"Adding attachment {attachment.get('id')} to context") + attachment_context.append({ + "role": "system", + "content": f"The user has attached a file with the following content:\n\n{attachment['content']}" + }) + + system_messages = [msg for msg in messages if msg.get("role") == "system"] + user_messages = [msg for msg in messages if msg.get("role") != "system"] + + return system_messages + attachment_context + user_messages class OpenAILLMHandler(LLMHandler): - def handle_response(self, agent, resp, tools_dict, messages, stream: bool = True): + def handle_response(self, agent, resp, tools_dict, messages, attachments=None, stream: bool = True): + + messages = self.prepare_messages_with_attachments(agent, messages, attachments) + if not stream: while hasattr(resp, "finish_reason") and resp.finish_reason == "tool_calls": message = json.loads(resp.model_dump_json())["message"] @@ -168,8 +210,10 @@ class OpenAILLMHandler(LLMHandler): class GoogleLLMHandler(LLMHandler): - def handle_response(self, agent, resp, tools_dict, messages, stream: bool = True): + def handle_response(self, agent, resp, tools_dict, messages, attachments=None, stream: bool = True): from google.genai import types + + messages = self.prepare_messages_with_attachments(agent, messages, attachments) while True: if not stream: diff --git a/application/worker.py b/application/worker.py index 4ee61421..03180d5d 100755 --- a/application/worker.py +++ b/application/worker.py @@ -345,6 +345,7 @@ def attachment_worker(self, directory, file_info, user): base_dir = os.path.join(directory, user, "attachments", folder_name) file_path = os.path.join(base_dir, filename) + logging.info(f"Processing file: {file_path}", extra={"user": user, "job": job_name}) if not os.path.exists(file_path):