fix: google parser, llm handler and other errors

This commit is contained in:
Siddhant Rai
2025-01-17 09:22:41 +05:30
parent 811dfecf98
commit c97d1e3363
4 changed files with 141 additions and 102 deletions

View File

@@ -1,60 +1,77 @@
from application.llm.base import BaseLLM import google.generativeai as genai
from application.core.settings import settings from application.core.settings import settings
import logging from application.llm.base import BaseLLM
class GoogleLLM(BaseLLM): class GoogleLLM(BaseLLM):
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs): def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.api_key = settings.API_KEY self.api_key = settings.API_KEY
self.user_api_key = user_api_key genai.configure(api_key=self.api_key)
def _clean_messages_google(self, messages): def _clean_messages_google(self, messages):
return [ cleaned_messages = []
{ for message in messages[1:]:
"role": "model" if message["role"] == "system" else message["role"], cleaned_messages.append(
"parts": [message["content"]], {
} "role": "model" if message["role"] == "system" else message["role"],
for message in messages[1:] "parts": [message["content"]],
] }
)
return cleaned_messages
def _clean_tools_format(self, tools_data): def _clean_tools_format(self, tools_data):
"""
Cleans the tools data format, converting string type representations
to the expected dictionary structure for google-generativeai.
"""
if isinstance(tools_data, list): if isinstance(tools_data, list):
return [self._clean_tools_format(item) for item in tools_data] return [self._clean_tools_format(item) for item in tools_data]
elif isinstance(tools_data, dict): elif isinstance(tools_data, dict):
if 'function' in tools_data and 'type' in tools_data and tools_data['type'] == 'function': if (
"function" in tools_data
and "type" in tools_data
and tools_data["type"] == "function"
):
# Handle the case where tools are nested under 'function' # Handle the case where tools are nested under 'function'
cleaned_function = self._clean_tools_format(tools_data['function']) cleaned_function = self._clean_tools_format(tools_data["function"])
return {'function_declarations': [cleaned_function]} return {"function_declarations": [cleaned_function]}
elif 'function' in tools_data and 'type_' in tools_data and tools_data['type_'] == 'function': elif (
"function" in tools_data
and "type_" in tools_data
and tools_data["type_"] == "function"
):
# Handle the case where tools are nested under 'function' and type is already 'type_' # Handle the case where tools are nested under 'function' and type is already 'type_'
cleaned_function = self._clean_tools_format(tools_data['function']) cleaned_function = self._clean_tools_format(tools_data["function"])
return {'function_declarations': [cleaned_function]} return {"function_declarations": [cleaned_function]}
else: else:
new_tools_data = {} new_tools_data = {}
for key, value in tools_data.items(): for key, value in tools_data.items():
if key == 'type': if key == "type":
if value == 'string': if value == "string":
new_tools_data['type_'] = 'STRING' # Keep as string for now new_tools_data["type_"] = "STRING"
elif value == 'object': elif value == "object":
new_tools_data['type_'] = 'OBJECT' # Keep as string for now new_tools_data["type_"] = "OBJECT"
elif key == 'additionalProperties': elif key == "additionalProperties":
continue continue
elif key == 'properties': elif key == "properties":
if isinstance(value, dict): if isinstance(value, dict):
new_properties = {} new_properties = {}
for prop_name, prop_value in value.items(): for prop_name, prop_value in value.items():
if isinstance(prop_value, dict) and 'type' in prop_value: if (
if prop_value['type'] == 'string': isinstance(prop_value, dict)
new_properties[prop_name] = {'type_': 'STRING', 'description': prop_value.get('description')} and "type" in prop_value
):
if prop_value["type"] == "string":
new_properties[prop_name] = {
"type_": "STRING",
"description": prop_value.get(
"description"
),
}
# Add more type mappings as needed # Add more type mappings as needed
else: else:
new_properties[prop_name] = self._clean_tools_format(prop_value) new_properties[prop_name] = (
self._clean_tools_format(prop_value)
)
new_tools_data[key] = new_properties new_tools_data[key] = new_properties
else: else:
new_tools_data[key] = self._clean_tools_format(value) new_tools_data[key] = self._clean_tools_format(value)
@@ -74,65 +91,64 @@ class GoogleLLM(BaseLLM):
tools=None, tools=None,
formatting="openai", formatting="openai",
**kwargs **kwargs
): ):
from google import genai config = {}
from google.genai import types model_name = "gemini-2.0-flash-exp"
client = genai.Client(api_key=self.api_key)
if formatting == "raw":
config = { client = genai.GenerativeModel(model_name=model_name)
} response = client.generate_content(contents=messages)
model = 'gemini-2.0-flash-exp'
if formatting=="raw":
response = client.models.generate_content(
model=model,
contents=messages
)
else:
model = genai.GenerativeModel(
model_name=model,
generation_config=config,
system_instruction=messages[0]["content"],
tools=self._clean_tools_format(tools)
)
chat_session = model.start_chat(
history=self._clean_messages_google(messages)[:-1]
)
response = chat_session.send_message(
self._clean_messages_google(messages)[-1]
)
logging.info(response)
return response.text return response.text
else:
if tools:
client = genai.GenerativeModel(
model_name=model_name,
generation_config=config,
system_instruction=messages[0]["content"],
tools=self._clean_tools_format(tools),
)
chat_session = gen_model.start_chat(
history=self._clean_messages_google(messages)[:-1]
)
response = chat_session.send_message(
self._clean_messages_google(messages)[-1]
)
return response
else:
gen_model = genai.GenerativeModel(
model_name=model_name,
generation_config=config,
system_instruction=messages[0]["content"],
)
chat_session = gen_model.start_chat(
history=self._clean_messages_google(messages)[:-1]
)
response = chat_session.send_message(
self._clean_messages_google(messages)[-1]
)
return response.text
def _raw_gen_stream( def _raw_gen_stream(
self, self, baseself, model, messages, stream=True, tools=None, **kwargs
baseself, ):
model, config = {}
messages, model_name = "gemini-2.0-flash-exp"
stream=True,
tools=None, gen_model = genai.GenerativeModel(
**kwargs model_name=model_name,
):
import google.generativeai as genai
genai.configure(api_key=self.api_key)
config = {
}
model = genai.GenerativeModel(
model_name=model,
generation_config=config, generation_config=config,
system_instruction=messages[0]["content"] system_instruction=messages[0]["content"],
) tools=self._clean_tools_format(tools),
chat_session = model.start_chat( )
chat_session = gen_model.start_chat(
history=self._clean_messages_google(messages)[:-1], history=self._clean_messages_google(messages)[:-1],
) )
response = chat_session.send_message( response = chat_session.send_message(
self._clean_messages_google(messages)[-1] self._clean_messages_google(messages)[-1], stream=stream
, stream=stream
) )
for line in response: for chunk in response:
if line.text is not None: if chunk.text is not None:
yield line.text yield chunk.text
def _supports_tools(self): def _supports_tools(self):
return True return True

