mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 08:33:20 +00:00
add google
This commit is contained in:
@@ -72,29 +72,38 @@ class GoogleLLM(BaseLLM):
|
|||||||
messages,
|
messages,
|
||||||
stream=False,
|
stream=False,
|
||||||
tools=None,
|
tools=None,
|
||||||
|
formatting="openai",
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
import google.generativeai as genai
|
from google import genai
|
||||||
genai.configure(api_key=self.api_key)
|
from google.genai import types
|
||||||
|
client = genai.Client(api_key=self.api_key)
|
||||||
|
|
||||||
|
|
||||||
config = {
|
config = {
|
||||||
}
|
}
|
||||||
model = 'gemini-2.0-flash-exp'
|
model = 'gemini-2.0-flash-exp'
|
||||||
|
if formatting=="raw":
|
||||||
model = genai.GenerativeModel(
|
response = client.models.generate_content(
|
||||||
model_name=model,
|
model=model,
|
||||||
generation_config=config,
|
contents=messages
|
||||||
system_instruction=messages[0]["content"],
|
|
||||||
tools=self._clean_tools_format(tools)
|
|
||||||
)
|
)
|
||||||
chat_session = model.start_chat(
|
|
||||||
history=self._clean_messages_google(messages)[:-1]
|
else:
|
||||||
)
|
model = genai.GenerativeModel(
|
||||||
response = chat_session.send_message(
|
model_name=model,
|
||||||
self._clean_messages_google(messages)[-1]
|
generation_config=config,
|
||||||
)
|
system_instruction=messages[0]["content"],
|
||||||
logging.info(response)
|
tools=self._clean_tools_format(tools)
|
||||||
return response.text
|
)
|
||||||
|
chat_session = model.start_chat(
|
||||||
|
history=self._clean_messages_google(messages)[:-1]
|
||||||
|
)
|
||||||
|
response = chat_session.send_message(
|
||||||
|
self._clean_messages_google(messages)[-1]
|
||||||
|
)
|
||||||
|
logging.info(response)
|
||||||
|
return response.text
|
||||||
|
|
||||||
def _raw_gen_stream(
|
def _raw_gen_stream(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
|
|
||||||
from application.core.mongo_db import MongoDB
|
from application.core.mongo_db import MongoDB
|
||||||
from application.llm.llm_creator import LLMCreator
|
from application.llm.llm_creator import LLMCreator
|
||||||
@@ -79,6 +80,25 @@ class Agent:
|
|||||||
print(f"Executing tool: {action_name} with args: {call_args}")
|
print(f"Executing tool: {action_name} with args: {call_args}")
|
||||||
return tool.execute_action(action_name, **call_args), call_id
|
return tool.execute_action(action_name, **call_args), call_id
|
||||||
|
|
||||||
|
def _execute_tool_action_google(self, tools_dict, call):
|
||||||
|
call_args = json.loads(call.args)
|
||||||
|
tool_id = call.name.split("_")[-1]
|
||||||
|
action_name = call.name.rsplit("_", 1)[0]
|
||||||
|
|
||||||
|
tool_data = tools_dict[tool_id]
|
||||||
|
action_data = next(
|
||||||
|
action for action in tool_data["actions"] if action["name"] == action_name
|
||||||
|
)
|
||||||
|
|
||||||
|
for param, details in action_data["parameters"]["properties"].items():
|
||||||
|
if param not in call_args and "value" in details:
|
||||||
|
call_args[param] = details["value"]
|
||||||
|
|
||||||
|
tm = ToolManager(config={})
|
||||||
|
tool = tm.load_tool(tool_data["name"], tool_config=tool_data["config"])
|
||||||
|
print(f"Executing tool: {action_name} with args: {call_args}")
|
||||||
|
return tool.execute_action(action_name, **call_args)
|
||||||
|
|
||||||
def _simple_tool_agent(self, messages):
|
def _simple_tool_agent(self, messages):
|
||||||
tools_dict = self._get_user_tools()
|
tools_dict = self._get_user_tools()
|
||||||
self._prepare_tools(tools_dict)
|
self._prepare_tools(tools_dict)
|
||||||
@@ -91,8 +111,18 @@ class Agent:
|
|||||||
if resp.message.content:
|
if resp.message.content:
|
||||||
yield resp.message.content
|
yield resp.message.content
|
||||||
return
|
return
|
||||||
|
# check if self.llm class is GoogleLLM
|
||||||
|
while self.llm.__class__.__name__ == "GoogleLLM" and resp.content.parts[0].function_call:
|
||||||
|
from google.genai import types
|
||||||
|
|
||||||
while resp.finish_reason == "tool_calls":
|
function_call_part = resp.candidates[0].content.parts[0]
|
||||||
|
tool_response = self._execute_tool_action_google(tools_dict, function_call_part.function_call)
|
||||||
|
function_response_part = types.Part.from_function_response(
|
||||||
|
name=function_call_part.function_call.name,
|
||||||
|
response=tool_response
|
||||||
|
)
|
||||||
|
|
||||||
|
while self.llm.__class__.__name__ == "OpenAILLM" and resp.finish_reason == "tool_calls":
|
||||||
message = json.loads(resp.model_dump_json())["message"]
|
message = json.loads(resp.model_dump_json())["message"]
|
||||||
keys_to_remove = {"audio", "function_call", "refusal"}
|
keys_to_remove = {"audio", "function_call", "refusal"}
|
||||||
filtered_data = {
|
filtered_data = {
|
||||||
|
|||||||
Reference in New Issue
Block a user