mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 16:43:16 +00:00
(fix:generation) attach + tools
This commit is contained in:
@@ -46,7 +46,6 @@ class GoogleLLM(BaseLLM):
|
||||
return messages
|
||||
|
||||
prepared_messages = messages.copy()
|
||||
logging.info(f"GoogleLLM: Initial messages before attachment processing: {json.dumps(prepared_messages, indent=2)}")
|
||||
|
||||
# Find the user message to attach files to the last one
|
||||
user_message_index = None
|
||||
@@ -95,7 +94,6 @@ class GoogleLLM(BaseLLM):
|
||||
"files": files
|
||||
})
|
||||
|
||||
logging.info(f"GoogleLLM: Final prepared messages: {json.dumps(prepared_messages, indent=2)}")
|
||||
return prepared_messages
|
||||
|
||||
def _upload_file_to_google(self, attachment):
|
||||
@@ -149,7 +147,6 @@ class GoogleLLM(BaseLLM):
|
||||
raise
|
||||
|
||||
def _clean_messages_google(self, messages):
|
||||
logging.info(f"GoogleLLM: Starting message cleaning. Input messages: {json.dumps(messages, indent=2)}")
|
||||
cleaned_messages = []
|
||||
for message in messages:
|
||||
role = message.get("role")
|
||||
@@ -292,23 +289,74 @@ class GoogleLLM(BaseLLM):
|
||||
if tools:
|
||||
cleaned_tools = self._clean_tools_format(tools)
|
||||
config.tools = cleaned_tools
|
||||
|
||||
|
||||
# Check if we have both tools and file attachments
|
||||
has_attachments = False
|
||||
for message in messages:
|
||||
for part in message.parts:
|
||||
if hasattr(part, 'file_data') and part.file_data is not None:
|
||||
has_attachments = True
|
||||
break
|
||||
if has_attachments:
|
||||
break
|
||||
|
||||
logging.info(f"GoogleLLM: Starting stream generation. Model: {model}, Messages: {json.dumps(messages, default=str)}, Has attachments: {has_attachments}")
|
||||
|
||||
response = client.models.generate_content_stream(
|
||||
model=model,
|
||||
contents=messages,
|
||||
config=config,
|
||||
)
|
||||
|
||||
# Track if we've seen any function calls
|
||||
function_call_seen = False
|
||||
|
||||
for chunk in response:
|
||||
if hasattr(chunk, "candidates") and chunk.candidates:
|
||||
for candidate in chunk.candidates:
|
||||
if candidate.content and candidate.content.parts:
|
||||
for part in candidate.content.parts:
|
||||
if part.function_call:
|
||||
function_call_seen = True
|
||||
yield part
|
||||
elif part.text:
|
||||
yield part.text
|
||||
elif hasattr(chunk, "text"):
|
||||
yield chunk.text
|
||||
|
||||
if has_attachments and tools and function_call_seen:
|
||||
logging.info("GoogleLLM: Detected both attachments and function calls. Making additional call for final response.")
|
||||
|
||||
last_user_message_index = -1
|
||||
for i, message in enumerate(messages):
|
||||
if message.role == 'user':
|
||||
last_user_message_index = i
|
||||
|
||||
if last_user_message_index >= 0:
|
||||
text_parts = []
|
||||
for part in messages[last_user_message_index].parts:
|
||||
if hasattr(part, 'text') and part.text is not None:
|
||||
text_parts.append(part)
|
||||
|
||||
if text_parts:
|
||||
messages[last_user_message_index].parts = text_parts
|
||||
follow_up_response = client.models.generate_content_stream(
|
||||
model=model,
|
||||
contents=messages,
|
||||
config=config,
|
||||
)
|
||||
|
||||
for chunk in follow_up_response:
|
||||
if hasattr(chunk, "candidates") and chunk.candidates:
|
||||
for candidate in chunk.candidates:
|
||||
if candidate.content and candidate.content.parts:
|
||||
for part in candidate.content.parts:
|
||||
if part.text:
|
||||
logging.info(f"GoogleLLM: Follow-up response text: {part.text[:50]}...")
|
||||
yield part.text
|
||||
elif hasattr(chunk, "text"):
|
||||
logging.info(f"GoogleLLM: Follow-up response text: {chunk.text[:50]}...")
|
||||
yield chunk.text
|
||||
|
||||
def _supports_tools(self):
|
||||
return True
|
||||
|
||||
Reference in New Issue
Block a user