mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 08:33:20 +00:00
fix: google parser, llm handler and other errors
This commit is contained in:
@@ -1,60 +1,77 @@
|
|||||||
from application.llm.base import BaseLLM
|
import google.generativeai as genai
|
||||||
|
|
||||||
from application.core.settings import settings
|
from application.core.settings import settings
|
||||||
import logging
|
from application.llm.base import BaseLLM
|
||||||
|
|
||||||
|
|
||||||
class GoogleLLM(BaseLLM):
|
class GoogleLLM(BaseLLM):
|
||||||
|
|
||||||
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
|
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
|
||||||
|
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.api_key = settings.API_KEY
|
self.api_key = settings.API_KEY
|
||||||
self.user_api_key = user_api_key
|
genai.configure(api_key=self.api_key)
|
||||||
|
|
||||||
def _clean_messages_google(self, messages):
|
def _clean_messages_google(self, messages):
|
||||||
return [
|
cleaned_messages = []
|
||||||
{
|
for message in messages[1:]:
|
||||||
"role": "model" if message["role"] == "system" else message["role"],
|
cleaned_messages.append(
|
||||||
"parts": [message["content"]],
|
{
|
||||||
}
|
"role": "model" if message["role"] == "system" else message["role"],
|
||||||
for message in messages[1:]
|
"parts": [message["content"]],
|
||||||
]
|
}
|
||||||
|
)
|
||||||
|
return cleaned_messages
|
||||||
|
|
||||||
def _clean_tools_format(self, tools_data):
|
def _clean_tools_format(self, tools_data):
|
||||||
"""
|
|
||||||
Cleans the tools data format, converting string type representations
|
|
||||||
to the expected dictionary structure for google-generativeai.
|
|
||||||
"""
|
|
||||||
if isinstance(tools_data, list):
|
if isinstance(tools_data, list):
|
||||||
return [self._clean_tools_format(item) for item in tools_data]
|
return [self._clean_tools_format(item) for item in tools_data]
|
||||||
elif isinstance(tools_data, dict):
|
elif isinstance(tools_data, dict):
|
||||||
if 'function' in tools_data and 'type' in tools_data and tools_data['type'] == 'function':
|
if (
|
||||||
|
"function" in tools_data
|
||||||
|
and "type" in tools_data
|
||||||
|
and tools_data["type"] == "function"
|
||||||
|
):
|
||||||
# Handle the case where tools are nested under 'function'
|
# Handle the case where tools are nested under 'function'
|
||||||
cleaned_function = self._clean_tools_format(tools_data['function'])
|
cleaned_function = self._clean_tools_format(tools_data["function"])
|
||||||
return {'function_declarations': [cleaned_function]}
|
return {"function_declarations": [cleaned_function]}
|
||||||
elif 'function' in tools_data and 'type_' in tools_data and tools_data['type_'] == 'function':
|
elif (
|
||||||
|
"function" in tools_data
|
||||||
|
and "type_" in tools_data
|
||||||
|
and tools_data["type_"] == "function"
|
||||||
|
):
|
||||||
# Handle the case where tools are nested under 'function' and type is already 'type_'
|
# Handle the case where tools are nested under 'function' and type is already 'type_'
|
||||||
cleaned_function = self._clean_tools_format(tools_data['function'])
|
cleaned_function = self._clean_tools_format(tools_data["function"])
|
||||||
return {'function_declarations': [cleaned_function]}
|
return {"function_declarations": [cleaned_function]}
|
||||||
else:
|
else:
|
||||||
new_tools_data = {}
|
new_tools_data = {}
|
||||||
for key, value in tools_data.items():
|
for key, value in tools_data.items():
|
||||||
if key == 'type':
|
if key == "type":
|
||||||
if value == 'string':
|
if value == "string":
|
||||||
new_tools_data['type_'] = 'STRING' # Keep as string for now
|
new_tools_data["type_"] = "STRING"
|
||||||
elif value == 'object':
|
elif value == "object":
|
||||||
new_tools_data['type_'] = 'OBJECT' # Keep as string for now
|
new_tools_data["type_"] = "OBJECT"
|
||||||
elif key == 'additionalProperties':
|
elif key == "additionalProperties":
|
||||||
continue
|
continue
|
||||||
elif key == 'properties':
|
elif key == "properties":
|
||||||
if isinstance(value, dict):
|
if isinstance(value, dict):
|
||||||
new_properties = {}
|
new_properties = {}
|
||||||
for prop_name, prop_value in value.items():
|
for prop_name, prop_value in value.items():
|
||||||
if isinstance(prop_value, dict) and 'type' in prop_value:
|
if (
|
||||||
if prop_value['type'] == 'string':
|
isinstance(prop_value, dict)
|
||||||
new_properties[prop_name] = {'type_': 'STRING', 'description': prop_value.get('description')}
|
and "type" in prop_value
|
||||||
|
):
|
||||||
|
if prop_value["type"] == "string":
|
||||||
|
new_properties[prop_name] = {
|
||||||
|
"type_": "STRING",
|
||||||
|
"description": prop_value.get(
|
||||||
|
"description"
|
||||||
|
),
|
||||||
|
}
|
||||||
# Add more type mappings as needed
|
# Add more type mappings as needed
|
||||||
else:
|
else:
|
||||||
new_properties[prop_name] = self._clean_tools_format(prop_value)
|
new_properties[prop_name] = (
|
||||||
|
self._clean_tools_format(prop_value)
|
||||||
|
)
|
||||||
new_tools_data[key] = new_properties
|
new_tools_data[key] = new_properties
|
||||||
else:
|
else:
|
||||||
new_tools_data[key] = self._clean_tools_format(value)
|
new_tools_data[key] = self._clean_tools_format(value)
|
||||||
@@ -74,65 +91,64 @@ class GoogleLLM(BaseLLM):
|
|||||||
tools=None,
|
tools=None,
|
||||||
formatting="openai",
|
formatting="openai",
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
from google import genai
|
config = {}
|
||||||
from google.genai import types
|
model_name = "gemini-2.0-flash-exp"
|
||||||
client = genai.Client(api_key=self.api_key)
|
|
||||||
|
|
||||||
|
if formatting == "raw":
|
||||||
config = {
|
client = genai.GenerativeModel(model_name=model_name)
|
||||||
}
|
response = client.generate_content(contents=messages)
|
||||||
model = 'gemini-2.0-flash-exp'
|
|
||||||
if formatting=="raw":
|
|
||||||
response = client.models.generate_content(
|
|
||||||
model=model,
|
|
||||||
contents=messages
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
model = genai.GenerativeModel(
|
|
||||||
model_name=model,
|
|
||||||
generation_config=config,
|
|
||||||
system_instruction=messages[0]["content"],
|
|
||||||
tools=self._clean_tools_format(tools)
|
|
||||||
)
|
|
||||||
chat_session = model.start_chat(
|
|
||||||
history=self._clean_messages_google(messages)[:-1]
|
|
||||||
)
|
|
||||||
response = chat_session.send_message(
|
|
||||||
self._clean_messages_google(messages)[-1]
|
|
||||||
)
|
|
||||||
logging.info(response)
|
|
||||||
return response.text
|
return response.text
|
||||||
|
else:
|
||||||
|
if tools:
|
||||||
|
client = genai.GenerativeModel(
|
||||||
|
model_name=model_name,
|
||||||
|
generation_config=config,
|
||||||
|
system_instruction=messages[0]["content"],
|
||||||
|
tools=self._clean_tools_format(tools),
|
||||||
|
)
|
||||||
|
chat_session = gen_model.start_chat(
|
||||||
|
history=self._clean_messages_google(messages)[:-1]
|
||||||
|
)
|
||||||
|
response = chat_session.send_message(
|
||||||
|
self._clean_messages_google(messages)[-1]
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
else:
|
||||||
|
gen_model = genai.GenerativeModel(
|
||||||
|
model_name=model_name,
|
||||||
|
generation_config=config,
|
||||||
|
system_instruction=messages[0]["content"],
|
||||||
|
)
|
||||||
|
chat_session = gen_model.start_chat(
|
||||||
|
history=self._clean_messages_google(messages)[:-1]
|
||||||
|
)
|
||||||
|
response = chat_session.send_message(
|
||||||
|
self._clean_messages_google(messages)[-1]
|
||||||
|
)
|
||||||
|
return response.text
|
||||||
|
|
||||||
def _raw_gen_stream(
|
def _raw_gen_stream(
|
||||||
self,
|
self, baseself, model, messages, stream=True, tools=None, **kwargs
|
||||||
baseself,
|
):
|
||||||
model,
|
config = {}
|
||||||
messages,
|
model_name = "gemini-2.0-flash-exp"
|
||||||
stream=True,
|
|
||||||
tools=None,
|
gen_model = genai.GenerativeModel(
|
||||||
**kwargs
|
model_name=model_name,
|
||||||
):
|
|
||||||
import google.generativeai as genai
|
|
||||||
genai.configure(api_key=self.api_key)
|
|
||||||
config = {
|
|
||||||
}
|
|
||||||
model = genai.GenerativeModel(
|
|
||||||
model_name=model,
|
|
||||||
generation_config=config,
|
generation_config=config,
|
||||||
system_instruction=messages[0]["content"]
|
system_instruction=messages[0]["content"],
|
||||||
)
|
tools=self._clean_tools_format(tools),
|
||||||
chat_session = model.start_chat(
|
)
|
||||||
|
chat_session = gen_model.start_chat(
|
||||||
history=self._clean_messages_google(messages)[:-1],
|
history=self._clean_messages_google(messages)[:-1],
|
||||||
)
|
)
|
||||||
response = chat_session.send_message(
|
response = chat_session.send_message(
|
||||||
self._clean_messages_google(messages)[-1]
|
self._clean_messages_google(messages)[-1], stream=stream
|
||||||
, stream=stream
|
|
||||||
)
|
)
|
||||||
for line in response:
|
for chunk in response:
|
||||||
if line.text is not None:
|
if chunk.text is not None:
|
||||||
yield line.text
|
yield chunk.text
|
||||||
|
|
||||||
def _supports_tools(self):
|
def _supports_tools(self):
|
||||||
return True
|
return True
|
||||||
|
|||||||
@@ -89,7 +89,7 @@ class Agent:
|
|||||||
if isinstance(resp, str):
|
if isinstance(resp, str):
|
||||||
yield resp
|
yield resp
|
||||||
return
|
return
|
||||||
if resp.message.content:
|
if hasattr(resp, "message") and hasattr(resp.message, "content"):
|
||||||
yield resp.message.content
|
yield resp.message.content
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -98,7 +98,7 @@ class Agent:
|
|||||||
# If no tool calls are needed, generate the final response
|
# If no tool calls are needed, generate the final response
|
||||||
if isinstance(resp, str):
|
if isinstance(resp, str):
|
||||||
yield resp
|
yield resp
|
||||||
elif resp.message.content:
|
elif hasattr(resp, "message") and hasattr(resp.message, "content"):
|
||||||
yield resp.message.content
|
yield resp.message.content
|
||||||
else:
|
else:
|
||||||
completion = self.llm.gen_stream(
|
completion = self.llm.gen_stream(
|
||||||
|
|||||||
@@ -47,23 +47,43 @@ class OpenAILLMHandler(LLMHandler):
|
|||||||
|
|
||||||
class GoogleLLMHandler(LLMHandler):
|
class GoogleLLMHandler(LLMHandler):
|
||||||
def handle_response(self, agent, resp, tools_dict, messages):
|
def handle_response(self, agent, resp, tools_dict, messages):
|
||||||
from google.genai import types
|
import google.generativeai as genai
|
||||||
|
|
||||||
while resp.content.parts[0].function_call:
|
while (
|
||||||
function_call_part = resp.candidates[0].content.parts[0]
|
hasattr(resp.candidates[0].content.parts[0], "function_call")
|
||||||
tool_response, call_id = agent._execute_tool_action(
|
and resp.candidates[0].content.parts[0].function_call
|
||||||
tools_dict, function_call_part.function_call
|
):
|
||||||
)
|
responses = {}
|
||||||
function_response_part = types.Part.from_function_response(
|
for part in resp.candidates[0].content.parts:
|
||||||
name=function_call_part.function_call.name, response=tool_response
|
if hasattr(part, "function_call") and part.function_call:
|
||||||
)
|
function_call_part = part
|
||||||
|
messages.append(
|
||||||
messages.append(function_call_part, function_response_part)
|
genai.protos.Part(
|
||||||
|
function_call=genai.protos.FunctionCall(
|
||||||
|
name=function_call_part.function_call.name,
|
||||||
|
args=function_call_part.function_call.args,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
tool_response, call_id = agent._execute_tool_action(
|
||||||
|
tools_dict, function_call_part.function_call
|
||||||
|
)
|
||||||
|
responses[function_call_part.function_call.name] = tool_response
|
||||||
|
response_parts = [
|
||||||
|
genai.protos.Part(
|
||||||
|
function_response=genai.protos.FunctionResponse(
|
||||||
|
name=tool_name, response={"result": response}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for tool_name, response in responses.items()
|
||||||
|
]
|
||||||
|
if response_parts:
|
||||||
|
messages.append(response_parts)
|
||||||
resp = agent.llm.gen(
|
resp = agent.llm.gen(
|
||||||
model=agent.gpt_model, messages=messages, tools=agent.tools
|
model=agent.gpt_model, messages=messages, tools=agent.tools
|
||||||
)
|
)
|
||||||
|
|
||||||
return resp
|
return resp.text
|
||||||
|
|
||||||
|
|
||||||
def get_llm_handler(llm_type):
|
def get_llm_handler(llm_type):
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
|
|
||||||
|
from google.protobuf.json_format import MessageToDict
|
||||||
|
|
||||||
|
|
||||||
class ToolActionParser:
|
class ToolActionParser:
|
||||||
def __init__(self, llm_type):
|
def __init__(self, llm_type):
|
||||||
@@ -20,7 +22,8 @@ class ToolActionParser:
|
|||||||
return tool_id, action_name, call_args
|
return tool_id, action_name, call_args
|
||||||
|
|
||||||
def _parse_google_llm(self, call):
|
def _parse_google_llm(self, call):
|
||||||
call_args = json.loads(call.args)
|
call = MessageToDict(call._pb)
|
||||||
tool_id = call.name.split("_")[-1]
|
call_args = call["args"]
|
||||||
action_name = call.name.rsplit("_", 1)[0]
|
tool_id = call["name"].split("_")[-1]
|
||||||
|
action_name = call["name"].rsplit("_", 1)[0]
|
||||||
return tool_id, action_name, call_args
|
return tool_id, action_name, call_args
|
||||||
|
|||||||
Reference in New Issue
Block a user