mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 08:33:20 +00:00
feat: streaming responses with function call
This commit is contained in:
@@ -104,7 +104,8 @@ class ClassicAgent(BaseAgent):
|
|||||||
model=self.gpt_model, messages=messages_combine, tools=self.tools
|
model=self.gpt_model, messages=messages_combine, tools=self.tools
|
||||||
)
|
)
|
||||||
for line in completion:
|
for line in completion:
|
||||||
yield {"answer": line}
|
if isinstance(line, str):
|
||||||
|
yield {"answer": line}
|
||||||
|
|
||||||
yield {"tool_calls": self.tool_calls.copy()}
|
yield {"tool_calls": self.tool_calls.copy()}
|
||||||
|
|
||||||
@@ -116,7 +117,7 @@ class ClassicAgent(BaseAgent):
|
|||||||
return retrieved_data
|
return retrieved_data
|
||||||
|
|
||||||
def _llm_gen(self, messages_combine, log_context):
|
def _llm_gen(self, messages_combine, log_context):
|
||||||
resp = self.llm.gen(
|
resp = self.llm.gen_stream(
|
||||||
model=self.gpt_model, messages=messages_combine, tools=self.tools
|
model=self.gpt_model, messages=messages_combine, tools=self.tools
|
||||||
)
|
)
|
||||||
if log_context:
|
if log_context:
|
||||||
@@ -131,5 +132,4 @@ class ClassicAgent(BaseAgent):
|
|||||||
if log_context:
|
if log_context:
|
||||||
data = build_stack_data(self.llm_handler)
|
data = build_stack_data(self.llm_handler)
|
||||||
log_context.stacks.append({"component": "llm_handler", "data": data})
|
log_context.stacks.append({"component": "llm_handler", "data": data})
|
||||||
|
|
||||||
return resp
|
return resp
|
||||||
|
|||||||
@@ -15,84 +15,221 @@ class LLMHandler(ABC):
|
|||||||
|
|
||||||
|
|
||||||
class OpenAILLMHandler(LLMHandler):
|
class OpenAILLMHandler(LLMHandler):
|
||||||
def handle_response(self, agent, resp, tools_dict, messages):
|
def handle_response(self, agent, resp, tools_dict, messages, stream: bool = True):
|
||||||
while resp.finish_reason == "tool_calls":
|
if not stream:
|
||||||
message = json.loads(resp.model_dump_json())["message"]
|
while hasattr(resp, "finish_reason") and resp.finish_reason == "tool_calls":
|
||||||
keys_to_remove = {"audio", "function_call", "refusal"}
|
message = json.loads(resp.model_dump_json())["message"]
|
||||||
filtered_data = {
|
keys_to_remove = {"audio", "function_call", "refusal"}
|
||||||
k: v for k, v in message.items() if k not in keys_to_remove
|
filtered_data = {
|
||||||
}
|
k: v for k, v in message.items() if k not in keys_to_remove
|
||||||
messages.append(filtered_data)
|
}
|
||||||
|
messages.append(filtered_data)
|
||||||
|
|
||||||
tool_calls = resp.message.tool_calls
|
tool_calls = resp.message.tool_calls
|
||||||
for call in tool_calls:
|
for call in tool_calls:
|
||||||
try:
|
try:
|
||||||
self.tool_calls.append(call)
|
self.tool_calls.append(call)
|
||||||
tool_response, call_id = agent._execute_tool_action(
|
tool_response, call_id = agent._execute_tool_action(
|
||||||
tools_dict, call
|
tools_dict, call
|
||||||
)
|
)
|
||||||
function_call_dict = {
|
function_call_dict = {
|
||||||
"function_call": {
|
"function_call": {
|
||||||
"name": call.function.name,
|
"name": call.function.name,
|
||||||
"args": call.function.arguments,
|
"args": call.function.arguments,
|
||||||
"call_id": call_id,
|
"call_id": call_id,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
function_response_dict = {
|
||||||
function_response_dict = {
|
"function_response": {
|
||||||
"function_response": {
|
"name": call.function.name,
|
||||||
"name": call.function.name,
|
"response": {"result": tool_response},
|
||||||
"response": {"result": tool_response},
|
"call_id": call_id,
|
||||||
"call_id": call_id,
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
messages.append(
|
messages.append(
|
||||||
{"role": "assistant", "content": [function_call_dict]}
|
{"role": "assistant", "content": [function_call_dict]}
|
||||||
)
|
)
|
||||||
messages.append(
|
messages.append(
|
||||||
{"role": "tool", "content": [function_response_dict]}
|
{"role": "tool", "content": [function_response_dict]}
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
messages.append(
|
messages.append(
|
||||||
{
|
{
|
||||||
"role": "tool",
|
"role": "tool",
|
||||||
"content": f"Error executing tool: {str(e)}",
|
"content": f"Error executing tool: {str(e)}",
|
||||||
"tool_call_id": call_id,
|
"tool_call_id": call_id,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
resp = agent.llm.gen(
|
resp = agent.llm.gen_stream(
|
||||||
model=agent.gpt_model, messages=messages, tools=agent.tools
|
model=agent.gpt_model, messages=messages, tools=agent.tools
|
||||||
)
|
)
|
||||||
self.llm_calls.append(build_stack_data(agent.llm))
|
self.llm_calls.append(build_stack_data(agent.llm))
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
|
else:
|
||||||
|
while True:
|
||||||
|
tool_calls = {}
|
||||||
|
for chunk in resp:
|
||||||
|
if isinstance(chunk, str):
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
chunk_delta = chunk.delta
|
||||||
|
|
||||||
|
if (
|
||||||
|
hasattr(chunk_delta, "tool_calls")
|
||||||
|
and chunk_delta.tool_calls is not None
|
||||||
|
):
|
||||||
|
for tool_call in chunk_delta.tool_calls:
|
||||||
|
index = tool_call.index
|
||||||
|
if index not in tool_calls:
|
||||||
|
tool_calls[index] = {
|
||||||
|
"id": "",
|
||||||
|
"function": {"name": "", "arguments": ""},
|
||||||
|
}
|
||||||
|
|
||||||
|
current = tool_calls[index]
|
||||||
|
if tool_call.id:
|
||||||
|
current["id"] = tool_call.id
|
||||||
|
if tool_call.function.name:
|
||||||
|
current["function"][
|
||||||
|
"name"
|
||||||
|
] = tool_call.function.name
|
||||||
|
if tool_call.function.arguments:
|
||||||
|
current["function"][
|
||||||
|
"arguments"
|
||||||
|
] += tool_call.function.arguments
|
||||||
|
tool_calls[index] = current
|
||||||
|
|
||||||
|
if (
|
||||||
|
hasattr(chunk, "finish_reason")
|
||||||
|
and chunk.finish_reason == "tool_calls"
|
||||||
|
):
|
||||||
|
for index in sorted(tool_calls.keys()):
|
||||||
|
call = tool_calls[index]
|
||||||
|
try:
|
||||||
|
self.tool_calls.append(call)
|
||||||
|
tool_response, call_id = agent._execute_tool_action(
|
||||||
|
tools_dict, call
|
||||||
|
)
|
||||||
|
|
||||||
|
function_call_dict = {
|
||||||
|
"function_call": {
|
||||||
|
"name": call["function"]["name"],
|
||||||
|
"args": call["function"]["arguments"],
|
||||||
|
"call_id": call["id"],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
function_response_dict = {
|
||||||
|
"function_response": {
|
||||||
|
"name": call["function"]["name"],
|
||||||
|
"response": {"result": tool_response},
|
||||||
|
"call_id": call["id"],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
messages.append(
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [function_call_dict],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
messages.append(
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"content": [function_response_dict],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
messages.append(
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": f"Error executing tool: {str(e)}",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
tool_calls = {}
|
||||||
|
|
||||||
|
if (
|
||||||
|
hasattr(chunk, "finish_reason")
|
||||||
|
and chunk.finish_reason == "stop"
|
||||||
|
):
|
||||||
|
return
|
||||||
|
|
||||||
|
resp = agent.llm.gen_stream(
|
||||||
|
model=agent.gpt_model, messages=messages, tools=agent.tools
|
||||||
|
)
|
||||||
|
self.llm_calls.append(build_stack_data(agent.llm))
|
||||||
|
|
||||||
|
|
||||||
class GoogleLLMHandler(LLMHandler):
|
class GoogleLLMHandler(LLMHandler):
|
||||||
def handle_response(self, agent, resp, tools_dict, messages):
|
def handle_response(self, agent, resp, tools_dict, messages, stream: bool = True):
|
||||||
from google.genai import types
|
from google.genai import types
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
response = agent.llm.gen(
|
if not stream:
|
||||||
model=agent.gpt_model, messages=messages, tools=agent.tools
|
response = agent.llm.gen(
|
||||||
)
|
model=agent.gpt_model, messages=messages, tools=agent.tools
|
||||||
self.llm_calls.append(build_stack_data(agent.llm))
|
)
|
||||||
if response.candidates and response.candidates[0].content.parts:
|
self.llm_calls.append(build_stack_data(agent.llm))
|
||||||
|
if response.candidates and response.candidates[0].content.parts:
|
||||||
|
tool_call_found = False
|
||||||
|
for part in response.candidates[0].content.parts:
|
||||||
|
if part.function_call:
|
||||||
|
tool_call_found = True
|
||||||
|
self.tool_calls.append(part.function_call)
|
||||||
|
tool_response, call_id = agent._execute_tool_action(
|
||||||
|
tools_dict, part.function_call
|
||||||
|
)
|
||||||
|
function_response_part = types.Part.from_function_response(
|
||||||
|
name=part.function_call.name,
|
||||||
|
response={"result": tool_response},
|
||||||
|
)
|
||||||
|
|
||||||
|
messages.append(
|
||||||
|
{"role": "model", "content": [part.to_json_dict()]}
|
||||||
|
)
|
||||||
|
messages.append(
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"content": [function_response_part.to_json_dict()],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
not tool_call_found
|
||||||
|
and response.candidates[0].content.parts
|
||||||
|
and response.candidates[0].content.parts[0].text
|
||||||
|
):
|
||||||
|
return response.candidates[0].content.parts[0].text
|
||||||
|
elif not tool_call_found:
|
||||||
|
return response.candidates[0].content.parts
|
||||||
|
|
||||||
|
else:
|
||||||
|
return response
|
||||||
|
|
||||||
|
else:
|
||||||
|
response = agent.llm.gen_stream(
|
||||||
|
model=agent.gpt_model, messages=messages, tools=agent.tools
|
||||||
|
)
|
||||||
|
self.llm_calls.append(build_stack_data(agent.llm))
|
||||||
|
|
||||||
tool_call_found = False
|
tool_call_found = False
|
||||||
for part in response.candidates[0].content.parts:
|
for result in response:
|
||||||
if part.function_call:
|
if hasattr(result, "function_call"):
|
||||||
tool_call_found = True
|
tool_call_found = True
|
||||||
self.tool_calls.append(part.function_call)
|
self.tool_calls.append(result.function_call)
|
||||||
tool_response, call_id = agent._execute_tool_action(
|
tool_response, call_id = agent._execute_tool_action(
|
||||||
tools_dict, part.function_call
|
tools_dict, result.function_call
|
||||||
)
|
)
|
||||||
function_response_part = types.Part.from_function_response(
|
function_response_part = types.Part.from_function_response(
|
||||||
name=part.function_call.name,
|
name=result.function_call.name,
|
||||||
response={"result": tool_response},
|
response={"result": tool_response},
|
||||||
)
|
)
|
||||||
|
|
||||||
messages.append(
|
messages.append(
|
||||||
{"role": "model", "content": [part.to_json_dict()]}
|
{"role": "model", "content": [result.to_json_dict()]}
|
||||||
)
|
)
|
||||||
messages.append(
|
messages.append(
|
||||||
{
|
{
|
||||||
@@ -101,17 +238,8 @@ class GoogleLLMHandler(LLMHandler):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if not tool_call_found:
|
||||||
not tool_call_found
|
return response
|
||||||
and response.candidates[0].content.parts
|
|
||||||
and response.candidates[0].content.parts[0].text
|
|
||||||
):
|
|
||||||
return response.candidates[0].content.parts[0].text
|
|
||||||
elif not tool_call_found:
|
|
||||||
return response.candidates[0].content.parts
|
|
||||||
|
|
||||||
else:
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
def get_llm_handler(llm_type):
|
def get_llm_handler(llm_type):
|
||||||
|
|||||||
@@ -31,7 +31,6 @@ class CryptoPriceTool(Tool):
|
|||||||
response = requests.get(url)
|
response = requests.get(url)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
data = response.json()
|
data = response.json()
|
||||||
# data will be like {"USD": <price>} if the call is successful
|
|
||||||
if currency.upper() in data:
|
if currency.upper() in data:
|
||||||
return {
|
return {
|
||||||
"status_code": response.status_code,
|
"status_code": response.status_code,
|
||||||
|
|||||||
@@ -14,9 +14,20 @@ class ToolActionParser:
|
|||||||
return parser(call)
|
return parser(call)
|
||||||
|
|
||||||
def _parse_openai_llm(self, call):
|
def _parse_openai_llm(self, call):
|
||||||
call_args = json.loads(call.function.arguments)
|
if isinstance(call, dict):
|
||||||
tool_id = call.function.name.split("_")[-1]
|
try:
|
||||||
action_name = call.function.name.rsplit("_", 1)[0]
|
call_args = json.loads(call["function"]["arguments"])
|
||||||
|
tool_id = call["function"]["name"].split("_")[-1]
|
||||||
|
action_name = call["function"]["name"].rsplit("_", 1)[0]
|
||||||
|
except (KeyError, TypeError) as e:
|
||||||
|
return None, None, None
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
call_args = json.loads(call.function.arguments)
|
||||||
|
tool_id = call.function.name.split("_")[-1]
|
||||||
|
action_name = call.function.name.rsplit("_", 1)[0]
|
||||||
|
except (AttributeError, TypeError) as e:
|
||||||
|
return None, None, None
|
||||||
return tool_id, action_name, call_args
|
return tool_id, action_name, call_args
|
||||||
|
|
||||||
def _parse_google_llm(self, call):
|
def _parse_google_llm(self, call):
|
||||||
|
|||||||
@@ -152,7 +152,15 @@ class GoogleLLM(BaseLLM):
|
|||||||
config=config,
|
config=config,
|
||||||
)
|
)
|
||||||
for chunk in response:
|
for chunk in response:
|
||||||
if chunk.text is not None:
|
if hasattr(chunk, "candidates") and chunk.candidates:
|
||||||
|
for candidate in chunk.candidates:
|
||||||
|
if candidate.content and candidate.content.parts:
|
||||||
|
for part in candidate.content.parts:
|
||||||
|
if part.function_call:
|
||||||
|
yield part
|
||||||
|
elif part.text:
|
||||||
|
yield part.text
|
||||||
|
elif hasattr(chunk, "text"):
|
||||||
yield chunk.text
|
yield chunk.text
|
||||||
|
|
||||||
def _supports_tools(self):
|
def _supports_tools(self):
|
||||||
|
|||||||
@@ -111,13 +111,24 @@ class OpenAILLM(BaseLLM):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
messages = self._clean_messages_openai(messages)
|
messages = self._clean_messages_openai(messages)
|
||||||
response = self.client.chat.completions.create(
|
if tools:
|
||||||
model=model, messages=messages, stream=stream, **kwargs
|
response = self.client.chat.completions.create(
|
||||||
)
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
stream=stream,
|
||||||
|
tools=tools,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
response = self.client.chat.completions.create(
|
||||||
|
model=model, messages=messages, stream=stream, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
for line in response:
|
for line in response:
|
||||||
if line.choices[0].delta.content is not None:
|
if line.choices[0].delta.content is not None:
|
||||||
yield line.choices[0].delta.content
|
yield line.choices[0].delta.content
|
||||||
|
else:
|
||||||
|
yield line.choices[0]
|
||||||
|
|
||||||
def _supports_tools(self):
|
def _supports_tools(self):
|
||||||
return True
|
return True
|
||||||
|
|||||||
Reference in New Issue
Block a user