feat: streaming responses with function call

This commit is contained in:
Siddhant Rai
2025-03-05 09:02:55 +05:30
parent c6ce4d9374
commit f88c34a0be
6 changed files with 237 additions and 80 deletions

View File

@@ -104,6 +104,7 @@ class ClassicAgent(BaseAgent):
model=self.gpt_model, messages=messages_combine, tools=self.tools model=self.gpt_model, messages=messages_combine, tools=self.tools
) )
for line in completion: for line in completion:
if isinstance(line, str):
yield {"answer": line} yield {"answer": line}
yield {"tool_calls": self.tool_calls.copy()} yield {"tool_calls": self.tool_calls.copy()}
@@ -116,7 +117,7 @@ class ClassicAgent(BaseAgent):
return retrieved_data return retrieved_data
def _llm_gen(self, messages_combine, log_context): def _llm_gen(self, messages_combine, log_context):
resp = self.llm.gen( resp = self.llm.gen_stream(
model=self.gpt_model, messages=messages_combine, tools=self.tools model=self.gpt_model, messages=messages_combine, tools=self.tools
) )
if log_context: if log_context:
@@ -131,5 +132,4 @@ class ClassicAgent(BaseAgent):
if log_context: if log_context:
data = build_stack_data(self.llm_handler) data = build_stack_data(self.llm_handler)
log_context.stacks.append({"component": "llm_handler", "data": data}) log_context.stacks.append({"component": "llm_handler", "data": data})
return resp return resp

View File

