Files
DocsGPT/application/tools/llm_handler.py
2025-01-17 09:22:41 +05:30

95 lines
3.4 KiB
Python

import json
from abc import ABC, abstractmethod
class LLMHandler(ABC):
@abstractmethod
def handle_response(self, agent, resp, tools_dict, messages, **kwargs):
pass
class OpenAILLMHandler(LLMHandler):
def handle_response(self, agent, resp, tools_dict, messages):
while resp.finish_reason == "tool_calls":
message = json.loads(resp.model_dump_json())["message"]
keys_to_remove = {"audio", "function_call", "refusal"}
filtered_data = {
k: v for k, v in message.items() if k not in keys_to_remove
}
messages.append(filtered_data)
tool_calls = resp.message.tool_calls
for call in tool_calls:
try:
tool_response, call_id = agent._execute_tool_action(
tools_dict, call
)
messages.append(
{
"role": "tool",
"content": str(tool_response),
"tool_call_id": call_id,
}
)
except Exception as e:
messages.append(
{
"role": "tool",
"content": f"Error executing tool: {str(e)}",
"tool_call_id": call_id,
}
)
resp = agent.llm.gen(
model=agent.gpt_model, messages=messages, tools=agent.tools
)
return resp
class GoogleLLMHandler(LLMHandler):
def handle_response(self, agent, resp, tools_dict, messages):
import google.generativeai as genai
while (
hasattr(resp.candidates[0].content.parts[0], "function_call")
and resp.candidates[0].content.parts[0].function_call
):
responses = {}
for part in resp.candidates[0].content.parts:
if hasattr(part, "function_call") and part.function_call:
function_call_part = part
messages.append(
genai.protos.Part(
function_call=genai.protos.FunctionCall(
name=function_call_part.function_call.name,
args=function_call_part.function_call.args,
)
)
)
tool_response, call_id = agent._execute_tool_action(
tools_dict, function_call_part.function_call
)
responses[function_call_part.function_call.name] = tool_response
response_parts = [
genai.protos.Part(
function_response=genai.protos.FunctionResponse(
name=tool_name, response={"result": response}
)
)
for tool_name, response in responses.items()
]
if response_parts:
messages.append(response_parts)
resp = agent.llm.gen(
model=agent.gpt_model, messages=messages, tools=agent.tools
)
return resp.text
def get_llm_handler(llm_type):
handlers = {
"openai": OpenAILLMHandler(),
"google": GoogleLLMHandler(),
}
return handlers.get(llm_type, OpenAILLMHandler())