fix: GoogleLLM, agent and handler according to the new genai SDK

This commit is contained in:
Siddhant Rai
2025-01-18 19:56:25 +05:30
parent ec270a3b54
commit 904b0bf2da
4 changed files with 116 additions and 158 deletions

View File

@@ -47,43 +47,41 @@ class OpenAILLMHandler(LLMHandler):
class GoogleLLMHandler(LLMHandler):
def handle_response(self, agent, resp, tools_dict, messages):
import google.generativeai as genai
from google.genai import types
while (
hasattr(resp.candidates[0].content.parts[0], "function_call")
and resp.candidates[0].content.parts[0].function_call
):
responses = {}
for part in resp.candidates[0].content.parts:
if hasattr(part, "function_call") and part.function_call:
function_call_part = part
messages.append(
genai.protos.Part(
function_call=genai.protos.FunctionCall(
name=function_call_part.function_call.name,
args=function_call_part.function_call.args,
)
)
)
tool_response, call_id = agent._execute_tool_action(
tools_dict, function_call_part.function_call
)
responses[function_call_part.function_call.name] = tool_response
response_parts = [
genai.protos.Part(
function_response=genai.protos.FunctionResponse(
name=tool_name, response={"result": response}
)
)
for tool_name, response in responses.items()
]
if response_parts:
messages.append(response_parts)
resp = agent.llm.gen(
while True:
response = agent.llm.gen(
model=agent.gpt_model, messages=messages, tools=agent.tools
)
if response.candidates and response.candidates[0].content.parts:
tool_call_found = False
for part in response.candidates[0].content.parts:
if part.function_call:
tool_call_found = True
tool_response, call_id = agent._execute_tool_action(
tools_dict, part.function_call
)
return resp.text
function_response_part = types.Part.from_function_response(
name=part.function_call.name,
response={"result": tool_response},
)
messages.append({"role": "model", "content": [part]})
messages.append(
{"role": "tool", "content": [function_response_part]}
)
if (
not tool_call_found
and response.candidates[0].content.parts
and response.candidates[0].content.parts[0].text
):
return response.candidates[0].content.parts[0].text
elif not tool_call_found:
return response.candidates[0].content.parts
else:
return response
def get_llm_handler(llm_type):