mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-12-02 18:13:13 +00:00
fix: streaming with tools google and openai halfway
This commit is contained in:
@@ -44,10 +44,13 @@ class ClassicAgent(BaseAgent):
|
||||
):
|
||||
yield {"answer": resp.message.content}
|
||||
else:
|
||||
completion = self.llm.gen_stream(
|
||||
model=self.gpt_model, messages=messages, tools=self.tools
|
||||
)
|
||||
for line in completion:
|
||||
# completion = self.llm.gen_stream(
|
||||
# model=self.gpt_model, messages=messages, tools=self.tools
|
||||
# )
|
||||
# log type of resp
|
||||
logger.info(f"Response type: {type(resp)}")
|
||||
logger.info(f"Response: {resp}")
|
||||
for line in resp:
|
||||
if isinstance(line, str):
|
||||
yield {"answer": line}
|
||||
|
||||
|
||||
@@ -160,12 +160,14 @@ class OpenAILLMHandler(LLMHandler):
|
||||
return resp
|
||||
|
||||
else:
|
||||
|
||||
text_buffer = ""
|
||||
while True:
|
||||
tool_calls = {}
|
||||
for chunk in resp:
|
||||
logger.info(f"Chunk: {chunk}")
|
||||
if isinstance(chunk, str) and len(chunk) > 0:
|
||||
return
|
||||
yield chunk
|
||||
continue
|
||||
elif hasattr(chunk, "delta"):
|
||||
chunk_delta = chunk.delta
|
||||
|
||||
@@ -244,12 +246,17 @@ class OpenAILLMHandler(LLMHandler):
|
||||
}
|
||||
)
|
||||
tool_calls = {}
|
||||
if hasattr(chunk_delta, "content") and chunk_delta.content:
|
||||
# Add to buffer or yield immediately based on your preference
|
||||
text_buffer += chunk_delta.content
|
||||
yield text_buffer
|
||||
text_buffer = ""
|
||||
|
||||
if (
|
||||
hasattr(chunk, "finish_reason")
|
||||
and chunk.finish_reason == "stop"
|
||||
):
|
||||
return
|
||||
return resp
|
||||
elif isinstance(chunk, str) and len(chunk) == 0:
|
||||
continue
|
||||
|
||||
@@ -265,7 +272,7 @@ class GoogleLLMHandler(LLMHandler):
|
||||
from google.genai import types
|
||||
|
||||
messages = self.prepare_messages_with_attachments(agent, messages, attachments)
|
||||
|
||||
|
||||
while True:
|
||||
if not stream:
|
||||
response = agent.llm.gen(
|
||||
@@ -336,6 +343,9 @@ class GoogleLLMHandler(LLMHandler):
|
||||
"content": [function_response_part.to_json_dict()],
|
||||
}
|
||||
)
|
||||
else:
|
||||
tool_call_found = False
|
||||
yield result
|
||||
|
||||
if not tool_call_found:
|
||||
return response
|
||||
|
||||
@@ -323,40 +323,6 @@ class GoogleLLM(BaseLLM):
|
||||
yield part.text
|
||||
elif hasattr(chunk, "text"):
|
||||
yield chunk.text
|
||||
|
||||
if has_attachments and tools and function_call_seen:
|
||||
logging.info("GoogleLLM: Detected both attachments and function calls. Making additional call for final response.")
|
||||
|
||||
last_user_message_index = -1
|
||||
for i, message in enumerate(messages):
|
||||
if message.role == 'user':
|
||||
last_user_message_index = i
|
||||
|
||||
if last_user_message_index >= 0:
|
||||
text_parts = []
|
||||
for part in messages[last_user_message_index].parts:
|
||||
if hasattr(part, 'text') and part.text is not None:
|
||||
text_parts.append(part)
|
||||
|
||||
if text_parts:
|
||||
messages[last_user_message_index].parts = text_parts
|
||||
follow_up_response = client.models.generate_content_stream(
|
||||
model=model,
|
||||
contents=messages,
|
||||
config=config,
|
||||
)
|
||||
|
||||
for chunk in follow_up_response:
|
||||
if hasattr(chunk, "candidates") and chunk.candidates:
|
||||
for candidate in chunk.candidates:
|
||||
if candidate.content and candidate.content.parts:
|
||||
for part in candidate.content.parts:
|
||||
if part.text:
|
||||
logging.info(f"GoogleLLM: Follow-up response text: {part.text[:50]}...")
|
||||
yield part.text
|
||||
elif hasattr(chunk, "text"):
|
||||
logging.info(f"GoogleLLM: Follow-up response text: {chunk.text[:50]}...")
|
||||
yield chunk.text
|
||||
|
||||
def _supports_tools(self):
|
||||
return True
|
||||
|
||||
Reference in New Issue
Block a user