mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 08:33:20 +00:00
95 lines
3.4 KiB
Python
95 lines
3.4 KiB
Python
import json
|
|
from abc import ABC, abstractmethod
|
|
|
|
|
|
class LLMHandler(ABC):
|
|
@abstractmethod
|
|
def handle_response(self, agent, resp, tools_dict, messages, **kwargs):
|
|
pass
|
|
|
|
|
|
class OpenAILLMHandler(LLMHandler):
|
|
def handle_response(self, agent, resp, tools_dict, messages):
|
|
while resp.finish_reason == "tool_calls":
|
|
message = json.loads(resp.model_dump_json())["message"]
|
|
keys_to_remove = {"audio", "function_call", "refusal"}
|
|
filtered_data = {
|
|
k: v for k, v in message.items() if k not in keys_to_remove
|
|
}
|
|
messages.append(filtered_data)
|
|
|
|
tool_calls = resp.message.tool_calls
|
|
for call in tool_calls:
|
|
try:
|
|
tool_response, call_id = agent._execute_tool_action(
|
|
tools_dict, call
|
|
)
|
|
messages.append(
|
|
{
|
|
"role": "tool",
|
|
"content": str(tool_response),
|
|
"tool_call_id": call_id,
|
|
}
|
|
)
|
|
except Exception as e:
|
|
messages.append(
|
|
{
|
|
"role": "tool",
|
|
"content": f"Error executing tool: {str(e)}",
|
|
"tool_call_id": call_id,
|
|
}
|
|
)
|
|
resp = agent.llm.gen(
|
|
model=agent.gpt_model, messages=messages, tools=agent.tools
|
|
)
|
|
return resp
|
|
|
|
|
|
class GoogleLLMHandler(LLMHandler):
|
|
def handle_response(self, agent, resp, tools_dict, messages):
|
|
import google.generativeai as genai
|
|
|
|
while (
|
|
hasattr(resp.candidates[0].content.parts[0], "function_call")
|
|
and resp.candidates[0].content.parts[0].function_call
|
|
):
|
|
responses = {}
|
|
for part in resp.candidates[0].content.parts:
|
|
if hasattr(part, "function_call") and part.function_call:
|
|
function_call_part = part
|
|
messages.append(
|
|
genai.protos.Part(
|
|
function_call=genai.protos.FunctionCall(
|
|
name=function_call_part.function_call.name,
|
|
args=function_call_part.function_call.args,
|
|
)
|
|
)
|
|
)
|
|
tool_response, call_id = agent._execute_tool_action(
|
|
tools_dict, function_call_part.function_call
|
|
)
|
|
responses[function_call_part.function_call.name] = tool_response
|
|
response_parts = [
|
|
genai.protos.Part(
|
|
function_response=genai.protos.FunctionResponse(
|
|
name=tool_name, response={"result": response}
|
|
)
|
|
)
|
|
for tool_name, response in responses.items()
|
|
]
|
|
if response_parts:
|
|
messages.append(response_parts)
|
|
resp = agent.llm.gen(
|
|
model=agent.gpt_model, messages=messages, tools=agent.tools
|
|
)
|
|
|
|
return resp.text
|
|
|
|
|
|
def get_llm_handler(llm_type):
|
|
handlers = {
|
|
"openai": OpenAILLMHandler(),
|
|
"google": GoogleLLMHandler(),
|
|
}
|
|
return handlers.get(llm_type, OpenAILLMHandler())
|