mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 08:33:20 +00:00
Merge pull request #1581 from siiddhantt/refactor/parser-and-handler-in-tools
refactor: tool agent for action parser and handlers
This commit is contained in:
@@ -33,7 +33,7 @@ def gen_cache_key(messages, model="docgpt", tools=None):
|
||||
if not all(isinstance(msg, dict) for msg in messages):
|
||||
raise ValueError("All messages must be dictionaries.")
|
||||
messages_str = json.dumps(messages)
|
||||
tools_str = json.dumps(tools) if tools else ""
|
||||
tools_str = json.dumps(str(tools)) if tools else ""
|
||||
combined = f"{model}_{messages_str}_{tools_str}"
|
||||
cache_key = get_hash(combined)
|
||||
return cache_key
|
||||
@@ -68,8 +68,8 @@ def gen_cache(func):
|
||||
|
||||
|
||||
def stream_cache(func):
|
||||
def wrapper(self, model, messages, stream, *args, **kwargs):
|
||||
cache_key = gen_cache_key(messages)
|
||||
def wrapper(self, model, messages, stream, tools=None, *args, **kwargs):
|
||||
cache_key = gen_cache_key(messages, model, tools)
|
||||
logger.info(f"Stream cache key: {cache_key}")
|
||||
|
||||
redis_client = get_redis_instance()
|
||||
@@ -86,7 +86,7 @@ def stream_cache(func):
|
||||
except redis.ConnectionError as e:
|
||||
logger.error(f"Redis connection error: {e}")
|
||||
|
||||
result = func(self, model, messages, stream, *args, **kwargs)
|
||||
result = func(self, model, messages, stream, tools=tools, *args, **kwargs)
|
||||
stream_cache_data = []
|
||||
|
||||
for chunk in result:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from application.cache import gen_cache, stream_cache
|
||||
from application.usage import gen_token_usage, stream_token_usage
|
||||
from application.cache import stream_cache, gen_cache
|
||||
|
||||
|
||||
class BaseLLM(ABC):
|
||||
@@ -18,18 +19,38 @@ class BaseLLM(ABC):
|
||||
|
||||
def gen(self, model, messages, stream=False, tools=None, *args, **kwargs):
|
||||
decorators = [gen_token_usage, gen_cache]
|
||||
return self._apply_decorator(self._raw_gen, decorators=decorators, model=model, messages=messages, stream=stream, tools=tools, *args, **kwargs)
|
||||
return self._apply_decorator(
|
||||
self._raw_gen,
|
||||
decorators=decorators,
|
||||
model=model,
|
||||
messages=messages,
|
||||
stream=stream,
|
||||
tools=tools,
|
||||
*args,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def _raw_gen_stream(self, model, messages, stream, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def gen_stream(self, model, messages, stream=True, *args, **kwargs):
|
||||
def gen_stream(self, model, messages, stream=True, tools=None, *args, **kwargs):
|
||||
decorators = [stream_cache, stream_token_usage]
|
||||
return self._apply_decorator(self._raw_gen_stream, decorators=decorators, model=model, messages=messages, stream=stream, *args, **kwargs)
|
||||
|
||||
return self._apply_decorator(
|
||||
self._raw_gen_stream,
|
||||
decorators=decorators,
|
||||
model=model,
|
||||
messages=messages,
|
||||
stream=stream,
|
||||
tools=tools,
|
||||
*args,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def supports_tools(self):
|
||||
return hasattr(self, '_supports_tools') and callable(getattr(self, '_supports_tools'))
|
||||
return hasattr(self, "_supports_tools") and callable(
|
||||
getattr(self, "_supports_tools")
|
||||
)
|
||||
|
||||
def _supports_tools(self):
|
||||
raise NotImplementedError("Subclass must implement _supports_tools method")
|
||||
raise NotImplementedError("Subclass must implement _supports_tools method")
|
||||
|
||||
@@ -1,69 +1,85 @@
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
|
||||
from application.llm.base import BaseLLM
|
||||
from application.core.settings import settings
|
||||
import logging
|
||||
|
||||
|
||||
class GoogleLLM(BaseLLM):
|
||||
|
||||
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
self.api_key = settings.API_KEY
|
||||
self.api_key = api_key
|
||||
self.user_api_key = user_api_key
|
||||
|
||||
def _clean_messages_google(self, messages):
|
||||
return [
|
||||
{
|
||||
"role": "model" if message["role"] == "system" else message["role"],
|
||||
"parts": [message["content"]],
|
||||
}
|
||||
for message in messages[1:]
|
||||
]
|
||||
|
||||
def _clean_tools_format(self, tools_data):
|
||||
"""
|
||||
Cleans the tools data format, converting string type representations
|
||||
to the expected dictionary structure for google-generativeai.
|
||||
"""
|
||||
if isinstance(tools_data, list):
|
||||
return [self._clean_tools_format(item) for item in tools_data]
|
||||
elif isinstance(tools_data, dict):
|
||||
if 'function' in tools_data and 'type' in tools_data and tools_data['type'] == 'function':
|
||||
# Handle the case where tools are nested under 'function'
|
||||
cleaned_function = self._clean_tools_format(tools_data['function'])
|
||||
return {'function_declarations': [cleaned_function]}
|
||||
elif 'function' in tools_data and 'type_' in tools_data and tools_data['type_'] == 'function':
|
||||
# Handle the case where tools are nested under 'function' and type is already 'type_'
|
||||
cleaned_function = self._clean_tools_format(tools_data['function'])
|
||||
return {'function_declarations': [cleaned_function]}
|
||||
else:
|
||||
new_tools_data = {}
|
||||
for key, value in tools_data.items():
|
||||
if key == 'type':
|
||||
if value == 'string':
|
||||
new_tools_data['type_'] = 'STRING' # Keep as string for now
|
||||
elif value == 'object':
|
||||
new_tools_data['type_'] = 'OBJECT' # Keep as string for now
|
||||
elif key == 'additionalProperties':
|
||||
continue
|
||||
elif key == 'properties':
|
||||
if isinstance(value, dict):
|
||||
new_properties = {}
|
||||
for prop_name, prop_value in value.items():
|
||||
if isinstance(prop_value, dict) and 'type' in prop_value:
|
||||
if prop_value['type'] == 'string':
|
||||
new_properties[prop_name] = {'type_': 'STRING', 'description': prop_value.get('description')}
|
||||
# Add more type mappings as needed
|
||||
else:
|
||||
new_properties[prop_name] = self._clean_tools_format(prop_value)
|
||||
new_tools_data[key] = new_properties
|
||||
else:
|
||||
new_tools_data[key] = self._clean_tools_format(value)
|
||||
cleaned_messages = []
|
||||
for message in messages:
|
||||
role = message.get("role")
|
||||
content = message.get("content")
|
||||
|
||||
else:
|
||||
new_tools_data[key] = self._clean_tools_format(value)
|
||||
return new_tools_data
|
||||
else:
|
||||
return tools_data
|
||||
if role == "assistant":
|
||||
role = "model"
|
||||
|
||||
parts = []
|
||||
if role and content is not None:
|
||||
if isinstance(content, str):
|
||||
parts = [types.Part.from_text(content)]
|
||||
elif isinstance(content, list):
|
||||
for item in content:
|
||||
if "text" in item:
|
||||
parts.append(types.Part.from_text(item["text"]))
|
||||
elif "function_call" in item:
|
||||
parts.append(
|
||||
types.Part.from_function_call(
|
||||
name=item["function_call"]["name"],
|
||||
args=item["function_call"]["args"],
|
||||
)
|
||||
)
|
||||
elif "function_response" in item:
|
||||
parts.append(
|
||||
types.Part.from_function_response(
|
||||
name=item["function_response"]["name"],
|
||||
response=item["function_response"]["response"],
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected content dictionary format:{item}"
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unexpected content type: {type(content)}")
|
||||
|
||||
cleaned_messages.append(types.Content(role=role, parts=parts))
|
||||
|
||||
return cleaned_messages
|
||||
|
||||
def _clean_tools_format(self, tools_list):
|
||||
genai_tools = []
|
||||
for tool_data in tools_list:
|
||||
if tool_data["type"] == "function":
|
||||
function = tool_data["function"]
|
||||
genai_function = dict(
|
||||
name=function["name"],
|
||||
description=function["description"],
|
||||
parameters={
|
||||
"type": "OBJECT",
|
||||
"properties": {
|
||||
k: {
|
||||
**v,
|
||||
"type": v["type"].upper() if v["type"] else None,
|
||||
}
|
||||
for k, v in function["parameters"]["properties"].items()
|
||||
},
|
||||
"required": (
|
||||
function["parameters"]["required"]
|
||||
if "required" in function["parameters"]
|
||||
else []
|
||||
),
|
||||
},
|
||||
)
|
||||
genai_tool = types.Tool(function_declarations=[genai_function])
|
||||
genai_tools.append(genai_tool)
|
||||
|
||||
return genai_tools
|
||||
|
||||
def _raw_gen(
|
||||
self,
|
||||
@@ -73,36 +89,29 @@ class GoogleLLM(BaseLLM):
|
||||
stream=False,
|
||||
tools=None,
|
||||
formatting="openai",
|
||||
**kwargs
|
||||
):
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
**kwargs,
|
||||
):
|
||||
client = genai.Client(api_key=self.api_key)
|
||||
if formatting == "openai":
|
||||
messages = self._clean_messages_google(messages)
|
||||
config = types.GenerateContentConfig()
|
||||
if messages[0].role == "system":
|
||||
config.system_instruction = messages[0].parts[0].text
|
||||
messages = messages[1:]
|
||||
|
||||
|
||||
config = {
|
||||
}
|
||||
model = 'gemini-2.0-flash-exp'
|
||||
if formatting=="raw":
|
||||
if tools:
|
||||
cleaned_tools = self._clean_tools_format(tools)
|
||||
config.tools = cleaned_tools
|
||||
response = client.models.generate_content(
|
||||
model=model,
|
||||
contents=messages
|
||||
contents=messages,
|
||||
config=config,
|
||||
)
|
||||
|
||||
return response
|
||||
else:
|
||||
model = genai.GenerativeModel(
|
||||
model_name=model,
|
||||
generation_config=config,
|
||||
system_instruction=messages[0]["content"],
|
||||
tools=self._clean_tools_format(tools)
|
||||
)
|
||||
chat_session = model.start_chat(
|
||||
history=self._clean_messages_google(messages)[:-1]
|
||||
response = client.models.generate_content(
|
||||
model=model, contents=messages, config=config
|
||||
)
|
||||
response = chat_session.send_message(
|
||||
self._clean_messages_google(messages)[-1]
|
||||
)
|
||||
logging.info(response)
|
||||
return response.text
|
||||
|
||||
def _raw_gen_stream(
|
||||
@@ -112,27 +121,29 @@ class GoogleLLM(BaseLLM):
|
||||
messages,
|
||||
stream=True,
|
||||
tools=None,
|
||||
**kwargs
|
||||
):
|
||||
import google.generativeai as genai
|
||||
genai.configure(api_key=self.api_key)
|
||||
config = {
|
||||
}
|
||||
model = genai.GenerativeModel(
|
||||
model_name=model,
|
||||
generation_config=config,
|
||||
system_instruction=messages[0]["content"]
|
||||
)
|
||||
chat_session = model.start_chat(
|
||||
history=self._clean_messages_google(messages)[:-1],
|
||||
formatting="openai",
|
||||
**kwargs,
|
||||
):
|
||||
client = genai.Client(api_key=self.api_key)
|
||||
if formatting == "openai":
|
||||
messages = self._clean_messages_google(messages)
|
||||
config = types.GenerateContentConfig()
|
||||
if messages[0].role == "system":
|
||||
config.system_instruction = messages[0].parts[0].text
|
||||
messages = messages[1:]
|
||||
|
||||
if tools:
|
||||
cleaned_tools = self._clean_tools_format(tools)
|
||||
config.tools = cleaned_tools
|
||||
|
||||
response = client.models.generate_content_stream(
|
||||
model=model,
|
||||
contents=messages,
|
||||
config=config,
|
||||
)
|
||||
response = chat_session.send_message(
|
||||
self._clean_messages_google(messages)[-1]
|
||||
, stream=stream
|
||||
)
|
||||
for line in response:
|
||||
if line.text is not None:
|
||||
yield line.text
|
||||
|
||||
for chunk in response:
|
||||
if chunk.text is not None:
|
||||
yield chunk.text
|
||||
|
||||
def _supports_tools(self):
|
||||
return True
|
||||
return True
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from application.llm.base import BaseLLM
|
||||
from application.core.settings import settings
|
||||
|
||||
from application.llm.base import BaseLLM
|
||||
|
||||
|
||||
class OpenAILLM(BaseLLM):
|
||||
@@ -10,10 +9,7 @@ class OpenAILLM(BaseLLM):
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
if settings.OPENAI_BASE_URL:
|
||||
self.client = OpenAI(
|
||||
api_key=api_key,
|
||||
base_url=settings.OPENAI_BASE_URL
|
||||
)
|
||||
self.client = OpenAI(api_key=api_key, base_url=settings.OPENAI_BASE_URL)
|
||||
else:
|
||||
self.client = OpenAI(api_key=api_key)
|
||||
self.api_key = api_key
|
||||
@@ -27,8 +23,8 @@ class OpenAILLM(BaseLLM):
|
||||
stream=False,
|
||||
tools=None,
|
||||
engine=settings.AZURE_DEPLOYMENT_NAME,
|
||||
**kwargs
|
||||
):
|
||||
**kwargs,
|
||||
):
|
||||
if tools:
|
||||
response = self.client.chat.completions.create(
|
||||
model=model, messages=messages, stream=stream, tools=tools, **kwargs
|
||||
@@ -48,18 +44,16 @@ class OpenAILLM(BaseLLM):
|
||||
stream=True,
|
||||
tools=None,
|
||||
engine=settings.AZURE_DEPLOYMENT_NAME,
|
||||
**kwargs
|
||||
):
|
||||
**kwargs,
|
||||
):
|
||||
response = self.client.chat.completions.create(
|
||||
model=model, messages=messages, stream=stream, **kwargs
|
||||
)
|
||||
|
||||
for line in response:
|
||||
# import sys
|
||||
# print(line.choices[0].delta.content, file=sys.stderr)
|
||||
if line.choices[0].delta.content is not None:
|
||||
yield line.choices[0].delta.content
|
||||
|
||||
|
||||
def _supports_tools(self):
|
||||
return True
|
||||
|
||||
|
||||
@@ -14,6 +14,8 @@ esutils==1.0.1
|
||||
Flask==3.0.3
|
||||
faiss-cpu==1.9.0.post1
|
||||
flask-restx==1.3.0
|
||||
google-genai==0.5.0
|
||||
google-generativeai==0.8.3
|
||||
gTTS==2.3.2
|
||||
gunicorn==23.0.0
|
||||
html2text==2024.2.26
|
||||
|
||||
@@ -74,12 +74,10 @@ class BraveRetSearch(BaseRetriever):
|
||||
if len(self.chat_history) > 1:
|
||||
for i in self.chat_history:
|
||||
if "prompt" in i and "response" in i:
|
||||
messages_combine.append(
|
||||
{"role": "user", "content": i["prompt"]}
|
||||
)
|
||||
messages_combine.append(
|
||||
{"role": "system", "content": i["response"]}
|
||||
)
|
||||
messages_combine.append({"role": "user", "content": i["prompt"]})
|
||||
messages_combine.append(
|
||||
{"role": "assistant", "content": i["response"]}
|
||||
)
|
||||
messages_combine.append({"role": "user", "content": self.question})
|
||||
|
||||
llm = LLMCreator.create_llm(
|
||||
|
||||
@@ -5,7 +5,6 @@ from application.tools.agent import Agent
|
||||
from application.vectorstore.vector_creator import VectorCreator
|
||||
|
||||
|
||||
|
||||
class ClassicRAG(BaseRetriever):
|
||||
|
||||
def __init__(
|
||||
@@ -74,13 +73,11 @@ class ClassicRAG(BaseRetriever):
|
||||
|
||||
if len(self.chat_history) > 1:
|
||||
for i in self.chat_history:
|
||||
if "prompt" in i and "response" in i:
|
||||
messages_combine.append(
|
||||
{"role": "user", "content": i["prompt"]}
|
||||
)
|
||||
messages_combine.append(
|
||||
{"role": "system", "content": i["response"]}
|
||||
)
|
||||
if "prompt" in i and "response" in i:
|
||||
messages_combine.append({"role": "user", "content": i["prompt"]})
|
||||
messages_combine.append(
|
||||
{"role": "assistant", "content": i["response"]}
|
||||
)
|
||||
messages_combine.append({"role": "user", "content": self.question})
|
||||
# llm = LLMCreator.create_llm(
|
||||
# settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=self.user_api_key
|
||||
|
||||
@@ -91,11 +91,9 @@ class DuckDuckSearch(BaseRetriever):
|
||||
if len(self.chat_history) > 1:
|
||||
for i in self.chat_history:
|
||||
if "prompt" in i and "response" in i:
|
||||
messages_combine.append({"role": "user", "content": i["prompt"]})
|
||||
messages_combine.append(
|
||||
{"role": "user", "content": i["prompt"]}
|
||||
)
|
||||
messages_combine.append(
|
||||
{"role": "system", "content": i["response"]}
|
||||
{"role": "assistant", "content": i["response"]}
|
||||
)
|
||||
messages_combine.append({"role": "user", "content": self.question})
|
||||
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
from application.tools.llm_handler import get_llm_handler
|
||||
from application.tools.tool_action_parser import ToolActionParser
|
||||
from application.tools.tool_manager import ToolManager
|
||||
|
||||
|
||||
@@ -12,6 +11,7 @@ class Agent:
|
||||
self.llm = LLMCreator.create_llm(
|
||||
llm_name, api_key=api_key, user_api_key=user_api_key
|
||||
)
|
||||
self.llm_handler = get_llm_handler(llm_name)
|
||||
self.gpt_model = gpt_model
|
||||
# Static tool configuration (to be replaced later)
|
||||
self.tools = []
|
||||
@@ -61,10 +61,8 @@ class Agent:
|
||||
]
|
||||
|
||||
def _execute_tool_action(self, tools_dict, call):
|
||||
call_id = call.id
|
||||
call_args = json.loads(call.function.arguments)
|
||||
tool_id = call.function.name.split("_")[-1]
|
||||
action_name = call.function.name.rsplit("_", 1)[0]
|
||||
parser = ToolActionParser(self.llm.__class__.__name__)
|
||||
tool_id, action_name, call_args = parser.parse_args(call)
|
||||
|
||||
tool_data = tools_dict[tool_id]
|
||||
action_data = next(
|
||||
@@ -78,26 +76,9 @@ class Agent:
|
||||
tm = ToolManager(config={})
|
||||
tool = tm.load_tool(tool_data["name"], tool_config=tool_data["config"])
|
||||
print(f"Executing tool: {action_name} with args: {call_args}")
|
||||
return tool.execute_action(action_name, **call_args), call_id
|
||||
|
||||
def _execute_tool_action_google(self, tools_dict, call):
|
||||
call_args = json.loads(call.args)
|
||||
tool_id = call.name.split("_")[-1]
|
||||
action_name = call.name.rsplit("_", 1)[0]
|
||||
|
||||
tool_data = tools_dict[tool_id]
|
||||
action_data = next(
|
||||
action for action in tool_data["actions"] if action["name"] == action_name
|
||||
)
|
||||
|
||||
for param, details in action_data["parameters"]["properties"].items():
|
||||
if param not in call_args and "value" in details:
|
||||
call_args[param] = details["value"]
|
||||
|
||||
tm = ToolManager(config={})
|
||||
tool = tm.load_tool(tool_data["name"], tool_config=tool_data["config"])
|
||||
print(f"Executing tool: {action_name} with args: {call_args}")
|
||||
return tool.execute_action(action_name, **call_args)
|
||||
result = tool.execute_action(action_name, **call_args)
|
||||
call_id = getattr(call, "id", None)
|
||||
return result, call_id
|
||||
|
||||
def _simple_tool_agent(self, messages):
|
||||
tools_dict = self._get_user_tools()
|
||||
@@ -108,55 +89,15 @@ class Agent:
|
||||
if isinstance(resp, str):
|
||||
yield resp
|
||||
return
|
||||
if resp.message.content:
|
||||
if hasattr(resp, "message") and hasattr(resp.message, "content"):
|
||||
yield resp.message.content
|
||||
return
|
||||
# check if self.llm class is GoogleLLM
|
||||
while self.llm.__class__.__name__ == "GoogleLLM" and resp.content.parts[0].function_call:
|
||||
from google.genai import types
|
||||
|
||||
function_call_part = resp.candidates[0].content.parts[0]
|
||||
tool_response = self._execute_tool_action_google(tools_dict, function_call_part.function_call)
|
||||
function_response_part = types.Part.from_function_response(
|
||||
name=function_call_part.function_call.name,
|
||||
response=tool_response
|
||||
)
|
||||
resp = self.llm_handler.handle_response(self, resp, tools_dict, messages)
|
||||
|
||||
while self.llm.__class__.__name__ == "OpenAILLM" and resp.finish_reason == "tool_calls":
|
||||
message = json.loads(resp.model_dump_json())["message"]
|
||||
keys_to_remove = {"audio", "function_call", "refusal"}
|
||||
filtered_data = {
|
||||
k: v for k, v in message.items() if k not in keys_to_remove
|
||||
}
|
||||
messages.append(filtered_data)
|
||||
tool_calls = resp.message.tool_calls
|
||||
for call in tool_calls:
|
||||
try:
|
||||
tool_response, call_id = self._execute_tool_action(tools_dict, call)
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": str(tool_response),
|
||||
"tool_call_id": call_id,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": f"Error executing tool: {str(e)}",
|
||||
"tool_call_id": call.id,
|
||||
}
|
||||
)
|
||||
# Generate a new response from the LLM after processing tools
|
||||
resp = self.llm.gen(
|
||||
model=self.gpt_model, messages=messages, tools=self.tools
|
||||
)
|
||||
|
||||
# If no tool calls are needed, generate the final response
|
||||
if isinstance(resp, str):
|
||||
yield resp
|
||||
elif resp.message.content:
|
||||
elif hasattr(resp, "message") and hasattr(resp.message, "content"):
|
||||
yield resp.message.content
|
||||
else:
|
||||
completion = self.llm.gen_stream(
|
||||
@@ -168,7 +109,6 @@ class Agent:
|
||||
return
|
||||
|
||||
def gen(self, messages):
|
||||
# Generate initial response from the LLM
|
||||
if self.llm.supports_tools():
|
||||
resp = self._simple_tool_agent(messages)
|
||||
for line in resp:
|
||||
|
||||
97
application/tools/llm_handler.py
Normal file
97
application/tools/llm_handler.py
Normal file
@@ -0,0 +1,97 @@
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class LLMHandler(ABC):
|
||||
@abstractmethod
|
||||
def handle_response(self, agent, resp, tools_dict, messages, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class OpenAILLMHandler(LLMHandler):
|
||||
def handle_response(self, agent, resp, tools_dict, messages):
|
||||
while resp.finish_reason == "tool_calls":
|
||||
message = json.loads(resp.model_dump_json())["message"]
|
||||
keys_to_remove = {"audio", "function_call", "refusal"}
|
||||
filtered_data = {
|
||||
k: v for k, v in message.items() if k not in keys_to_remove
|
||||
}
|
||||
messages.append(filtered_data)
|
||||
|
||||
tool_calls = resp.message.tool_calls
|
||||
for call in tool_calls:
|
||||
try:
|
||||
tool_response, call_id = agent._execute_tool_action(
|
||||
tools_dict, call
|
||||
)
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": str(tool_response),
|
||||
"tool_call_id": call_id,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": f"Error executing tool: {str(e)}",
|
||||
"tool_call_id": call_id,
|
||||
}
|
||||
)
|
||||
resp = agent.llm.gen(
|
||||
model=agent.gpt_model, messages=messages, tools=agent.tools
|
||||
)
|
||||
return resp
|
||||
|
||||
|
||||
class GoogleLLMHandler(LLMHandler):
|
||||
def handle_response(self, agent, resp, tools_dict, messages):
|
||||
from google.genai import types
|
||||
|
||||
while True:
|
||||
response = agent.llm.gen(
|
||||
model=agent.gpt_model, messages=messages, tools=agent.tools
|
||||
)
|
||||
if response.candidates and response.candidates[0].content.parts:
|
||||
tool_call_found = False
|
||||
for part in response.candidates[0].content.parts:
|
||||
if part.function_call:
|
||||
tool_call_found = True
|
||||
tool_response, call_id = agent._execute_tool_action(
|
||||
tools_dict, part.function_call
|
||||
)
|
||||
function_response_part = types.Part.from_function_response(
|
||||
name=part.function_call.name,
|
||||
response={"result": tool_response},
|
||||
)
|
||||
|
||||
messages.append(
|
||||
{"role": "model", "content": [part.to_json_dict()]}
|
||||
)
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": [function_response_part.to_json_dict()],
|
||||
}
|
||||
)
|
||||
|
||||
if (
|
||||
not tool_call_found
|
||||
and response.candidates[0].content.parts
|
||||
and response.candidates[0].content.parts[0].text
|
||||
):
|
||||
return response.candidates[0].content.parts[0].text
|
||||
elif not tool_call_found:
|
||||
return response.candidates[0].content.parts
|
||||
|
||||
else:
|
||||
return response
|
||||
|
||||
|
||||
def get_llm_handler(llm_type):
|
||||
handlers = {
|
||||
"openai": OpenAILLMHandler(),
|
||||
"google": GoogleLLMHandler(),
|
||||
}
|
||||
return handlers.get(llm_type, OpenAILLMHandler())
|
||||
26
application/tools/tool_action_parser.py
Normal file
26
application/tools/tool_action_parser.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import json
|
||||
|
||||
|
||||
class ToolActionParser:
|
||||
def __init__(self, llm_type):
|
||||
self.llm_type = llm_type
|
||||
self.parsers = {
|
||||
"OpenAILLM": self._parse_openai_llm,
|
||||
"GoogleLLM": self._parse_google_llm,
|
||||
}
|
||||
|
||||
def parse_args(self, call):
|
||||
parser = self.parsers.get(self.llm_type, self._parse_openai_llm)
|
||||
return parser(call)
|
||||
|
||||
def _parse_openai_llm(self, call):
|
||||
call_args = json.loads(call.function.arguments)
|
||||
tool_id = call.function.name.split("_")[-1]
|
||||
action_name = call.function.name.rsplit("_", 1)[0]
|
||||
return tool_id, action_name, call_args
|
||||
|
||||
def _parse_google_llm(self, call):
|
||||
call_args = call.args
|
||||
tool_id = call.name.split("_")[-1]
|
||||
action_name = call.name.rsplit("_", 1)[0]
|
||||
return tool_id, action_name, call_args
|
||||
Reference in New Issue
Block a user