mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 08:33:20 +00:00
(feat:attach) fallback strategy to process docs
This commit is contained in:
@@ -23,7 +23,7 @@ class BaseAgent(ABC):
|
|||||||
prompt: str = "",
|
prompt: str = "",
|
||||||
chat_history: Optional[List[Dict]] = None,
|
chat_history: Optional[List[Dict]] = None,
|
||||||
decoded_token: Optional[Dict] = None,
|
decoded_token: Optional[Dict] = None,
|
||||||
attachments: Optional[str]=None,
|
attachments: Optional[List[Dict]]=None,
|
||||||
):
|
):
|
||||||
self.endpoint = endpoint
|
self.endpoint = endpoint
|
||||||
self.llm_name = llm_name
|
self.llm_name = llm_name
|
||||||
@@ -44,7 +44,7 @@ class BaseAgent(ABC):
|
|||||||
decoded_token=decoded_token,
|
decoded_token=decoded_token,
|
||||||
)
|
)
|
||||||
self.llm_handler = get_llm_handler(llm_name)
|
self.llm_handler = get_llm_handler(llm_name)
|
||||||
set.attachments = attachments or []
|
self.attachments = attachments or []
|
||||||
|
|
||||||
@log_activity()
|
@log_activity()
|
||||||
def gen(
|
def gen(
|
||||||
@@ -243,8 +243,9 @@ class BaseAgent(ABC):
|
|||||||
tools_dict: Dict,
|
tools_dict: Dict,
|
||||||
messages: List[Dict],
|
messages: List[Dict],
|
||||||
log_context: Optional[LogContext] = None,
|
log_context: Optional[LogContext] = None,
|
||||||
|
attachments: Optional[List[Dict]] = None
|
||||||
):
|
):
|
||||||
resp = self.llm_handler.handle_response(self, resp, tools_dict, messages)
|
resp = self.llm_handler.handle_response(self, resp, tools_dict, messages, attachments)
|
||||||
if log_context:
|
if log_context:
|
||||||
data = build_stack_data(self.llm_handler)
|
data = build_stack_data(self.llm_handler)
|
||||||
log_context.stacks.append({"component": "llm_handler", "data": data})
|
log_context.stacks.append({"component": "llm_handler", "data": data})
|
||||||
|
|||||||
@@ -26,38 +26,48 @@ class LLMHandler(ABC):
|
|||||||
attachments (list): List of attachment dictionaries with content.
|
attachments (list): List of attachment dictionaries with content.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list: Messages with attachment context added.
|
list: Messages with attachment context added to the system prompt.
|
||||||
"""
|
"""
|
||||||
if not attachments:
|
if not attachments:
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
logger.info(f"Preparing messages with {len(attachments)} attachments")
|
logger.info(f"Preparing messages with {len(attachments)} attachments")
|
||||||
|
|
||||||
# If the LLM has its own attachment handling, use that
|
# Check if the LLM has its own custom attachment handling implementation
|
||||||
if hasattr(agent.llm, "prepare_messages_with_attachments"):
|
if hasattr(agent.llm, "prepare_messages_with_attachments") and agent.llm.__class__.__name__ != "BaseLLM":
|
||||||
|
logger.info(f"Using {agent.llm.__class__.__name__}'s own prepare_messages_with_attachments method")
|
||||||
return agent.llm.prepare_messages_with_attachments(messages, attachments)
|
return agent.llm.prepare_messages_with_attachments(messages, attachments)
|
||||||
|
|
||||||
# Otherwise, use a generic approach:
|
# Otherwise, append attachment content to the system prompt
|
||||||
# Insert attachment context after system messages, before user messages.
|
prepared_messages = messages.copy()
|
||||||
attachment_context = []
|
|
||||||
|
# Build attachment content string
|
||||||
|
attachment_texts = []
|
||||||
for attachment in attachments:
|
for attachment in attachments:
|
||||||
logger.info(f"Adding attachment {attachment.get('id')} to context")
|
logger.info(f"Adding attachment {attachment.get('id')} to context")
|
||||||
attachment_context.append({
|
if 'content' in attachment:
|
||||||
"role": "system",
|
attachment_texts.append(f"Attached file content:\n\n{attachment['content']}")
|
||||||
"content": f"The user has attached a file with the following content:\n\n{attachment['content']}"
|
|
||||||
})
|
|
||||||
|
|
||||||
system_messages = [msg for msg in messages if msg.get("role") == "system"]
|
|
||||||
user_messages = [msg for msg in messages if msg.get("role") != "system"]
|
|
||||||
|
|
||||||
return system_messages + attachment_context + user_messages
|
if attachment_texts:
|
||||||
|
combined_attachment_text = "\n\n".join(attachment_texts)
|
||||||
|
|
||||||
|
system_found = False
|
||||||
|
for i in range(len(prepared_messages)):
|
||||||
|
if prepared_messages[i].get("role") == "system":
|
||||||
|
prepared_messages[i]["content"] += f"\n\n{combined_attachment_text}"
|
||||||
|
system_found = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if not system_found:
|
||||||
|
prepared_messages.insert(0, {"role": "system", "content": combined_attachment_text})
|
||||||
|
|
||||||
|
return prepared_messages
|
||||||
|
|
||||||
class OpenAILLMHandler(LLMHandler):
|
class OpenAILLMHandler(LLMHandler):
|
||||||
def handle_response(self, agent, resp, tools_dict, messages, attachments=None, stream: bool = True):
|
def handle_response(self, agent, resp, tools_dict, messages, attachments=None, stream: bool = True):
|
||||||
|
|
||||||
messages = self.prepare_messages_with_attachments(agent, messages, attachments)
|
|
||||||
|
|
||||||
|
messages = self.prepare_messages_with_attachments(agent, messages, attachments)
|
||||||
|
logger.info(f"Messages with attachments: {messages}")
|
||||||
if not stream:
|
if not stream:
|
||||||
while hasattr(resp, "finish_reason") and resp.finish_reason == "tool_calls":
|
while hasattr(resp, "finish_reason") and resp.finish_reason == "tool_calls":
|
||||||
message = json.loads(resp.model_dump_json())["message"]
|
message = json.loads(resp.model_dump_json())["message"]
|
||||||
@@ -96,6 +106,7 @@ class OpenAILLMHandler(LLMHandler):
|
|||||||
{"role": "tool", "content": [function_response_dict]}
|
{"role": "tool", "content": [function_response_dict]}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
messages = self.prepare_messages_with_attachments(agent, messages, attachments)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
messages.append(
|
messages.append(
|
||||||
{
|
{
|
||||||
@@ -111,6 +122,7 @@ class OpenAILLMHandler(LLMHandler):
|
|||||||
return resp
|
return resp
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
tool_calls = {}
|
tool_calls = {}
|
||||||
for chunk in resp:
|
for chunk in resp:
|
||||||
@@ -202,7 +214,8 @@ class OpenAILLMHandler(LLMHandler):
|
|||||||
return
|
return
|
||||||
elif isinstance(chunk, str) and len(chunk) == 0:
|
elif isinstance(chunk, str) and len(chunk) == 0:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
logger.info(f"Regenerating with messages: {messages}")
|
||||||
resp = agent.llm.gen_stream(
|
resp = agent.llm.gen_stream(
|
||||||
model=agent.gpt_model, messages=messages, tools=agent.tools
|
model=agent.gpt_model, messages=messages, tools=agent.tools
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -55,6 +55,3 @@ class BaseLLM(ABC):
|
|||||||
|
|
||||||
def _supports_tools(self):
|
def _supports_tools(self):
|
||||||
raise NotImplementedError("Subclass must implement _supports_tools method")
|
raise NotImplementedError("Subclass must implement _supports_tools method")
|
||||||
|
|
||||||
def prepare_messages_with_attachments(self, messages, attachments=None):
|
|
||||||
return messages
|
|
||||||
|
|||||||
Reference in New Issue
Block a user