diff --git a/application/agents/classic_agent.py b/application/agents/classic_agent.py index 4e64442d..79d9e37f 100644 --- a/application/agents/classic_agent.py +++ b/application/agents/classic_agent.py @@ -104,7 +104,8 @@ class ClassicAgent(BaseAgent): model=self.gpt_model, messages=messages_combine, tools=self.tools ) for line in completion: - yield {"answer": line} + if isinstance(line, str): + yield {"answer": line} yield {"tool_calls": self.tool_calls.copy()} @@ -116,7 +117,7 @@ class ClassicAgent(BaseAgent): return retrieved_data def _llm_gen(self, messages_combine, log_context): - resp = self.llm.gen( + resp = self.llm.gen_stream( model=self.gpt_model, messages=messages_combine, tools=self.tools ) if log_context: @@ -131,5 +132,4 @@ class ClassicAgent(BaseAgent): if log_context: data = build_stack_data(self.llm_handler) log_context.stacks.append({"component": "llm_handler", "data": data}) - return resp diff --git a/application/agents/llm_handler.py b/application/agents/llm_handler.py index adf240c3..9267dc53 100644 --- a/application/agents/llm_handler.py +++ b/application/agents/llm_handler.py @@ -15,84 +15,221 @@ class LLMHandler(ABC): class OpenAILLMHandler(LLMHandler): - def handle_response(self, agent, resp, tools_dict, messages): - while resp.finish_reason == "tool_calls": - message = json.loads(resp.model_dump_json())["message"] - keys_to_remove = {"audio", "function_call", "refusal"} - filtered_data = { - k: v for k, v in message.items() if k not in keys_to_remove - } - messages.append(filtered_data) + def handle_response(self, agent, resp, tools_dict, messages, stream: bool = True): + if not stream: + while hasattr(resp, "finish_reason") and resp.finish_reason == "tool_calls": + message = json.loads(resp.model_dump_json())["message"] + keys_to_remove = {"audio", "function_call", "refusal"} + filtered_data = { + k: v for k, v in message.items() if k not in keys_to_remove + } + messages.append(filtered_data) - tool_calls = resp.message.tool_calls - for call in tool_calls: - try: - self.tool_calls.append(call) - tool_response, call_id = agent._execute_tool_action( - tools_dict, call - ) - function_call_dict = { - "function_call": { - "name": call.function.name, - "args": call.function.arguments, - "call_id": call_id, + tool_calls = resp.message.tool_calls + for call in tool_calls: + try: + self.tool_calls.append(call) + tool_response, call_id = agent._execute_tool_action( + tools_dict, call + ) + function_call_dict = { + "function_call": { + "name": call.function.name, + "args": call.function.arguments, + "call_id": call_id, + } } - } - function_response_dict = { - "function_response": { - "name": call.function.name, - "response": {"result": tool_response}, - "call_id": call_id, + function_response_dict = { + "function_response": { + "name": call.function.name, + "response": {"result": tool_response}, + "call_id": call_id, + } } - } - messages.append( - {"role": "assistant", "content": [function_call_dict]} - ) - messages.append( - {"role": "tool", "content": [function_response_dict]} - ) + messages.append( + {"role": "assistant", "content": [function_call_dict]} + ) + messages.append( + {"role": "tool", "content": [function_response_dict]} + ) - except Exception as e: - messages.append( - { - "role": "tool", - "content": f"Error executing tool: {str(e)}", - "tool_call_id": call_id, - } - ) - resp = agent.llm.gen( - model=agent.gpt_model, messages=messages, tools=agent.tools - ) - self.llm_calls.append(build_stack_data(agent.llm)) - return resp + except Exception as e: + messages.append( + { + "role": "tool", + "content": f"Error executing tool: {str(e)}", + "tool_call_id": call_id, + } + ) + resp = agent.llm.gen_stream( + model=agent.gpt_model, messages=messages, tools=agent.tools + ) + self.llm_calls.append(build_stack_data(agent.llm)) + return resp + + else: + while True: + tool_calls = {} + for chunk in resp: + if isinstance(chunk, str): + return + else: + chunk_delta = chunk.delta + + if ( + hasattr(chunk_delta, "tool_calls") + and chunk_delta.tool_calls is not None + ): + for tool_call in chunk_delta.tool_calls: + index = tool_call.index + if index not in tool_calls: + tool_calls[index] = { + "id": "", + "function": {"name": "", "arguments": ""}, + } + + current = tool_calls[index] + if tool_call.id: + current["id"] = tool_call.id + if tool_call.function.name: + current["function"][ + "name" + ] = tool_call.function.name + if tool_call.function.arguments: + current["function"][ + "arguments" + ] += tool_call.function.arguments + tool_calls[index] = current + + if ( + hasattr(chunk, "finish_reason") + and chunk.finish_reason == "tool_calls" + ): + for index in sorted(tool_calls.keys()): + call = tool_calls[index] + try: + self.tool_calls.append(call) + tool_response, call_id = agent._execute_tool_action( + tools_dict, call + ) + + function_call_dict = { + "function_call": { + "name": call["function"]["name"], + "args": call["function"]["arguments"], + "call_id": call["id"], + } + } + function_response_dict = { + "function_response": { + "name": call["function"]["name"], + "response": {"result": tool_response}, + "call_id": call["id"], + } + } + + messages.append( + { + "role": "assistant", + "content": [function_call_dict], + } + ) + messages.append( + { + "role": "tool", + "content": [function_response_dict], + } + ) + + except Exception as e: + messages.append( + { + "role": "assistant", + "content": f"Error executing tool: {str(e)}", + } + ) + tool_calls = {} + + if ( + hasattr(chunk, "finish_reason") + and chunk.finish_reason == "stop" + ): + return + + resp = agent.llm.gen_stream( + model=agent.gpt_model, messages=messages, tools=agent.tools + ) + self.llm_calls.append(build_stack_data(agent.llm)) class GoogleLLMHandler(LLMHandler): - def handle_response(self, agent, resp, tools_dict, messages): + def handle_response(self, agent, resp, tools_dict, messages, stream: bool = True): from google.genai import types while True: - response = agent.llm.gen( - model=agent.gpt_model, messages=messages, tools=agent.tools - ) - self.llm_calls.append(build_stack_data(agent.llm)) - if response.candidates and response.candidates[0].content.parts: + if not stream: + response = agent.llm.gen( + model=agent.gpt_model, messages=messages, tools=agent.tools + ) + self.llm_calls.append(build_stack_data(agent.llm)) + if response.candidates and response.candidates[0].content.parts: + tool_call_found = False + for part in response.candidates[0].content.parts: + if part.function_call: + tool_call_found = True + self.tool_calls.append(part.function_call) + tool_response, call_id = agent._execute_tool_action( + tools_dict, part.function_call + ) + function_response_part = types.Part.from_function_response( + name=part.function_call.name, + response={"result": tool_response}, + ) + + messages.append( + {"role": "model", "content": [part.to_json_dict()]} + ) + messages.append( + { + "role": "tool", + "content": [function_response_part.to_json_dict()], + } + ) + + if ( + not tool_call_found + and response.candidates[0].content.parts + and response.candidates[0].content.parts[0].text + ): + return response.candidates[0].content.parts[0].text + elif not tool_call_found: + return response.candidates[0].content.parts + + else: + return response + + else: + response = agent.llm.gen_stream( + model=agent.gpt_model, messages=messages, tools=agent.tools + ) + self.llm_calls.append(build_stack_data(agent.llm)) + tool_call_found = False - for part in response.candidates[0].content.parts: - if part.function_call: + for result in response: + if hasattr(result, "function_call"): tool_call_found = True - self.tool_calls.append(part.function_call) + self.tool_calls.append(result.function_call) tool_response, call_id = agent._execute_tool_action( - tools_dict, part.function_call + tools_dict, result.function_call ) function_response_part = types.Part.from_function_response( - name=part.function_call.name, + name=result.function_call.name, response={"result": tool_response}, ) messages.append( - {"role": "model", "content": [part.to_json_dict()]} + {"role": "model", "content": [result.to_json_dict()]} ) messages.append( { @@ -101,17 +238,8 @@ class GoogleLLMHandler(LLMHandler): } ) - if ( - not tool_call_found - and response.candidates[0].content.parts - and response.candidates[0].content.parts[0].text - ): - return response.candidates[0].content.parts[0].text - elif not tool_call_found: - return response.candidates[0].content.parts - - else: - return response + if not tool_call_found: + return response def get_llm_handler(llm_type): diff --git a/application/agents/tools/cryptoprice.py b/application/agents/tools/cryptoprice.py index 80d0c2fc..c25c3d43 100644 --- a/application/agents/tools/cryptoprice.py +++ b/application/agents/tools/cryptoprice.py @@ -31,7 +31,6 @@ class CryptoPriceTool(Tool): response = requests.get(url) if response.status_code == 200: data = response.json() - # data will be like {"USD": } if the call is successful if currency.upper() in data: return { "status_code": response.status_code, diff --git a/application/agents/tools/tool_action_parser.py b/application/agents/tools/tool_action_parser.py index ac0a70c1..4d894d1a 100644 --- a/application/agents/tools/tool_action_parser.py +++ b/application/agents/tools/tool_action_parser.py @@ -14,9 +14,20 @@ class ToolActionParser: return parser(call) def _parse_openai_llm(self, call): - call_args = json.loads(call.function.arguments) - tool_id = call.function.name.split("_")[-1] - action_name = call.function.name.rsplit("_", 1)[0] + if isinstance(call, dict): + try: + call_args = json.loads(call["function"]["arguments"]) + tool_id = call["function"]["name"].split("_")[-1] + action_name = call["function"]["name"].rsplit("_", 1)[0] + except (KeyError, TypeError) as e: + return None, None, None + else: + try: + call_args = json.loads(call.function.arguments) + tool_id = call.function.name.split("_")[-1] + action_name = call.function.name.rsplit("_", 1)[0] + except (AttributeError, TypeError) as e: + return None, None, None return tool_id, action_name, call_args def _parse_google_llm(self, call): diff --git a/application/llm/google_ai.py b/application/llm/google_ai.py index 31943601..d52e26c8 100644 --- a/application/llm/google_ai.py +++ b/application/llm/google_ai.py @@ -152,7 +152,15 @@ class GoogleLLM(BaseLLM): config=config, ) for chunk in response: - if chunk.text is not None: + if hasattr(chunk, "candidates") and chunk.candidates: + for candidate in chunk.candidates: + if candidate.content and candidate.content.parts: + for part in candidate.content.parts: + if part.function_call: + yield part + elif part.text: + yield part.text + elif hasattr(chunk, "text"): yield chunk.text def _supports_tools(self): diff --git a/application/llm/openai.py b/application/llm/openai.py index b8f311b0..938de523 100644 --- a/application/llm/openai.py +++ b/application/llm/openai.py @@ -111,13 +111,24 @@ class OpenAILLM(BaseLLM): **kwargs, ): messages = self._clean_messages_openai(messages) - response = self.client.chat.completions.create( - model=model, messages=messages, stream=stream, **kwargs - ) + if tools: + response = self.client.chat.completions.create( + model=model, + messages=messages, + stream=stream, + tools=tools, + **kwargs, + ) + else: + response = self.client.chat.completions.create( + model=model, messages=messages, stream=stream, **kwargs + ) for line in response: if line.choices[0].delta.content is not None: yield line.choices[0].delta.content + else: + yield line.choices[0] def _supports_tools(self): return True