From 51225b18b2e520472d0f969464925dea8b5c4fc6 Mon Sep 17 00:00:00 2001 From: Alex Date: Mon, 13 Jan 2025 10:37:53 +0000 Subject: [PATCH] add google --- application/llm/google_ai.py | 41 ++++++++++++++++++++++-------------- application/tools/agent.py | 32 +++++++++++++++++++++++++++- 2 files changed, 56 insertions(+), 17 deletions(-) diff --git a/application/llm/google_ai.py b/application/llm/google_ai.py index 33ae0855..09f0d9f7 100644 --- a/application/llm/google_ai.py +++ b/application/llm/google_ai.py @@ -72,29 +72,38 @@ class GoogleLLM(BaseLLM): messages, stream=False, tools=None, + formatting="openai", **kwargs ): - import google.generativeai as genai - genai.configure(api_key=self.api_key) + from google import genai + from google.genai import types + client = genai.Client(api_key=self.api_key) + config = { } model = 'gemini-2.0-flash-exp' - - model = genai.GenerativeModel( - model_name=model, - generation_config=config, - system_instruction=messages[0]["content"], - tools=self._clean_tools_format(tools) + if formatting=="raw": + response = client.models.generate_content( + model=model, + contents=messages ) - chat_session = model.start_chat( - history=self._clean_messages_google(messages)[:-1] - ) - response = chat_session.send_message( - self._clean_messages_google(messages)[-1] - ) - logging.info(response) - return response.text + + else: + model = genai.GenerativeModel( + model_name=model, + generation_config=config, + system_instruction=messages[0]["content"], + tools=self._clean_tools_format(tools) + ) + chat_session = model.start_chat( + history=self._clean_messages_google(messages)[:-1] + ) + response = chat_session.send_message( + self._clean_messages_google(messages)[-1] + ) + logging.info(response) + return response.text def _raw_gen_stream( self, diff --git a/application/tools/agent.py b/application/tools/agent.py index d4077e45..a6aec2e9 100644 --- a/application/tools/agent.py +++ b/application/tools/agent.py @@ -1,4 +1,5 @@ import json +import logging from application.core.mongo_db import MongoDB from application.llm.llm_creator import LLMCreator @@ -79,6 +80,25 @@ class Agent: print(f"Executing tool: {action_name} with args: {call_args}") return tool.execute_action(action_name, **call_args), call_id + def _execute_tool_action_google(self, tools_dict, call): + call_args = json.loads(call.args) + tool_id = call.name.split("_")[-1] + action_name = call.name.rsplit("_", 1)[0] + + tool_data = tools_dict[tool_id] + action_data = next( + action for action in tool_data["actions"] if action["name"] == action_name + ) + + for param, details in action_data["parameters"]["properties"].items(): + if param not in call_args and "value" in details: + call_args[param] = details["value"] + + tm = ToolManager(config={}) + tool = tm.load_tool(tool_data["name"], tool_config=tool_data["config"]) + print(f"Executing tool: {action_name} with args: {call_args}") + return tool.execute_action(action_name, **call_args) + def _simple_tool_agent(self, messages): tools_dict = self._get_user_tools() self._prepare_tools(tools_dict) @@ -91,8 +111,18 @@ class Agent: if resp.message.content: yield resp.message.content return + # check if self.llm class is GoogleLLM + while self.llm.__class__.__name__ == "GoogleLLM" and resp.content.parts[0].function_call: + from google.genai import types - while resp.finish_reason == "tool_calls": + function_call_part = resp.candidates[0].content.parts[0] + tool_response = self._execute_tool_action_google(tools_dict, function_call_part.function_call) + function_response_part = types.Part.from_function_response( + name=function_call_part.function_call.name, + response=tool_response + ) + + while self.llm.__class__.__name__ == "OpenAILLM" and resp.finish_reason == "tool_calls": message = json.loads(resp.model_dump_json())["message"] keys_to_remove = {"audio", "function_call", "refusal"} filtered_data = {