Merge pull request #1581 from siiddhantt/refactor/parser-and-handler-in-tools

refactor: tool agent for action parser and handlers
This commit is contained in:
Alex
2025-01-21 09:48:22 +00:00
committed by GitHub
11 changed files with 299 additions and 215 deletions

View File

@@ -33,7 +33,7 @@ def gen_cache_key(messages, model="docgpt", tools=None):
if not all(isinstance(msg, dict) for msg in messages):
raise ValueError("All messages must be dictionaries.")
messages_str = json.dumps(messages)
tools_str = json.dumps(tools) if tools else ""
tools_str = json.dumps(str(tools)) if tools else ""
combined = f"{model}_{messages_str}_{tools_str}"
cache_key = get_hash(combined)
return cache_key
@@ -68,8 +68,8 @@ def gen_cache(func):
def stream_cache(func):
def wrapper(self, model, messages, stream, *args, **kwargs):
cache_key = gen_cache_key(messages)
def wrapper(self, model, messages, stream, tools=None, *args, **kwargs):
cache_key = gen_cache_key(messages, model, tools)
logger.info(f"Stream cache key: {cache_key}")
redis_client = get_redis_instance()
@@ -86,7 +86,7 @@ def stream_cache(func):
except redis.ConnectionError as e:
logger.error(f"Redis connection error: {e}")
result = func(self, model, messages, stream, *args, **kwargs)
result = func(self, model, messages, stream, tools=tools, *args, **kwargs)
stream_cache_data = []
for chunk in result:

View File

@@ -1,6 +1,7 @@
from abc import ABC, abstractmethod
from application.cache import gen_cache, stream_cache
from application.usage import gen_token_usage, stream_token_usage
from application.cache import stream_cache, gen_cache
class BaseLLM(ABC):
@@ -18,18 +19,38 @@ class BaseLLM(ABC):
def gen(self, model, messages, stream=False, tools=None, *args, **kwargs):
decorators = [gen_token_usage, gen_cache]
return self._apply_decorator(self._raw_gen, decorators=decorators, model=model, messages=messages, stream=stream, tools=tools, *args, **kwargs)
return self._apply_decorator(
self._raw_gen,
decorators=decorators,
model=model,
messages=messages,
stream=stream,
tools=tools,
*args,
**kwargs
)
@abstractmethod
def _raw_gen_stream(self, model, messages, stream, *args, **kwargs):
pass
def gen_stream(self, model, messages, stream=True, *args, **kwargs):
def gen_stream(self, model, messages, stream=True, tools=None, *args, **kwargs):
decorators = [stream_cache, stream_token_usage]
return self._apply_decorator(self._raw_gen_stream, decorators=decorators, model=model, messages=messages, stream=stream, *args, **kwargs)
return self._apply_decorator(
self._raw_gen_stream,
decorators=decorators,
model=model,
messages=messages,
stream=stream,
tools=tools,
*args,
**kwargs
)
def supports_tools(self):
return hasattr(self, '_supports_tools') and callable(getattr(self, '_supports_tools'))
return hasattr(self, "_supports_tools") and callable(
getattr(self, "_supports_tools")
)
def _supports_tools(self):
raise NotImplementedError("Subclass must implement _supports_tools method")

View File

