mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 08:33:20 +00:00
add google
This commit is contained in:
@@ -72,29 +72,38 @@ class GoogleLLM(BaseLLM):
|
||||
messages,
|
||||
stream=False,
|
||||
tools=None,
|
||||
formatting="openai",
|
||||
**kwargs
|
||||
):
|
||||
import google.generativeai as genai
|
||||
genai.configure(api_key=self.api_key)
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
client = genai.Client(api_key=self.api_key)
|
||||
|
||||
|
||||
config = {
|
||||
}
|
||||
model = 'gemini-2.0-flash-exp'
|
||||
|
||||
model = genai.GenerativeModel(
|
||||
model_name=model,
|
||||
generation_config=config,
|
||||
system_instruction=messages[0]["content"],
|
||||
tools=self._clean_tools_format(tools)
|
||||
if formatting=="raw":
|
||||
response = client.models.generate_content(
|
||||
model=model,
|
||||
contents=messages
|
||||
)
|
||||
chat_session = model.start_chat(
|
||||
history=self._clean_messages_google(messages)[:-1]
|
||||
)
|
||||
response = chat_session.send_message(
|
||||
self._clean_messages_google(messages)[-1]
|
||||
)
|
||||
logging.info(response)
|
||||
return response.text
|
||||
|
||||
else:
|
||||
model = genai.GenerativeModel(
|
||||
model_name=model,
|
||||
generation_config=config,
|
||||
system_instruction=messages[0]["content"],
|
||||
tools=self._clean_tools_format(tools)
|
||||
)
|
||||
chat_session = model.start_chat(
|
||||
history=self._clean_messages_google(messages)[:-1]
|
||||
)
|
||||
response = chat_session.send_message(
|
||||
self._clean_messages_google(messages)[-1]
|
||||
)
|
||||
logging.info(response)
|
||||
return response.text
|
||||
|
||||
def _raw_gen_stream(
|
||||
self,
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
@@ -79,6 +80,25 @@ class Agent:
|
||||
print(f"Executing tool: {action_name} with args: {call_args}")
|
||||
return tool.execute_action(action_name, **call_args), call_id
|
||||
|
||||
def _execute_tool_action_google(self, tools_dict, call):
|
||||
call_args = json.loads(call.args)
|
||||
tool_id = call.name.split("_")[-1]
|
||||
action_name = call.name.rsplit("_", 1)[0]
|
||||
|
||||
tool_data = tools_dict[tool_id]
|
||||
action_data = next(
|
||||
action for action in tool_data["actions"] if action["name"] == action_name
|
||||
)
|
||||
|
||||
for param, details in action_data["parameters"]["properties"].items():
|
||||
if param not in call_args and "value" in details:
|
||||
call_args[param] = details["value"]
|
||||
|
||||
tm = ToolManager(config={})
|
||||
tool = tm.load_tool(tool_data["name"], tool_config=tool_data["config"])
|
||||
print(f"Executing tool: {action_name} with args: {call_args}")
|
||||
return tool.execute_action(action_name, **call_args)
|
||||
|
||||
def _simple_tool_agent(self, messages):
|
||||
tools_dict = self._get_user_tools()
|
||||
self._prepare_tools(tools_dict)
|
||||
@@ -91,8 +111,18 @@ class Agent:
|
||||
if resp.message.content:
|
||||
yield resp.message.content
|
||||
return
|
||||
# check if self.llm class is GoogleLLM
|
||||
while self.llm.__class__.__name__ == "GoogleLLM" and resp.content.parts[0].function_call:
|
||||
from google.genai import types
|
||||
|
||||
while resp.finish_reason == "tool_calls":
|
||||
function_call_part = resp.candidates[0].content.parts[0]
|
||||
tool_response = self._execute_tool_action_google(tools_dict, function_call_part.function_call)
|
||||
function_response_part = types.Part.from_function_response(
|
||||
name=function_call_part.function_call.name,
|
||||
response=tool_response
|
||||
)
|
||||
|
||||
while self.llm.__class__.__name__ == "OpenAILLM" and resp.finish_reason == "tool_calls":
|
||||
message = json.loads(resp.model_dump_json())["message"]
|
||||
keys_to_remove = {"audio", "function_call", "refusal"}
|
||||
filtered_data = {
|
||||
|
||||
Reference in New Issue
Block a user