(feat:attach) fallback strategy to process docs

This commit is contained in:
ManishMadan2882
2025-04-03 03:26:37 +05:30
parent 19d68252cd
commit f9ad4c068a
3 changed files with 37 additions and 26 deletions

View File

@@ -23,7 +23,7 @@ class BaseAgent(ABC):
prompt: str = "", prompt: str = "",
chat_history: Optional[List[Dict]] = None, chat_history: Optional[List[Dict]] = None,
decoded_token: Optional[Dict] = None, decoded_token: Optional[Dict] = None,
attachments: Optional[str]=None, attachments: Optional[List[Dict]]=None,
): ):
self.endpoint = endpoint self.endpoint = endpoint
self.llm_name = llm_name self.llm_name = llm_name
@@ -44,7 +44,7 @@ class BaseAgent(ABC):
decoded_token=decoded_token, decoded_token=decoded_token,
) )
self.llm_handler = get_llm_handler(llm_name) self.llm_handler = get_llm_handler(llm_name)
set.attachments = attachments or [] self.attachments = attachments or []
@log_activity() @log_activity()
def gen( def gen(
@@ -243,8 +243,9 @@ class BaseAgent(ABC):
tools_dict: Dict, tools_dict: Dict,
messages: List[Dict], messages: List[Dict],
log_context: Optional[LogContext] = None, log_context: Optional[LogContext] = None,
attachments: Optional[List[Dict]] = None
): ):
resp = self.llm_handler.handle_response(self, resp, tools_dict, messages) resp = self.llm_handler.handle_response(self, resp, tools_dict, messages, attachments)
if log_context: if log_context:
data = build_stack_data(self.llm_handler) data = build_stack_data(self.llm_handler)
log_context.stacks.append({"component": "llm_handler", "data": data}) log_context.stacks.append({"component": "llm_handler", "data": data})

View File

@@ -26,38 +26,48 @@ class LLMHandler(ABC):
attachments (list): List of attachment dictionaries with content. attachments (list): List of attachment dictionaries with content.
Returns: Returns:
list: Messages with attachment context added. list: Messages with attachment context added to the system prompt.
""" """
if not attachments: if not attachments:
return messages return messages
logger.info(f"Preparing messages with {len(attachments)} attachments") logger.info(f"Preparing messages with {len(attachments)} attachments")
# If the LLM has its own attachment handling, use that # Check if the LLM has its own custom attachment handling implementation
if hasattr(agent.llm, "prepare_messages_with_attachments"): if hasattr(agent.llm, "prepare_messages_with_attachments") and agent.llm.__class__.__name__ != "BaseLLM":
logger.info(f"Using {agent.llm.__class__.__name__}'s own prepare_messages_with_attachments method")
return agent.llm.prepare_messages_with_attachments(messages, attachments) return agent.llm.prepare_messages_with_attachments(messages, attachments)
# Otherwise, use a generic approach: # Otherwise, append attachment content to the system prompt
# Insert attachment context after system messages, before user messages. prepared_messages = messages.copy()
attachment_context = []
# Build attachment content string
attachment_texts = []
for attachment in attachments: for attachment in attachments:
logger.info(f"Adding attachment {attachment.get('id')} to context") logger.info(f"Adding attachment {attachment.get('id')} to context")
attachment_context.append({ if 'content' in attachment:
"role": "system", attachment_texts.append(f"Attached file content:\n\n{attachment['content']}")
"content": f"The user has attached a file with the following content:\n\n{attachment['content']}"
})
system_messages = [msg for msg in messages if msg.get("role") == "system"]
user_messages = [msg for msg in messages if msg.get("role") != "system"]
return system_messages + attachment_context + user_messages if attachment_texts:
combined_attachment_text = "\n\n".join(attachment_texts)
system_found = False
for i in range(len(prepared_messages)):
if prepared_messages[i].get("role") == "system":
prepared_messages[i]["content"] += f"\n\n{combined_attachment_text}"
system_found = True
break
if not system_found:
prepared_messages.insert(0, {"role": "system", "content": combined_attachment_text})
return prepared_messages
class OpenAILLMHandler(LLMHandler): class OpenAILLMHandler(LLMHandler):
def handle_response(self, agent, resp, tools_dict, messages, attachments=None, stream: bool = True): def handle_response(self, agent, resp, tools_dict, messages, attachments=None, stream: bool = True):
messages = self.prepare_messages_with_attachments(agent, messages, attachments)
messages = self.prepare_messages_with_attachments(agent, messages, attachments)
logger.info(f"Messages with attachments: {messages}")
if not stream: if not stream:
while hasattr(resp, "finish_reason") and resp.finish_reason == "tool_calls": while hasattr(resp, "finish_reason") and resp.finish_reason == "tool_calls":
message = json.loads(resp.model_dump_json())["message"] message = json.loads(resp.model_dump_json())["message"]
@@ -96,6 +106,7 @@ class OpenAILLMHandler(LLMHandler):
{"role": "tool", "content": [function_response_dict]} {"role": "tool", "content": [function_response_dict]}
) )
messages = self.prepare_messages_with_attachments(agent, messages, attachments)
except Exception as e: except Exception as e:
messages.append( messages.append(
{ {
@@ -111,6 +122,7 @@ class OpenAILLMHandler(LLMHandler):
return resp return resp
else: else:
while True: while True:
tool_calls = {} tool_calls = {}
for chunk in resp: for chunk in resp:
@@ -202,7 +214,8 @@ class OpenAILLMHandler(LLMHandler):
return return
elif isinstance(chunk, str) and len(chunk) == 0: elif isinstance(chunk, str) and len(chunk) == 0:
continue continue
logger.info(f"Regenerating with messages: {messages}")
resp = agent.llm.gen_stream( resp = agent.llm.gen_stream(
model=agent.gpt_model, messages=messages, tools=agent.tools model=agent.gpt_model, messages=messages, tools=agent.tools
) )

View File

@@ -55,6 +55,3 @@ class BaseLLM(ABC):
def _supports_tools(self): def _supports_tools(self):
raise NotImplementedError("Subclass must implement _supports_tools method") raise NotImplementedError("Subclass must implement _supports_tools method")
def prepare_messages_with_attachments(self, messages, attachments=None):
return messages