@@ -1,69 +1,85 @@
from google import genai
from google.genai import types
from application.llm.base import BaseLLM
from application.core.settings import settings
import logging
class GoogleLLM(BaseLLM):
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.api_key = settings.API_KEY
self.api_key = api_key
self.user_api_key = user_api_key
def _clean_messages_google(self, messages):
return [
{
"role": "model" if message["role"] == "system" else message["role"],
"parts": [message["content"]],
}
for message in messages[1:]
]
cleaned_messages = []
for message in messages:
role = message.get("role")
content = message.get("content")
def _clean_tools_format(self, tools_data):
"""
Cleans the tools data format, converting string type representations
to the expected dictionary structure for google-generativeai.
"""
if isinstance(tools_data, list):
return [self._clean_tools_format(item) for item in tools_data]
elif isinstance(tools_data, dict):
if 'function' in tools_data and 'type' in tools_data and tools_data['type'] == 'function':
# Handle the case where tools are nested under 'function'
cleaned_function = self._clean_tools_format(tools_data['function'])
return {'function_declarations': [cleaned_function]}
elif 'function' in tools_data and 'type_' in tools_data and tools_data['type_'] == 'function':
# Handle the case where tools are nested under 'function' and type is already 'type_'
cleaned_function = self._clean_tools_format(tools_data['function'])
return {'function_declarations': [cleaned_function]}
else:
new_tools_data = {}
for key, value in tools_data.items():
if key == 'type':
if value == 'string':
new_tools_data['type_'] = 'STRING' # Keep as string for now
elif value == 'object':
new_tools_data['type_'] = 'OBJECT' # Keep as string for now
elif key == 'additionalProperties':
continue
elif key == 'properties':
if isinstance(value, dict):
new_properties = {}
for prop_name, prop_value in value.items():
if isinstance(prop_value, dict) and 'type' in prop_value:
if prop_value['type'] == 'string':
new_properties[prop_name] = {'type_': 'STRING', 'description': prop_value.get('description')}
# Add more type mappings as needed
else:
new_properties[prop_name] = self._clean_tools_format(prop_value)
new_tools_data[key] = new_properties
if role == "assistant":
role = "model"
parts = []
if role and content is not None:
if isinstance(content, str):
parts = [types.Part.from_text(content)]
elif isinstance(content, list):
for item in content:
if "text" in item:
parts.append(types.Part.from_text(item["text"]))
elif "function_call" in item:
parts.append(
types.Part.from_function_call(
name=item["function_call"]["name"],
args=item["function_call"]["args"],
)
)
elif "function_response" in item:
parts.append(
types.Part.from_function_response(
name=item["function_response"]["name"],
response=item["function_response"]["response"],
)
)
else:
new_tools_data[key] = self._clean_tools_format(value)
raise ValueError(
f"Unexpected content dictionary format:{item}"
)
else:
raise ValueError(f"Unexpected content type: {type(content)}")
else:
new_tools_data[key] = self._clean_tools_format(value)
return new_tools_data
else:
return tools_data
cleaned_messages.append(types.Content(role=role, parts=parts))
return cleaned_messages
def _clean_tools_format(self, tools_list):
genai_tools = []
for tool_data in tools_list:
if tool_data["type"] == "function":
function = tool_data["function"]
genai_function = dict(
name=function["name"],
description=function["description"],
parameters={
"type": "OBJECT",
"properties": {
k: {
**v,
"type": v["type"].upper() if v["type"] else None,
}
for k, v in function["parameters"]["properties"].items()
},
"required": (
function["parameters"]["required"]
if "required" in function["parameters"]
else []
),
},
)
genai_tool = types.Tool(function_declarations=[genai_function])
genai_tools.append(genai_tool)
return genai_tools
def _raw_gen(
self,
@@ -73,36 +89,29 @@ class GoogleLLM(BaseLLM):
stream=False,
tools=None,
formatting="openai",
**kwargs
**kwargs,
):
from google import genai
from google.genai import types
client = genai.Client(api_key=self.api_key)
if formatting == "openai":
messages = self._clean_messages_google(messages)
config = types.GenerateContentConfig()
if messages[0].role == "system":
config.system_instruction = messages[0].parts[0].text
messages = messages[1:]
config = {
}
model = 'gemini-2.0-flash-exp'
if formatting=="raw":
if tools:
cleaned_tools = self._clean_tools_format(tools)
config.tools = cleaned_tools
response = client.models.generate_content(
model=model,
contents=messages
contents=messages,
config=config,
)
return response
else:
model = genai.GenerativeModel(
model_name=model,
generation_config=config,
system_instruction=messages[0]["content"],
tools=self._clean_tools_format(tools)
)
chat_session = model.start_chat(
history=self._clean_messages_google(messages)[:-1]
response = client.models.generate_content(
model=model, contents=messages, config=config
)
response = chat_session.send_message(
self._clean_messages_google(messages)[-1]
)
logging.info(response)
return response.text
def _raw_gen_stream(
@@ -112,27 +121,29 @@ class GoogleLLM(BaseLLM):
messages,
stream=True,
tools=None,
**kwargs
formatting="openai",
**kwargs,
):
import google.generativeai as genai
genai.configure(api_key=self.api_key)
config = {
}
model = genai.GenerativeModel(
model_name=model,
generation_config=config,
system_instruction=messages[0]["content"]
)
chat_session = model.start_chat(
history=self._clean_messages_google(messages)[:-1],
client = genai.Client(api_key=self.api_key)
if formatting == "openai":
messages = self._clean_messages_google(messages)
config = types.GenerateContentConfig()
if messages[0].role == "system":
config.system_instruction = messages[0].parts[0].text
messages = messages[1:]
if tools:
cleaned_tools = self._clean_tools_format(tools)
config.tools = cleaned_tools
response = client.models.generate_content_stream(
model=model,
contents=messages,
config=config,
)
response = chat_session.send_message(
self._clean_messages_google(messages)[-1]
, stream=stream
)
for line in response:
if line.text is not None:
yield line.text
for chunk in response:
if chunk.text is not None:
yield chunk.text
def _supports_tools(self):
return True

View File

@@ -1,6 +1,5 @@
from application.llm.base import BaseLLM
from application.core.settings import settings
from application.llm.base import BaseLLM
class OpenAILLM(BaseLLM):
@@ -10,10 +9,7 @@ class OpenAILLM(BaseLLM):
super().__init__(*args, **kwargs)
if settings.OPENAI_BASE_URL:
self.client = OpenAI(
api_key=api_key,
base_url=settings.OPENAI_BASE_URL
)
self.client = OpenAI(api_key=api_key, base_url=settings.OPENAI_BASE_URL)
else:
self.client = OpenAI(api_key=api_key)
self.api_key = api_key
@@ -27,7 +23,7 @@ class OpenAILLM(BaseLLM):
stream=False,
tools=None,
engine=settings.AZURE_DEPLOYMENT_NAME,
**kwargs
**kwargs,
):
if tools:
response = self.client.chat.completions.create(
@@ -48,15 +44,13 @@ class OpenAILLM(BaseLLM):
stream=True,
tools=None,
engine=settings.AZURE_DEPLOYMENT_NAME,
**kwargs
**kwargs,
):
response = self.client.chat.completions.create(
model=model, messages=messages, stream=stream, **kwargs
)
for line in response:
# import sys
# print(line.choices[0].delta.content, file=sys.stderr)
if line.choices[0].delta.content is not None:
yield line.choices[0].delta.content

View File

@@ -14,6 +14,8 @@ esutils==1.0.1
Flask==3.0.3
faiss-cpu==1.9.0.post1
flask-restx==1.3.0
google-genai==0.5.0
google-generativeai==0.8.3
gTTS==2.3.2
gunicorn==23.0.0
html2text==2024.2.26

View File

@@ -74,12 +74,10 @@ class BraveRetSearch(BaseRetriever):
if len(self.chat_history) > 1:
for i in self.chat_history:
if "prompt" in i and "response" in i:
messages_combine.append(
{"role": "user", "content": i["prompt"]}
)
messages_combine.append(
{"role": "system", "content": i["response"]}
)
messages_combine.append({"role": "user", "content": i["prompt"]})
messages_combine.append(
{"role": "assistant", "content": i["response"]}
)
messages_combine.append({"role": "user", "content": self.question})
llm = LLMCreator.create_llm(

View File

@@ -5,7 +5,6 @@ from application.tools.agent import Agent
from application.vectorstore.vector_creator import VectorCreator
class ClassicRAG(BaseRetriever):
def __init__(
@@ -74,13 +73,11 @@ class ClassicRAG(BaseRetriever):
if len(self.chat_history) > 1:
for i in self.chat_history:
if "prompt" in i and "response" in i:
messages_combine.append(
{"role": "user", "content": i["prompt"]}
)
messages_combine.append(
{"role": "system", "content": i["response"]}
)
if "prompt" in i and "response" in i:
messages_combine.append({"role": "user", "content": i["prompt"]})
messages_combine.append(
{"role": "assistant", "content": i["response"]}
)
messages_combine.append({"role": "user", "content": self.question})
# llm = LLMCreator.create_llm(
# settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=self.user_api_key

View File

@@ -91,11 +91,9 @@ class DuckDuckSearch(BaseRetriever):
if len(self.chat_history) > 1:
for i in self.chat_history:
if "prompt" in i and "response" in i:
messages_combine.append({"role": "user", "content": i["prompt"]})
messages_combine.append(
{"role": "user", "content": i["prompt"]}
)
messages_combine.append(
{"role": "system", "content": i["response"]}
{"role": "assistant", "content": i["response"]}
)
messages_combine.append({"role": "user", "content": self.question})

View File

@@ -1,8 +1,7 @@
import json
import logging
from application.core.mongo_db import MongoDB
from application.llm.llm_creator import LLMCreator
from application.tools.llm_handler import get_llm_handler
from application.tools.tool_action_parser import ToolActionParser
from application.tools.tool_manager import ToolManager
@@ -12,6 +11,7 @@ class Agent:
self.llm = LLMCreator.create_llm(
llm_name, api_key=api_key, user_api_key=user_api_key
)
self.llm_handler = get_llm_handler(llm_name)
self.gpt_model = gpt_model
# Static tool configuration (to be replaced later)
self.tools = []
@@ -61,10 +61,8 @@ class Agent:
]
def _execute_tool_action(self, tools_dict, call):
call_id = call.id
call_args = json.loads(call.function.arguments)
tool_id = call.function.name.split("_")[-1]
action_name = call.function.name.rsplit("_", 1)[0]
parser = ToolActionParser(self.llm.__class__.__name__)
tool_id, action_name, call_args = parser.parse_args(call)
tool_data = tools_dict[tool_id]
action_data = next(
@@ -78,26 +76,9 @@ class Agent:
tm = ToolManager(config={})
tool = tm.load_tool(tool_data["name"], tool_config=tool_data["config"])
print(f"Executing tool: {action_name} with args: {call_args}")
return tool.execute_action(action_name, **call_args), call_id
def _execute_tool_action_google(self, tools_dict, call):
call_args = json.loads(call.args)
tool_id = call.name.split("_")[-1]
action_name = call.name.rsplit("_", 1)[0]
tool_data = tools_dict[tool_id]
action_data = next(
action for action in tool_data["actions"] if action["name"] == action_name
)
for param, details in action_data["parameters"]["properties"].items():
if param not in call_args and "value" in details:
call_args[param] = details["value"]
tm = ToolManager(config={})
tool = tm.load_tool(tool_data["name"], tool_config=tool_data["config"])
print(f"Executing tool: {action_name} with args: {call_args}")
return tool.execute_action(action_name, **call_args)
result = tool.execute_action(action_name, **call_args)
call_id = getattr(call, "id", None)
return result, call_id
def _simple_tool_agent(self, messages):
tools_dict = self._get_user_tools()
@@ -108,55 +89,15 @@ class Agent:
if isinstance(resp, str):
yield resp
return
if resp.message.content:
if hasattr(resp, "message") and hasattr(resp.message, "content"):
yield resp.message.content
return
# check if self.llm class is GoogleLLM
while self.llm.__class__.__name__ == "GoogleLLM" and resp.content.parts[0].function_call:
from google.genai import types
function_call_part = resp.candidates[0].content.parts[0]
tool_response = self._execute_tool_action_google(tools_dict, function_call_part.function_call)
function_response_part = types.Part.from_function_response(
name=function_call_part.function_call.name,
response=tool_response
)
resp = self.llm_handler.handle_response(self, resp, tools_dict, messages)
while self.llm.__class__.__name__ == "OpenAILLM" and resp.finish_reason == "tool_calls":
message = json.loads(resp.model_dump_json())["message"]
keys_to_remove = {"audio", "function_call", "refusal"}
filtered_data = {
k: v for k, v in message.items() if k not in keys_to_remove
}
messages.append(filtered_data)
tool_calls = resp.message.tool_calls
for call in tool_calls:
try:
tool_response, call_id = self._execute_tool_action(tools_dict, call)
messages.append(
{
"role": "tool",
"content": str(tool_response),
"tool_call_id": call_id,
}
)
except Exception as e:
messages.append(
{
"role": "tool",
"content": f"Error executing tool: {str(e)}",
"tool_call_id": call.id,
}
)
# Generate a new response from the LLM after processing tools
resp = self.llm.gen(
model=self.gpt_model, messages=messages, tools=self.tools
)
# If no tool calls are needed, generate the final response
if isinstance(resp, str):
yield resp
elif resp.message.content:
elif hasattr(resp, "message") and hasattr(resp.message, "content"):
yield resp.message.content
else:
completion = self.llm.gen_stream(
@@ -168,7 +109,6 @@ class Agent:
return
def gen(self, messages):
# Generate initial response from the LLM
if self.llm.supports_tools():
resp = self._simple_tool_agent(messages)
for line in resp:

View File

@@ -0,0 +1,97 @@
import json
from abc import ABC, abstractmethod
class LLMHandler(ABC):
@abstractmethod
def handle_response(self, agent, resp, tools_dict, messages, **kwargs):
pass
class OpenAILLMHandler(LLMHandler):
def handle_response(self, agent, resp, tools_dict, messages):
while resp.finish_reason == "tool_calls":
message = json.loads(resp.model_dump_json())["message"]
keys_to_remove = {"audio", "function_call", "refusal"}
filtered_data = {
k: v for k, v in message.items() if k not in keys_to_remove
}
messages.append(filtered_data)
tool_calls = resp.message.tool_calls
for call in tool_calls:
try:
tool_response, call_id = agent._execute_tool_action(
tools_dict, call
)
messages.append(
{
"role": "tool",
"content": str(tool_response),
"tool_call_id": call_id,
}
)
except Exception as e:
messages.append(
{
"role": "tool",
"content": f"Error executing tool: {str(e)}",
"tool_call_id": call_id,
}
)
resp = agent.llm.gen(
model=agent.gpt_model, messages=messages, tools=agent.tools
)
return resp
class GoogleLLMHandler(LLMHandler):
def handle_response(self, agent, resp, tools_dict, messages):
from google.genai import types
while True:
response = agent.llm.gen(
model=agent.gpt_model, messages=messages, tools=agent.tools
)
if response.candidates and response.candidates[0].content.parts:
tool_call_found = False
for part in response.candidates[0].content.parts:
if part.function_call:
tool_call_found = True
tool_response, call_id = agent._execute_tool_action(
tools_dict, part.function_call
)
function_response_part = types.Part.from_function_response(
name=part.function_call.name,
response={"result": tool_response},
)
messages.append(
{"role": "model", "content": [part.to_json_dict()]}
)
messages.append(
{
"role": "tool",
"content": [function_response_part.to_json_dict()],
}
)
if (
not tool_call_found
and response.candidates[0].content.parts
and response.candidates[0].content.parts[0].text
):
return response.candidates[0].content.parts[0].text
elif not tool_call_found:
return response.candidates[0].content.parts
else:
return response
def get_llm_handler(llm_type):
handlers = {
"openai": OpenAILLMHandler(),
"google": GoogleLLMHandler(),
}
return handlers.get(llm_type, OpenAILLMHandler())

View File

@@ -0,0 +1,26 @@
import json
class ToolActionParser:
def __init__(self, llm_type):
self.llm_type = llm_type
self.parsers = {
"OpenAILLM": self._parse_openai_llm,
"GoogleLLM": self._parse_google_llm,
}
def parse_args(self, call):
parser = self.parsers.get(self.llm_type, self._parse_openai_llm)
return parser(call)
def _parse_openai_llm(self, call):
call_args = json.loads(call.function.arguments)
tool_id = call.function.name.split("_")[-1]
action_name = call.function.name.rsplit("_", 1)[0]
return tool_id, action_name, call_args
def _parse_google_llm(self, call):
call_args = call.args
tool_id = call.name.split("_")[-1]
action_name = call.name.rsplit("_", 1)[0]
return tool_id, action_name, call_args