mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 08:33:20 +00:00
(feat:attach) pass attachments for generation
This commit is contained in:
@@ -17,6 +17,7 @@ class BaseAgent:
|
||||
api_key,
|
||||
user_api_key=None,
|
||||
decoded_token=None,
|
||||
attachments=None,
|
||||
):
|
||||
self.endpoint = endpoint
|
||||
self.llm = LLMCreator.create_llm(
|
||||
@@ -30,6 +31,7 @@ class BaseAgent:
|
||||
self.tools = []
|
||||
self.tool_config = {}
|
||||
self.tool_calls = []
|
||||
self.attachments = attachments or []
|
||||
|
||||
def gen(self, *args, **kwargs) -> Generator[Dict, None, None]:
|
||||
raise NotImplementedError('Method "gen" must be implemented in the child class')
|
||||
|
||||
@@ -5,7 +5,8 @@ from application.agents.base import BaseAgent
|
||||
from application.logging import build_stack_data, log_activity, LogContext
|
||||
|
||||
from application.retriever.base import BaseRetriever
|
||||
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ClassicAgent(BaseAgent):
|
||||
def __init__(
|
||||
@@ -18,9 +19,10 @@ class ClassicAgent(BaseAgent):
|
||||
prompt="",
|
||||
chat_history=None,
|
||||
decoded_token=None,
|
||||
attachments=None,
|
||||
):
|
||||
super().__init__(
|
||||
endpoint, llm_name, gpt_model, api_key, user_api_key, decoded_token
|
||||
endpoint, llm_name, gpt_model, api_key, user_api_key, decoded_token, attachments
|
||||
)
|
||||
self.user = decoded_token.get("sub")
|
||||
self.prompt = prompt
|
||||
@@ -93,7 +95,7 @@ class ClassicAgent(BaseAgent):
|
||||
yield {"answer": resp.message.content}
|
||||
return
|
||||
|
||||
resp = self._llm_handler(resp, tools_dict, messages_combine, log_context)
|
||||
resp = self._llm_handler(resp, tools_dict, messages_combine, log_context, self.attachments)
|
||||
|
||||
if isinstance(resp, str):
|
||||
yield {"answer": resp}
|
||||
@@ -130,9 +132,10 @@ class ClassicAgent(BaseAgent):
|
||||
log_context.stacks.append({"component": "llm", "data": data})
|
||||
return resp
|
||||
|
||||
def _llm_handler(self, resp, tools_dict, messages_combine, log_context):
|
||||
def _llm_handler(self, resp, tools_dict, messages_combine, log_context, attachments=None):
|
||||
logger.info(f"Handling LLM response with {len(attachments) if attachments else 0} attachments")
|
||||
resp = self.llm_handler.handle_response(
|
||||
self, resp, tools_dict, messages_combine
|
||||
self, resp, tools_dict, messages_combine, attachments=attachments
|
||||
)
|
||||
if log_context:
|
||||
data = build_stack_data(self.llm_handler)
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
import json
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from application.logging import build_stack_data
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LLMHandler(ABC):
|
||||
def __init__(self):
|
||||
@@ -10,12 +13,51 @@ class LLMHandler(ABC):
|
||||
self.tool_calls = []
|
||||
|
||||
@abstractmethod
|
||||
def handle_response(self, agent, resp, tools_dict, messages, **kwargs):
|
||||
def handle_response(self, agent, resp, tools_dict, messages, attachments=None, **kwargs):
|
||||
pass
|
||||
|
||||
def prepare_messages_with_attachments(self, agent, messages, attachments=None):
|
||||
"""
|
||||
Prepare messages with attachment content if available.
|
||||
|
||||
Args:
|
||||
agent: The current agent instance.
|
||||
messages (list): List of message dictionaries.
|
||||
attachments (list): List of attachment dictionaries with content.
|
||||
|
||||
Returns:
|
||||
list: Messages with attachment context added.
|
||||
"""
|
||||
if not attachments:
|
||||
return messages
|
||||
|
||||
logger.info(f"Preparing messages with {len(attachments)} attachments")
|
||||
|
||||
# If the LLM has its own attachment handling, use that
|
||||
if hasattr(agent.llm, "prepare_messages_with_attachments"):
|
||||
return agent.llm.prepare_messages_with_attachments(messages, attachments)
|
||||
|
||||
# Otherwise, use a generic approach:
|
||||
# Insert attachment context after system messages, before user messages.
|
||||
attachment_context = []
|
||||
for attachment in attachments:
|
||||
logger.info(f"Adding attachment {attachment.get('id')} to context")
|
||||
attachment_context.append({
|
||||
"role": "system",
|
||||
"content": f"The user has attached a file with the following content:\n\n{attachment['content']}"
|
||||
})
|
||||
|
||||
system_messages = [msg for msg in messages if msg.get("role") == "system"]
|
||||
user_messages = [msg for msg in messages if msg.get("role") != "system"]
|
||||
|
||||
return system_messages + attachment_context + user_messages
|
||||
|
||||
|
||||
class OpenAILLMHandler(LLMHandler):
|
||||
def handle_response(self, agent, resp, tools_dict, messages, stream: bool = True):
|
||||
def handle_response(self, agent, resp, tools_dict, messages, attachments=None, stream: bool = True):
|
||||
|
||||
messages = self.prepare_messages_with_attachments(agent, messages, attachments)
|
||||
|
||||
if not stream:
|
||||
while hasattr(resp, "finish_reason") and resp.finish_reason == "tool_calls":
|
||||
message = json.loads(resp.model_dump_json())["message"]
|
||||
@@ -168,8 +210,10 @@ class OpenAILLMHandler(LLMHandler):
|
||||
|
||||
|
||||
class GoogleLLMHandler(LLMHandler):
|
||||
def handle_response(self, agent, resp, tools_dict, messages, stream: bool = True):
|
||||
def handle_response(self, agent, resp, tools_dict, messages, attachments=None, stream: bool = True):
|
||||
from google.genai import types
|
||||
|
||||
messages = self.prepare_messages_with_attachments(agent, messages, attachments)
|
||||
|
||||
while True:
|
||||
if not stream:
|
||||
|
||||
Reference in New Issue
Block a user