mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-30 17:13:15 +00:00
fix: GoogleLLM, agent and handler according to the new genai SDK
This commit is contained in:
@@ -1,86 +1,61 @@
|
||||
import google.generativeai as genai
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.llm.base import BaseLLM
|
||||
|
||||
|
||||
class GoogleLLM(BaseLLM):
|
||||
|
||||
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.api_key = settings.API_KEY
|
||||
genai.configure(api_key=self.api_key)
|
||||
self.client = genai.Client(api_key="AIzaSyDmbZX65qlQKXcvfMBkJV2KwH82_0yIMlE")
|
||||
|
||||
def _clean_messages_google(self, messages):
|
||||
cleaned_messages = []
|
||||
for message in messages[1:]:
|
||||
cleaned_messages.append(
|
||||
{
|
||||
"role": "model" if message["role"] == "system" else message["role"],
|
||||
"parts": [message["content"]],
|
||||
}
|
||||
)
|
||||
for message in messages:
|
||||
role = message.get("role")
|
||||
content = message.get("content")
|
||||
|
||||
if role and content is not None:
|
||||
if isinstance(content, str):
|
||||
parts = [types.Part.from_text(content)]
|
||||
elif isinstance(content, list):
|
||||
parts = content
|
||||
else:
|
||||
raise ValueError(f"Unexpected content type: {type(content)}")
|
||||
|
||||
cleaned_messages.append(types.Content(role=role, parts=parts))
|
||||
|
||||
return cleaned_messages
|
||||
|
||||
def _clean_tools_format(self, tools_data):
|
||||
if isinstance(tools_data, list):
|
||||
return [self._clean_tools_format(item) for item in tools_data]
|
||||
elif isinstance(tools_data, dict):
|
||||
if (
|
||||
"function" in tools_data
|
||||
and "type" in tools_data
|
||||
and tools_data["type"] == "function"
|
||||
):
|
||||
# Handle the case where tools are nested under 'function'
|
||||
cleaned_function = self._clean_tools_format(tools_data["function"])
|
||||
return {"function_declarations": [cleaned_function]}
|
||||
elif (
|
||||
"function" in tools_data
|
||||
and "type_" in tools_data
|
||||
and tools_data["type_"] == "function"
|
||||
):
|
||||
# Handle the case where tools are nested under 'function' and type is already 'type_'
|
||||
cleaned_function = self._clean_tools_format(tools_data["function"])
|
||||
return {"function_declarations": [cleaned_function]}
|
||||
else:
|
||||
new_tools_data = {}
|
||||
for key, value in tools_data.items():
|
||||
if key == "type":
|
||||
if value == "string":
|
||||
new_tools_data["type_"] = "STRING"
|
||||
elif value == "object":
|
||||
new_tools_data["type_"] = "OBJECT"
|
||||
elif key == "additionalProperties":
|
||||
continue
|
||||
elif key == "properties":
|
||||
if isinstance(value, dict):
|
||||
new_properties = {}
|
||||
for prop_name, prop_value in value.items():
|
||||
if (
|
||||
isinstance(prop_value, dict)
|
||||
and "type" in prop_value
|
||||
):
|
||||
if prop_value["type"] == "string":
|
||||
new_properties[prop_name] = {
|
||||
"type_": "STRING",
|
||||
"description": prop_value.get(
|
||||
"description"
|
||||
),
|
||||
}
|
||||
# Add more type mappings as needed
|
||||
else:
|
||||
new_properties[prop_name] = (
|
||||
self._clean_tools_format(prop_value)
|
||||
)
|
||||
new_tools_data[key] = new_properties
|
||||
else:
|
||||
new_tools_data[key] = self._clean_tools_format(value)
|
||||
def _clean_tools_format(self, tools_list):
|
||||
genai_tools = []
|
||||
for tool_data in tools_list:
|
||||
if tool_data["type"] == "function":
|
||||
function = tool_data["function"]
|
||||
genai_function = dict(
|
||||
name=function["name"],
|
||||
description=function["description"],
|
||||
parameters={
|
||||
"type": "OBJECT",
|
||||
"properties": {
|
||||
k: {
|
||||
**v,
|
||||
"type": v["type"].upper() if v["type"] else None,
|
||||
}
|
||||
for k, v in function["parameters"]["properties"].items()
|
||||
},
|
||||
"required": (
|
||||
function["parameters"]["required"]
|
||||
if "required" in function["parameters"]
|
||||
else []
|
||||
),
|
||||
},
|
||||
)
|
||||
genai_tool = types.Tool(function_declarations=[genai_function])
|
||||
genai_tools.append(genai_tool)
|
||||
|
||||
else:
|
||||
new_tools_data[key] = self._clean_tools_format(value)
|
||||
return new_tools_data
|
||||
else:
|
||||
return tools_data
|
||||
return genai_tools
|
||||
|
||||
def _raw_gen(
|
||||
self,
|
||||
@@ -90,61 +65,51 @@ class GoogleLLM(BaseLLM):
|
||||
stream=False,
|
||||
tools=None,
|
||||
formatting="openai",
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
config = {}
|
||||
model_name = "gemini-2.0-flash-exp"
|
||||
client = self.client
|
||||
if formatting == "openai":
|
||||
messages = self._clean_messages_google(messages)
|
||||
config = types.GenerateContentConfig()
|
||||
|
||||
if formatting == "raw":
|
||||
client = genai.GenerativeModel(model_name=model_name)
|
||||
response = client.generate_content(contents=messages)
|
||||
return response.text
|
||||
if tools:
|
||||
cleaned_tools = self._clean_tools_format(tools)
|
||||
config.tools = cleaned_tools
|
||||
response = client.models.generate_content(
|
||||
model=model,
|
||||
contents=messages,
|
||||
config=config,
|
||||
)
|
||||
return response
|
||||
else:
|
||||
if tools:
|
||||
client = genai.GenerativeModel(
|
||||
model_name=model_name,
|
||||
generation_config=config,
|
||||
system_instruction=messages[0]["content"],
|
||||
tools=self._clean_tools_format(tools),
|
||||
)
|
||||
chat_session = gen_model.start_chat(
|
||||
history=self._clean_messages_google(messages)[:-1]
|
||||
)
|
||||
response = chat_session.send_message(
|
||||
self._clean_messages_google(messages)[-1]
|
||||
)
|
||||
return response
|
||||
else:
|
||||
gen_model = genai.GenerativeModel(
|
||||
model_name=model_name,
|
||||
generation_config=config,
|
||||
system_instruction=messages[0]["content"],
|
||||
)
|
||||
chat_session = gen_model.start_chat(
|
||||
history=self._clean_messages_google(messages)[:-1]
|
||||
)
|
||||
response = chat_session.send_message(
|
||||
self._clean_messages_google(messages)[-1]
|
||||
)
|
||||
return response.text
|
||||
response = client.models.generate_content(
|
||||
model=model, contents=messages, config=config
|
||||
)
|
||||
return response.text
|
||||
|
||||
def _raw_gen_stream(
|
||||
self, baseself, model, messages, stream=True, tools=None, **kwargs
|
||||
self,
|
||||
baseself,
|
||||
model,
|
||||
messages,
|
||||
stream=True,
|
||||
tools=None,
|
||||
formatting="openai",
|
||||
**kwargs,
|
||||
):
|
||||
config = {}
|
||||
model_name = "gemini-2.0-flash-exp"
|
||||
client = self.client
|
||||
if formatting == "openai":
|
||||
cleaned_messages = self._clean_messages_google(messages)
|
||||
config = types.GenerateContentConfig()
|
||||
|
||||
gen_model = genai.GenerativeModel(
|
||||
model_name=model_name,
|
||||
generation_config=config,
|
||||
system_instruction=messages[0]["content"],
|
||||
tools=self._clean_tools_format(tools),
|
||||
)
|
||||
chat_session = gen_model.start_chat(
|
||||
history=self._clean_messages_google(messages)[:-1],
|
||||
)
|
||||
response = chat_session.send_message(
|
||||
self._clean_messages_google(messages)[-1], stream=stream
|
||||
if tools:
|
||||
cleaned_tools = self._clean_tools_format(tools)
|
||||
config.tools = cleaned_tools
|
||||
|
||||
response = client.models.generate_content_stream(
|
||||
model=model,
|
||||
contents=cleaned_messages,
|
||||
config=config,
|
||||
)
|
||||
for chunk in response:
|
||||
if chunk.text is not None:
|
||||
|
||||
@@ -95,7 +95,6 @@ class Agent:
|
||||
|
||||
resp = self.llm_handler.handle_response(self, resp, tools_dict, messages)
|
||||
|
||||
# If no tool calls are needed, generate the final response
|
||||
if isinstance(resp, str):
|
||||
yield resp
|
||||
elif hasattr(resp, "message") and hasattr(resp.message, "content"):
|
||||
@@ -110,7 +109,6 @@ class Agent:
|
||||
return
|
||||
|
||||
def gen(self, messages):
|
||||
# Generate initial response from the LLM
|
||||
if self.llm.supports_tools():
|
||||
resp = self._simple_tool_agent(messages)
|
||||
for line in resp:
|
||||
|
||||
@@ -47,43 +47,41 @@ class OpenAILLMHandler(LLMHandler):
|
||||
|
||||
class GoogleLLMHandler(LLMHandler):
|
||||
def handle_response(self, agent, resp, tools_dict, messages):
|
||||
import google.generativeai as genai
|
||||
from google.genai import types
|
||||
|
||||
while (
|
||||
hasattr(resp.candidates[0].content.parts[0], "function_call")
|
||||
and resp.candidates[0].content.parts[0].function_call
|
||||
):
|
||||
responses = {}
|
||||
for part in resp.candidates[0].content.parts:
|
||||
if hasattr(part, "function_call") and part.function_call:
|
||||
function_call_part = part
|
||||
messages.append(
|
||||
genai.protos.Part(
|
||||
function_call=genai.protos.FunctionCall(
|
||||
name=function_call_part.function_call.name,
|
||||
args=function_call_part.function_call.args,
|
||||
)
|
||||
)
|
||||
)
|
||||
tool_response, call_id = agent._execute_tool_action(
|
||||
tools_dict, function_call_part.function_call
|
||||
)
|
||||
responses[function_call_part.function_call.name] = tool_response
|
||||
response_parts = [
|
||||
genai.protos.Part(
|
||||
function_response=genai.protos.FunctionResponse(
|
||||
name=tool_name, response={"result": response}
|
||||
)
|
||||
)
|
||||
for tool_name, response in responses.items()
|
||||
]
|
||||
if response_parts:
|
||||
messages.append(response_parts)
|
||||
resp = agent.llm.gen(
|
||||
while True:
|
||||
response = agent.llm.gen(
|
||||
model=agent.gpt_model, messages=messages, tools=agent.tools
|
||||
)
|
||||
if response.candidates and response.candidates[0].content.parts:
|
||||
tool_call_found = False
|
||||
for part in response.candidates[0].content.parts:
|
||||
if part.function_call:
|
||||
tool_call_found = True
|
||||
tool_response, call_id = agent._execute_tool_action(
|
||||
tools_dict, part.function_call
|
||||
)
|
||||
|
||||
return resp.text
|
||||
function_response_part = types.Part.from_function_response(
|
||||
name=part.function_call.name,
|
||||
response={"result": tool_response},
|
||||
)
|
||||
messages.append({"role": "model", "content": [part]})
|
||||
messages.append(
|
||||
{"role": "tool", "content": [function_response_part]}
|
||||
)
|
||||
|
||||
if (
|
||||
not tool_call_found
|
||||
and response.candidates[0].content.parts
|
||||
and response.candidates[0].content.parts[0].text
|
||||
):
|
||||
return response.candidates[0].content.parts[0].text
|
||||
elif not tool_call_found:
|
||||
return response.candidates[0].content.parts
|
||||
|
||||
else:
|
||||
return response
|
||||
|
||||
|
||||
def get_llm_handler(llm_type):
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
import json
|
||||
|
||||
from google.protobuf.json_format import MessageToDict
|
||||
|
||||
|
||||
class ToolActionParser:
|
||||
def __init__(self, llm_type):
|
||||
@@ -22,8 +20,7 @@ class ToolActionParser:
|
||||
return tool_id, action_name, call_args
|
||||
|
||||
def _parse_google_llm(self, call):
|
||||
call = MessageToDict(call._pb)
|
||||
call_args = call["args"]
|
||||
tool_id = call["name"].split("_")[-1]
|
||||
action_name = call["name"].rsplit("_", 1)[0]
|
||||
call_args = call.args
|
||||
tool_id = call.name.split("_")[-1]
|
||||
action_name = call.name.rsplit("_", 1)[0]
|
||||
return tool_id, action_name, call_args
|
||||
|
||||
Reference in New Issue
Block a user