From c97d1e336308c0717bc12774e663d49b403aa345 Mon Sep 17 00:00:00 2001 From: Siddhant Rai Date: Fri, 17 Jan 2025 09:22:41 +0530 Subject: [PATCH] fix: google parser, llm handler and other errors --- application/llm/google_ai.py | 186 +++++++++++++----------- application/tools/agent.py | 4 +- application/tools/llm_handler.py | 44 ++++-- application/tools/tool_action_parser.py | 9 +- 4 files changed, 141 insertions(+), 102 deletions(-) diff --git a/application/llm/google_ai.py b/application/llm/google_ai.py index 09f0d9f7..24043d9c 100644 --- a/application/llm/google_ai.py +++ b/application/llm/google_ai.py @@ -1,60 +1,77 @@ -from application.llm.base import BaseLLM +import google.generativeai as genai + from application.core.settings import settings -import logging +from application.llm.base import BaseLLM + class GoogleLLM(BaseLLM): def __init__(self, api_key=None, user_api_key=None, *args, **kwargs): - super().__init__(*args, **kwargs) self.api_key = settings.API_KEY - self.user_api_key = user_api_key + genai.configure(api_key=self.api_key) def _clean_messages_google(self, messages): - return [ - { - "role": "model" if message["role"] == "system" else message["role"], - "parts": [message["content"]], - } - for message in messages[1:] - ] - + cleaned_messages = [] + for message in messages[1:]: + cleaned_messages.append( + { + "role": "model" if message["role"] == "system" else message["role"], + "parts": [message["content"]], + } + ) + return cleaned_messages + def _clean_tools_format(self, tools_data): - """ - Cleans the tools data format, converting string type representations - to the expected dictionary structure for google-generativeai. - """ if isinstance(tools_data, list): return [self._clean_tools_format(item) for item in tools_data] elif isinstance(tools_data, dict): - if 'function' in tools_data and 'type' in tools_data and tools_data['type'] == 'function': + if ( + "function" in tools_data + and "type" in tools_data + and tools_data["type"] == "function" + ): # Handle the case where tools are nested under 'function' - cleaned_function = self._clean_tools_format(tools_data['function']) - return {'function_declarations': [cleaned_function]} - elif 'function' in tools_data and 'type_' in tools_data and tools_data['type_'] == 'function': + cleaned_function = self._clean_tools_format(tools_data["function"]) + return {"function_declarations": [cleaned_function]} + elif ( + "function" in tools_data + and "type_" in tools_data + and tools_data["type_"] == "function" + ): # Handle the case where tools are nested under 'function' and type is already 'type_' - cleaned_function = self._clean_tools_format(tools_data['function']) - return {'function_declarations': [cleaned_function]} + cleaned_function = self._clean_tools_format(tools_data["function"]) + return {"function_declarations": [cleaned_function]} else: new_tools_data = {} for key, value in tools_data.items(): - if key == 'type': - if value == 'string': - new_tools_data['type_'] = 'STRING' # Keep as string for now - elif value == 'object': - new_tools_data['type_'] = 'OBJECT' # Keep as string for now - elif key == 'additionalProperties': + if key == "type": + if value == "string": + new_tools_data["type_"] = "STRING" + elif value == "object": + new_tools_data["type_"] = "OBJECT" + elif key == "additionalProperties": continue - elif key == 'properties': + elif key == "properties": if isinstance(value, dict): new_properties = {} for prop_name, prop_value in value.items(): - if isinstance(prop_value, dict) and 'type' in prop_value: - if prop_value['type'] == 'string': - new_properties[prop_name] = {'type_': 'STRING', 'description': prop_value.get('description')} + if ( + isinstance(prop_value, dict) + and "type" in prop_value + ): + if prop_value["type"] == "string": + new_properties[prop_name] = { + "type_": "STRING", + "description": prop_value.get( + "description" + ), + } # Add more type mappings as needed else: - new_properties[prop_name] = self._clean_tools_format(prop_value) + new_properties[prop_name] = ( + self._clean_tools_format(prop_value) + ) new_tools_data[key] = new_properties else: new_tools_data[key] = self._clean_tools_format(value) @@ -74,65 +91,64 @@ class GoogleLLM(BaseLLM): tools=None, formatting="openai", **kwargs - ): - from google import genai - from google.genai import types - client = genai.Client(api_key=self.api_key) + ): + config = {} + model_name = "gemini-2.0-flash-exp" - - config = { - } - model = 'gemini-2.0-flash-exp' - if formatting=="raw": - response = client.models.generate_content( - model=model, - contents=messages - ) - - else: - model = genai.GenerativeModel( - model_name=model, - generation_config=config, - system_instruction=messages[0]["content"], - tools=self._clean_tools_format(tools) - ) - chat_session = model.start_chat( - history=self._clean_messages_google(messages)[:-1] - ) - response = chat_session.send_message( - self._clean_messages_google(messages)[-1] - ) - logging.info(response) + if formatting == "raw": + client = genai.GenerativeModel(model_name=model_name) + response = client.generate_content(contents=messages) return response.text + else: + if tools: + client = genai.GenerativeModel( + model_name=model_name, + generation_config=config, + system_instruction=messages[0]["content"], + tools=self._clean_tools_format(tools), + ) + chat_session = gen_model.start_chat( + history=self._clean_messages_google(messages)[:-1] + ) + response = chat_session.send_message( + self._clean_messages_google(messages)[-1] + ) + return response + else: + gen_model = genai.GenerativeModel( + model_name=model_name, + generation_config=config, + system_instruction=messages[0]["content"], + ) + chat_session = gen_model.start_chat( + history=self._clean_messages_google(messages)[:-1] + ) + response = chat_session.send_message( + self._clean_messages_google(messages)[-1] + ) + return response.text def _raw_gen_stream( - self, - baseself, - model, - messages, - stream=True, - tools=None, - **kwargs - ): - import google.generativeai as genai - genai.configure(api_key=self.api_key) - config = { - } - model = genai.GenerativeModel( - model_name=model, + self, baseself, model, messages, stream=True, tools=None, **kwargs + ): + config = {} + model_name = "gemini-2.0-flash-exp" + + gen_model = genai.GenerativeModel( + model_name=model_name, generation_config=config, - system_instruction=messages[0]["content"] - ) - chat_session = model.start_chat( + system_instruction=messages[0]["content"], + tools=self._clean_tools_format(tools), + ) + chat_session = gen_model.start_chat( history=self._clean_messages_google(messages)[:-1], ) response = chat_session.send_message( - self._clean_messages_google(messages)[-1] - , stream=stream + self._clean_messages_google(messages)[-1], stream=stream ) - for line in response: - if line.text is not None: - yield line.text - + for chunk in response: + if chunk.text is not None: + yield chunk.text + def _supports_tools(self): - return True \ No newline at end of file + return True diff --git a/application/tools/agent.py b/application/tools/agent.py index f4b37d9b..bbf6bcac 100644 --- a/application/tools/agent.py +++ b/application/tools/agent.py @@ -89,7 +89,7 @@ class Agent: if isinstance(resp, str): yield resp return - if resp.message.content: + if hasattr(resp, "message") and hasattr(resp.message, "content"): yield resp.message.content return @@ -98,7 +98,7 @@ class Agent: # If no tool calls are needed, generate the final response if isinstance(resp, str): yield resp - elif resp.message.content: + elif hasattr(resp, "message") and hasattr(resp.message, "content"): yield resp.message.content else: completion = self.llm.gen_stream( diff --git a/application/tools/llm_handler.py b/application/tools/llm_handler.py index 58fce56e..6be89ad7 100644 --- a/application/tools/llm_handler.py +++ b/application/tools/llm_handler.py @@ -47,23 +47,43 @@ class OpenAILLMHandler(LLMHandler): class GoogleLLMHandler(LLMHandler): def handle_response(self, agent, resp, tools_dict, messages): - from google.genai import types + import google.generativeai as genai - while resp.content.parts[0].function_call: - function_call_part = resp.candidates[0].content.parts[0] - tool_response, call_id = agent._execute_tool_action( - tools_dict, function_call_part.function_call - ) - function_response_part = types.Part.from_function_response( - name=function_call_part.function_call.name, response=tool_response - ) - - messages.append(function_call_part, function_response_part) + while ( + hasattr(resp.candidates[0].content.parts[0], "function_call") + and resp.candidates[0].content.parts[0].function_call + ): + responses = {} + for part in resp.candidates[0].content.parts: + if hasattr(part, "function_call") and part.function_call: + function_call_part = part + messages.append( + genai.protos.Part( + function_call=genai.protos.FunctionCall( + name=function_call_part.function_call.name, + args=function_call_part.function_call.args, + ) + ) + ) + tool_response, call_id = agent._execute_tool_action( + tools_dict, function_call_part.function_call + ) + responses[function_call_part.function_call.name] = tool_response + response_parts = [ + genai.protos.Part( + function_response=genai.protos.FunctionResponse( + name=tool_name, response={"result": response} + ) + ) + for tool_name, response in responses.items() + ] + if response_parts: + messages.append(response_parts) resp = agent.llm.gen( model=agent.gpt_model, messages=messages, tools=agent.tools ) - return resp + return resp.text def get_llm_handler(llm_type): diff --git a/application/tools/tool_action_parser.py b/application/tools/tool_action_parser.py index b708992a..254c13b4 100644 --- a/application/tools/tool_action_parser.py +++ b/application/tools/tool_action_parser.py @@ -1,5 +1,7 @@ import json +from google.protobuf.json_format import MessageToDict + class ToolActionParser: def __init__(self, llm_type): @@ -20,7 +22,8 @@ class ToolActionParser: return tool_id, action_name, call_args def _parse_google_llm(self, call): - call_args = json.loads(call.args) - tool_id = call.name.split("_")[-1] - action_name = call.name.rsplit("_", 1)[0] + call = MessageToDict(call._pb) + call_args = call["args"] + tool_id = call["name"].split("_")[-1] + action_name = call["name"].rsplit("_", 1)[0] return tool_id, action_name, call_args