@@ -15,8 +15,9 @@ class LLMHandler(ABC):
class OpenAILLMHandler(LLMHandler): class OpenAILLMHandler(LLMHandler):
def handle_response(self, agent, resp, tools_dict, messages): def handle_response(self, agent, resp, tools_dict, messages, stream: bool = True):
while resp.finish_reason == "tool_calls": if not stream:
while hasattr(resp, "finish_reason") and resp.finish_reason == "tool_calls":
message = json.loads(resp.model_dump_json())["message"] message = json.loads(resp.model_dump_json())["message"]
keys_to_remove = {"audio", "function_call", "refusal"} keys_to_remove = {"audio", "function_call", "refusal"}
filtered_data = { filtered_data = {
@@ -61,18 +62,113 @@ class OpenAILLMHandler(LLMHandler):
"tool_call_id": call_id, "tool_call_id": call_id,
} }
) )
resp = agent.llm.gen( resp = agent.llm.gen_stream(
model=agent.gpt_model, messages=messages, tools=agent.tools model=agent.gpt_model, messages=messages, tools=agent.tools
) )
self.llm_calls.append(build_stack_data(agent.llm)) self.llm_calls.append(build_stack_data(agent.llm))
return resp return resp
else:
while True:
tool_calls = {}
for chunk in resp:
if isinstance(chunk, str):
return
else:
chunk_delta = chunk.delta
if (
hasattr(chunk_delta, "tool_calls")
and chunk_delta.tool_calls is not None
):
for tool_call in chunk_delta.tool_calls:
index = tool_call.index
if index not in tool_calls:
tool_calls[index] = {
"id": "",
"function": {"name": "", "arguments": ""},
}
current = tool_calls[index]
if tool_call.id:
current["id"] = tool_call.id
if tool_call.function.name:
current["function"][
"name"
] = tool_call.function.name
if tool_call.function.arguments:
current["function"][
"arguments"
] += tool_call.function.arguments
tool_calls[index] = current
if (
hasattr(chunk, "finish_reason")
and chunk.finish_reason == "tool_calls"
):
for index in sorted(tool_calls.keys()):
call = tool_calls[index]
try:
self.tool_calls.append(call)
tool_response, call_id = agent._execute_tool_action(
tools_dict, call
)
function_call_dict = {
"function_call": {
"name": call["function"]["name"],
"args": call["function"]["arguments"],
"call_id": call["id"],
}
}
function_response_dict = {
"function_response": {
"name": call["function"]["name"],
"response": {"result": tool_response},
"call_id": call["id"],
}
}
messages.append(
{
"role": "assistant",
"content": [function_call_dict],
}
)
messages.append(
{
"role": "tool",
"content": [function_response_dict],
}
)
except Exception as e:
messages.append(
{
"role": "assistant",
"content": f"Error executing tool: {str(e)}",
}
)
tool_calls = {}
if (
hasattr(chunk, "finish_reason")
and chunk.finish_reason == "stop"
):
return
resp = agent.llm.gen_stream(
model=agent.gpt_model, messages=messages, tools=agent.tools
)
self.llm_calls.append(build_stack_data(agent.llm))
class GoogleLLMHandler(LLMHandler): class GoogleLLMHandler(LLMHandler):
def handle_response(self, agent, resp, tools_dict, messages): def handle_response(self, agent, resp, tools_dict, messages, stream: bool = True):
from google.genai import types from google.genai import types
while True: while True:
if not stream:
response = agent.llm.gen( response = agent.llm.gen(
model=agent.gpt_model, messages=messages, tools=agent.tools model=agent.gpt_model, messages=messages, tools=agent.tools
) )
@@ -113,6 +209,38 @@ class GoogleLLMHandler(LLMHandler):
else: else:
return response return response
else:
response = agent.llm.gen_stream(
model=agent.gpt_model, messages=messages, tools=agent.tools
)
self.llm_calls.append(build_stack_data(agent.llm))
tool_call_found = False
for result in response:
if hasattr(result, "function_call"):
tool_call_found = True
self.tool_calls.append(result.function_call)
tool_response, call_id = agent._execute_tool_action(
tools_dict, result.function_call
)
function_response_part = types.Part.from_function_response(
name=result.function_call.name,
response={"result": tool_response},
)
messages.append(
{"role": "model", "content": [result.to_json_dict()]}
)
messages.append(
{
"role": "tool",
"content": [function_response_part.to_json_dict()],
}
)
if not tool_call_found:
return response
def get_llm_handler(llm_type): def get_llm_handler(llm_type):
handlers = { handlers = {

View File

@@ -31,7 +31,6 @@ class CryptoPriceTool(Tool):
response = requests.get(url) response = requests.get(url)
if response.status_code == 200: if response.status_code == 200:
data = response.json() data = response.json()
# data will be like {"USD": <price>} if the call is successful
if currency.upper() in data: if currency.upper() in data:
return { return {
"status_code": response.status_code, "status_code": response.status_code,

View File

@@ -14,9 +14,20 @@ class ToolActionParser:
return parser(call) return parser(call)
def _parse_openai_llm(self, call): def _parse_openai_llm(self, call):
if isinstance(call, dict):
try:
call_args = json.loads(call["function"]["arguments"])
tool_id = call["function"]["name"].split("_")[-1]
action_name = call["function"]["name"].rsplit("_", 1)[0]
except (KeyError, TypeError) as e:
return None, None, None
else:
try:
call_args = json.loads(call.function.arguments) call_args = json.loads(call.function.arguments)
tool_id = call.function.name.split("_")[-1] tool_id = call.function.name.split("_")[-1]
action_name = call.function.name.rsplit("_", 1)[0] action_name = call.function.name.rsplit("_", 1)[0]
except (AttributeError, TypeError) as e:
return None, None, None
return tool_id, action_name, call_args return tool_id, action_name, call_args
def _parse_google_llm(self, call): def _parse_google_llm(self, call):

View File

@@ -152,7 +152,15 @@ class GoogleLLM(BaseLLM):
config=config, config=config,
) )
for chunk in response: for chunk in response:
if chunk.text is not None: if hasattr(chunk, "candidates") and chunk.candidates:
for candidate in chunk.candidates:
if candidate.content and candidate.content.parts:
for part in candidate.content.parts:
if part.function_call:
yield part
elif part.text:
yield part.text
elif hasattr(chunk, "text"):
yield chunk.text yield chunk.text
def _supports_tools(self): def _supports_tools(self):

View File

@@ -111,6 +111,15 @@ class OpenAILLM(BaseLLM):
**kwargs, **kwargs,
): ):
messages = self._clean_messages_openai(messages) messages = self._clean_messages_openai(messages)
if tools:
response = self.client.chat.completions.create(
model=model,
messages=messages,
stream=stream,
tools=tools,
**kwargs,
)
else:
response = self.client.chat.completions.create( response = self.client.chat.completions.create(
model=model, messages=messages, stream=stream, **kwargs model=model, messages=messages, stream=stream, **kwargs
) )
@@ -118,6 +127,8 @@ class OpenAILLM(BaseLLM):
for line in response: for line in response:
if line.choices[0].delta.content is not None: if line.choices[0].delta.content is not None:
yield line.choices[0].delta.content yield line.choices[0].delta.content
else:
yield line.choices[0]
def _supports_tools(self): def _supports_tools(self):
return True return True