View File

@@ -89,7 +89,7 @@ class Agent:
if isinstance(resp, str): if isinstance(resp, str):
yield resp yield resp
return return
if resp.message.content: if hasattr(resp, "message") and hasattr(resp.message, "content"):
yield resp.message.content yield resp.message.content
return return
@@ -98,7 +98,7 @@ class Agent:
# If no tool calls are needed, generate the final response # If no tool calls are needed, generate the final response
if isinstance(resp, str): if isinstance(resp, str):
yield resp yield resp
elif resp.message.content: elif hasattr(resp, "message") and hasattr(resp.message, "content"):
yield resp.message.content yield resp.message.content
else: else:
completion = self.llm.gen_stream( completion = self.llm.gen_stream(

View File

@@ -47,23 +47,43 @@ class OpenAILLMHandler(LLMHandler):
class GoogleLLMHandler(LLMHandler): class GoogleLLMHandler(LLMHandler):
def handle_response(self, agent, resp, tools_dict, messages): def handle_response(self, agent, resp, tools_dict, messages):
from google.genai import types import google.generativeai as genai
while resp.content.parts[0].function_call: while (
function_call_part = resp.candidates[0].content.parts[0] hasattr(resp.candidates[0].content.parts[0], "function_call")
tool_response, call_id = agent._execute_tool_action( and resp.candidates[0].content.parts[0].function_call
tools_dict, function_call_part.function_call ):
) responses = {}
function_response_part = types.Part.from_function_response( for part in resp.candidates[0].content.parts:
name=function_call_part.function_call.name, response=tool_response if hasattr(part, "function_call") and part.function_call:
) function_call_part = part
messages.append(
messages.append(function_call_part, function_response_part) genai.protos.Part(
function_call=genai.protos.FunctionCall(
name=function_call_part.function_call.name,
args=function_call_part.function_call.args,
)
)
)
tool_response, call_id = agent._execute_tool_action(
tools_dict, function_call_part.function_call
)
responses[function_call_part.function_call.name] = tool_response
response_parts = [
genai.protos.Part(
function_response=genai.protos.FunctionResponse(
name=tool_name, response={"result": response}
)
)
for tool_name, response in responses.items()
]
if response_parts:
messages.append(response_parts)
resp = agent.llm.gen( resp = agent.llm.gen(
model=agent.gpt_model, messages=messages, tools=agent.tools model=agent.gpt_model, messages=messages, tools=agent.tools
) )
return resp return resp.text
def get_llm_handler(llm_type): def get_llm_handler(llm_type):

View File

@@ -1,5 +1,7 @@
import json import json
from google.protobuf.json_format import MessageToDict
class ToolActionParser: class ToolActionParser:
def __init__(self, llm_type): def __init__(self, llm_type):
@@ -20,7 +22,8 @@ class ToolActionParser:
return tool_id, action_name, call_args return tool_id, action_name, call_args
def _parse_google_llm(self, call): def _parse_google_llm(self, call):
call_args = json.loads(call.args) call = MessageToDict(call._pb)
tool_id = call.name.split("_")[-1] call_args = call["args"]
action_name = call.name.rsplit("_", 1)[0] tool_id = call["name"].split("_")[-1]
action_name = call["name"].rsplit("_", 1)[0]
return tool_id, action_name, call_args return tool_id, action_name, call_args