mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-30 17:13:15 +00:00
refactor: folder restructure for agent based workflow
This commit is contained in:
@@ -1,102 +0,0 @@
|
||||
import uuid
|
||||
from typing import Dict, Generator
|
||||
|
||||
from application.retriever.base import BaseRetriever
|
||||
from application.tools.base_agent import BaseAgent
|
||||
|
||||
|
||||
class ClassicAgent(BaseAgent):
|
||||
def __init__(
|
||||
self,
|
||||
llm_name,
|
||||
gpt_model,
|
||||
api_key,
|
||||
user_api_key=None,
|
||||
prompt="",
|
||||
chat_history=None,
|
||||
):
|
||||
super().__init__(llm_name, gpt_model, api_key, user_api_key)
|
||||
self.prompt = prompt
|
||||
self.chat_history = chat_history if chat_history is not None else []
|
||||
|
||||
def gen(self, query: str, retriever: BaseRetriever) -> Generator[Dict, None, None]:
|
||||
|
||||
retrieved_data = retriever.search(query)
|
||||
docs_together = "\n".join([doc["text"] for doc in retrieved_data])
|
||||
p_chat_combine = self.prompt.replace("{summaries}", docs_together)
|
||||
messages_combine = [{"role": "system", "content": p_chat_combine}]
|
||||
|
||||
if len(self.chat_history) > 0:
|
||||
for i in self.chat_history:
|
||||
if "prompt" in i and "response" in i:
|
||||
messages_combine.append({"role": "user", "content": i["prompt"]})
|
||||
messages_combine.append(
|
||||
{"role": "assistant", "content": i["response"]}
|
||||
)
|
||||
if "tool_calls" in i:
|
||||
for tool_call in i["tool_calls"]:
|
||||
call_id = tool_call.get("call_id")
|
||||
if call_id is None or call_id == "None":
|
||||
call_id = str(uuid.uuid4())
|
||||
|
||||
function_call_dict = {
|
||||
"function_call": {
|
||||
"name": tool_call.get("action_name"),
|
||||
"args": tool_call.get("arguments"),
|
||||
"call_id": call_id,
|
||||
}
|
||||
}
|
||||
function_response_dict = {
|
||||
"function_response": {
|
||||
"name": tool_call.get("action_name"),
|
||||
"response": {"result": tool_call.get("result")},
|
||||
"call_id": call_id,
|
||||
}
|
||||
}
|
||||
|
||||
messages_combine.append(
|
||||
{"role": "assistant", "content": [function_call_dict]}
|
||||
)
|
||||
messages_combine.append(
|
||||
{"role": "tool", "content": [function_response_dict]}
|
||||
)
|
||||
messages_combine.append({"role": "user", "content": query})
|
||||
|
||||
tools_dict = self._get_user_tools()
|
||||
self._prepare_tools(tools_dict)
|
||||
|
||||
resp = self.llm.gen(
|
||||
model=self.gpt_model, messages=messages_combine, tools=self.tools
|
||||
)
|
||||
|
||||
if isinstance(resp, str):
|
||||
yield {"answer": resp}
|
||||
return
|
||||
if (
|
||||
hasattr(resp, "message")
|
||||
and hasattr(resp.message, "content")
|
||||
and resp.message.content is not None
|
||||
):
|
||||
yield {"answer": resp.message.content}
|
||||
return
|
||||
|
||||
resp = self.llm_handler.handle_response(
|
||||
self, resp, tools_dict, messages_combine
|
||||
)
|
||||
|
||||
if isinstance(resp, str):
|
||||
yield {"answer": resp}
|
||||
elif (
|
||||
hasattr(resp, "message")
|
||||
and hasattr(resp.message, "content")
|
||||
and resp.message.content is not None
|
||||
):
|
||||
yield {"answer": resp.message.content}
|
||||
else:
|
||||
completion = self.llm.gen_stream(
|
||||
model=self.gpt_model, messages=messages_combine, tools=self.tools
|
||||
)
|
||||
for line in completion:
|
||||
yield {"answer": line}
|
||||
|
||||
yield {"tool_calls": self.tool_calls.copy()}
|
||||
@@ -1,21 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class Tool(ABC):
|
||||
@abstractmethod
|
||||
def execute_action(self, action_name: str, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_actions_metadata(self):
|
||||
"""
|
||||
Returns a list of JSON objects describing the actions supported by the tool.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_config_requirements(self):
|
||||
"""
|
||||
Returns a dictionary describing the configuration requirements for the tool.
|
||||
"""
|
||||
pass
|
||||
@@ -1,140 +0,0 @@
|
||||
from typing import Dict, Generator
|
||||
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
from application.tools.llm_handler import get_llm_handler
|
||||
from application.tools.tool_action_parser import ToolActionParser
|
||||
from application.tools.tool_manager import ToolManager
|
||||
|
||||
|
||||
class BaseAgent:
|
||||
def __init__(self, llm_name, gpt_model, api_key, user_api_key=None):
|
||||
self.llm = LLMCreator.create_llm(
|
||||
llm_name, api_key=api_key, user_api_key=user_api_key
|
||||
)
|
||||
self.llm_handler = get_llm_handler(llm_name)
|
||||
self.gpt_model = gpt_model
|
||||
self.tools = []
|
||||
self.tool_config = {}
|
||||
self.tool_calls = []
|
||||
|
||||
def gen(self, query: str) -> Generator[Dict, None, None]:
|
||||
raise NotImplementedError('Method "gen" must be implemented in the child class')
|
||||
|
||||
def _get_user_tools(self, user="local"):
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo["docsgpt"]
|
||||
user_tools_collection = db["user_tools"]
|
||||
user_tools = user_tools_collection.find({"user": user, "status": True})
|
||||
user_tools = list(user_tools)
|
||||
tools_by_id = {str(tool["_id"]): tool for tool in user_tools}
|
||||
return tools_by_id
|
||||
|
||||
def _build_tool_parameters(self, action):
|
||||
params = {"type": "object", "properties": {}, "required": []}
|
||||
for param_type in ["query_params", "headers", "body", "parameters"]:
|
||||
if param_type in action and action[param_type].get("properties"):
|
||||
for k, v in action[param_type]["properties"].items():
|
||||
if v.get("filled_by_llm", True):
|
||||
params["properties"][k] = {
|
||||
key: value
|
||||
for key, value in v.items()
|
||||
if key != "filled_by_llm" and key != "value"
|
||||
}
|
||||
|
||||
params["required"].append(k)
|
||||
return params
|
||||
|
||||
def _prepare_tools(self, tools_dict):
|
||||
self.tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": f"{action['name']}_{tool_id}",
|
||||
"description": action["description"],
|
||||
"parameters": self._build_tool_parameters(action),
|
||||
},
|
||||
}
|
||||
for tool_id, tool in tools_dict.items()
|
||||
if (
|
||||
(tool["name"] == "api_tool" and "actions" in tool.get("config", {}))
|
||||
or (tool["name"] != "api_tool" and "actions" in tool)
|
||||
)
|
||||
for action in (
|
||||
tool["config"]["actions"].values()
|
||||
if tool["name"] == "api_tool"
|
||||
else tool["actions"]
|
||||
)
|
||||
if action.get("active", True)
|
||||
]
|
||||
|
||||
def _execute_tool_action(self, tools_dict, call):
|
||||
parser = ToolActionParser(self.llm.__class__.__name__)
|
||||
tool_id, action_name, call_args = parser.parse_args(call)
|
||||
|
||||
tool_data = tools_dict[tool_id]
|
||||
action_data = (
|
||||
tool_data["config"]["actions"][action_name]
|
||||
if tool_data["name"] == "api_tool"
|
||||
else next(
|
||||
action
|
||||
for action in tool_data["actions"]
|
||||
if action["name"] == action_name
|
||||
)
|
||||
)
|
||||
|
||||
query_params, headers, body, parameters = {}, {}, {}, {}
|
||||
param_types = {
|
||||
"query_params": query_params,
|
||||
"headers": headers,
|
||||
"body": body,
|
||||
"parameters": parameters,
|
||||
}
|
||||
|
||||
for param_type, target_dict in param_types.items():
|
||||
if param_type in action_data and action_data[param_type].get("properties"):
|
||||
for param, details in action_data[param_type]["properties"].items():
|
||||
if param not in call_args and "value" in details:
|
||||
target_dict[param] = details["value"]
|
||||
|
||||
for param, value in call_args.items():
|
||||
for param_type, target_dict in param_types.items():
|
||||
if param_type in action_data and param in action_data[param_type].get(
|
||||
"properties", {}
|
||||
):
|
||||
target_dict[param] = value
|
||||
|
||||
tm = ToolManager(config={})
|
||||
tool = tm.load_tool(
|
||||
tool_data["name"],
|
||||
tool_config=(
|
||||
{
|
||||
"url": tool_data["config"]["actions"][action_name]["url"],
|
||||
"method": tool_data["config"]["actions"][action_name]["method"],
|
||||
"headers": headers,
|
||||
"query_params": query_params,
|
||||
}
|
||||
if tool_data["name"] == "api_tool"
|
||||
else tool_data["config"]
|
||||
),
|
||||
)
|
||||
if tool_data["name"] == "api_tool":
|
||||
print(
|
||||
f"Executing api: {action_name} with query_params: {query_params}, headers: {headers}, body: {body}"
|
||||
)
|
||||
result = tool.execute_action(action_name, **body)
|
||||
else:
|
||||
print(f"Executing tool: {action_name} with args: {call_args}")
|
||||
result = tool.execute_action(action_name, **parameters)
|
||||
call_id = getattr(call, "id", None)
|
||||
|
||||
tool_call_data = {
|
||||
"tool_name": tool_data["name"],
|
||||
"call_id": call_id if call_id is not None else "None",
|
||||
"action_name": f"{action_name}_{tool_id}",
|
||||
"arguments": call_args,
|
||||
"result": result,
|
||||
}
|
||||
self.tool_calls.append(tool_call_data)
|
||||
|
||||
return result, call_id
|
||||
@@ -1,71 +0,0 @@
|
||||
import json
|
||||
|
||||
import requests
|
||||
from application.tools.base import Tool
|
||||
|
||||
|
||||
class APITool(Tool):
|
||||
"""
|
||||
API Tool
|
||||
A flexible tool for performing various API actions (e.g., sending messages, retrieving data) via custom user-specified APIs
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.url = config.get("url", "")
|
||||
self.method = config.get("method", "GET")
|
||||
self.headers = config.get("headers", {"Content-Type": "application/json"})
|
||||
self.query_params = config.get("query_params", {})
|
||||
|
||||
def execute_action(self, action_name, **kwargs):
|
||||
return self._make_api_call(
|
||||
self.url, self.method, self.headers, self.query_params, kwargs
|
||||
)
|
||||
|
||||
def _make_api_call(self, url, method, headers, query_params, body):
|
||||
if query_params:
|
||||
url = f"{url}?{requests.compat.urlencode(query_params)}"
|
||||
if isinstance(body, dict):
|
||||
body = json.dumps(body)
|
||||
try:
|
||||
print(f"Making API call: {method} {url} with body: {body}")
|
||||
response = requests.request(method, url, headers=headers, data=body)
|
||||
response.raise_for_status()
|
||||
|
||||
content_type = response.headers.get(
|
||||
"Content-Type", "application/json"
|
||||
).lower()
|
||||
if "application/json" in content_type:
|
||||
try:
|
||||
data = response.json()
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"Error decoding JSON: {e}. Raw response: {response.text}")
|
||||
return {
|
||||
"status_code": response.status_code,
|
||||
"message": f"API call returned invalid JSON. Error: {e}",
|
||||
"data": response.text,
|
||||
}
|
||||
elif "text/" in content_type or "application/xml" in content_type:
|
||||
data = response.text
|
||||
elif not response.content:
|
||||
data = None
|
||||
else:
|
||||
print(f"Unsupported content type: {content_type}")
|
||||
data = response.content
|
||||
|
||||
return {
|
||||
"status_code": response.status_code,
|
||||
"data": data,
|
||||
"message": "API call successful.",
|
||||
}
|
||||
except requests.exceptions.RequestException as e:
|
||||
return {
|
||||
"status_code": response.status_code if response else None,
|
||||
"message": f"API call failed: {str(e)}",
|
||||
}
|
||||
|
||||
def get_actions_metadata(self):
|
||||
return []
|
||||
|
||||
def get_config_requirements(self):
|
||||
return {}
|
||||
@@ -1,77 +0,0 @@
|
||||
import requests
|
||||
from application.tools.base import Tool
|
||||
|
||||
|
||||
class CryptoPriceTool(Tool):
|
||||
"""
|
||||
CryptoPrice
|
||||
A tool for retrieving cryptocurrency prices using the CryptoCompare public API
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
|
||||
def execute_action(self, action_name, **kwargs):
|
||||
actions = {"cryptoprice_get": self._get_price}
|
||||
|
||||
if action_name in actions:
|
||||
return actions[action_name](**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unknown action: {action_name}")
|
||||
|
||||
def _get_price(self, symbol, currency):
|
||||
"""
|
||||
Fetches the current price of a given cryptocurrency symbol in the specified currency.
|
||||
Example:
|
||||
symbol = "BTC"
|
||||
currency = "USD"
|
||||
returns price in USD.
|
||||
"""
|
||||
url = f"https://min-api.cryptocompare.com/data/price?fsym={symbol.upper()}&tsyms={currency.upper()}"
|
||||
response = requests.get(url)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
# data will be like {"USD": <price>} if the call is successful
|
||||
if currency.upper() in data:
|
||||
return {
|
||||
"status_code": response.status_code,
|
||||
"price": data[currency.upper()],
|
||||
"message": f"Price of {symbol.upper()} in {currency.upper()} retrieved successfully.",
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"status_code": response.status_code,
|
||||
"message": f"Couldn't find price for {symbol.upper()} in {currency.upper()}.",
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"status_code": response.status_code,
|
||||
"message": "Failed to retrieve price.",
|
||||
}
|
||||
|
||||
def get_actions_metadata(self):
|
||||
return [
|
||||
{
|
||||
"name": "cryptoprice_get",
|
||||
"description": "Retrieve the price of a specified cryptocurrency in a given currency",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"symbol": {
|
||||
"type": "string",
|
||||
"description": "The cryptocurrency symbol (e.g. BTC)",
|
||||
},
|
||||
"currency": {
|
||||
"type": "string",
|
||||
"description": "The currency in which you want the price (e.g. USD)",
|
||||
},
|
||||
},
|
||||
"required": ["symbol", "currency"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
def get_config_requirements(self):
|
||||
# No specific configuration needed for this tool as it just queries a public endpoint
|
||||
return {}
|
||||
@@ -1,86 +0,0 @@
|
||||
import requests
|
||||
from application.tools.base import Tool
|
||||
|
||||
|
||||
class TelegramTool(Tool):
|
||||
"""
|
||||
Telegram Bot
|
||||
A flexible Telegram tool for performing various actions (e.g., sending messages, images).
|
||||
Requires a bot token and chat ID for configuration
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.token = config.get("token", "")
|
||||
|
||||
def execute_action(self, action_name, **kwargs):
|
||||
actions = {
|
||||
"telegram_send_message": self._send_message,
|
||||
"telegram_send_image": self._send_image,
|
||||
}
|
||||
|
||||
if action_name in actions:
|
||||
return actions[action_name](**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unknown action: {action_name}")
|
||||
|
||||
def _send_message(self, text, chat_id):
|
||||
print(f"Sending message: {text}")
|
||||
url = f"https://api.telegram.org/bot{self.token}/sendMessage"
|
||||
payload = {"chat_id": chat_id, "text": text}
|
||||
response = requests.post(url, data=payload)
|
||||
return {"status_code": response.status_code, "message": "Message sent"}
|
||||
|
||||
def _send_image(self, image_url, chat_id):
|
||||
print(f"Sending image: {image_url}")
|
||||
url = f"https://api.telegram.org/bot{self.token}/sendPhoto"
|
||||
payload = {"chat_id": chat_id, "photo": image_url}
|
||||
response = requests.post(url, data=payload)
|
||||
return {"status_code": response.status_code, "message": "Image sent"}
|
||||
|
||||
def get_actions_metadata(self):
|
||||
return [
|
||||
{
|
||||
"name": "telegram_send_message",
|
||||
"description": "Send a notification to Telegram chat",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {
|
||||
"type": "string",
|
||||
"description": "Text to send in the notification",
|
||||
},
|
||||
"chat_id": {
|
||||
"type": "string",
|
||||
"description": "Chat ID to send the notification to",
|
||||
},
|
||||
},
|
||||
"required": ["text"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "telegram_send_image",
|
||||
"description": "Send an image to the Telegram chat",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"image_url": {
|
||||
"type": "string",
|
||||
"description": "URL of the image to send",
|
||||
},
|
||||
"chat_id": {
|
||||
"type": "string",
|
||||
"description": "Chat ID to send the image to",
|
||||
},
|
||||
},
|
||||
"required": ["image_url"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
def get_config_requirements(self):
|
||||
return {
|
||||
"token": {"type": "string", "description": "Bot token for authentication"},
|
||||
}
|
||||
@@ -1,112 +0,0 @@
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class LLMHandler(ABC):
|
||||
@abstractmethod
|
||||
def handle_response(self, agent, resp, tools_dict, messages, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class OpenAILLMHandler(LLMHandler):
|
||||
def handle_response(self, agent, resp, tools_dict, messages):
|
||||
while resp.finish_reason == "tool_calls":
|
||||
message = json.loads(resp.model_dump_json())["message"]
|
||||
keys_to_remove = {"audio", "function_call", "refusal"}
|
||||
filtered_data = {
|
||||
k: v for k, v in message.items() if k not in keys_to_remove
|
||||
}
|
||||
messages.append(filtered_data)
|
||||
|
||||
tool_calls = resp.message.tool_calls
|
||||
for call in tool_calls:
|
||||
try:
|
||||
tool_response, call_id = agent._execute_tool_action(
|
||||
tools_dict, call
|
||||
)
|
||||
function_call_dict = {
|
||||
"function_call": {
|
||||
"name": call.function.name,
|
||||
"args": call.function.arguments,
|
||||
"call_id": call_id,
|
||||
}
|
||||
}
|
||||
function_response_dict = {
|
||||
"function_response": {
|
||||
"name": call.function.name,
|
||||
"response": {"result": tool_response},
|
||||
"call_id": call_id,
|
||||
}
|
||||
}
|
||||
|
||||
messages.append(
|
||||
{"role": "assistant", "content": [function_call_dict]}
|
||||
)
|
||||
messages.append(
|
||||
{"role": "tool", "content": [function_response_dict]}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": f"Error executing tool: {str(e)}",
|
||||
"tool_call_id": call_id,
|
||||
}
|
||||
)
|
||||
resp = agent.llm.gen(
|
||||
model=agent.gpt_model, messages=messages, tools=agent.tools
|
||||
)
|
||||
return resp
|
||||
|
||||
|
||||
class GoogleLLMHandler(LLMHandler):
|
||||
def handle_response(self, agent, resp, tools_dict, messages):
|
||||
from google.genai import types
|
||||
|
||||
while True:
|
||||
response = agent.llm.gen(
|
||||
model=agent.gpt_model, messages=messages, tools=agent.tools
|
||||
)
|
||||
if response.candidates and response.candidates[0].content.parts:
|
||||
tool_call_found = False
|
||||
for part in response.candidates[0].content.parts:
|
||||
if part.function_call:
|
||||
tool_call_found = True
|
||||
tool_response, call_id = agent._execute_tool_action(
|
||||
tools_dict, part.function_call
|
||||
)
|
||||
function_response_part = types.Part.from_function_response(
|
||||
name=part.function_call.name,
|
||||
response={"result": tool_response},
|
||||
)
|
||||
|
||||
messages.append(
|
||||
{"role": "model", "content": [part.to_json_dict()]}
|
||||
)
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": [function_response_part.to_json_dict()],
|
||||
}
|
||||
)
|
||||
|
||||
if (
|
||||
not tool_call_found
|
||||
and response.candidates[0].content.parts
|
||||
and response.candidates[0].content.parts[0].text
|
||||
):
|
||||
return response.candidates[0].content.parts[0].text
|
||||
elif not tool_call_found:
|
||||
return response.candidates[0].content.parts
|
||||
|
||||
else:
|
||||
return response
|
||||
|
||||
|
||||
def get_llm_handler(llm_type):
|
||||
handlers = {
|
||||
"openai": OpenAILLMHandler(),
|
||||
"google": GoogleLLMHandler(),
|
||||
}
|
||||
return handlers.get(llm_type, OpenAILLMHandler())
|
||||
@@ -1,26 +0,0 @@
|
||||
import json
|
||||
|
||||
|
||||
class ToolActionParser:
|
||||
def __init__(self, llm_type):
|
||||
self.llm_type = llm_type
|
||||
self.parsers = {
|
||||
"OpenAILLM": self._parse_openai_llm,
|
||||
"GoogleLLM": self._parse_google_llm,
|
||||
}
|
||||
|
||||
def parse_args(self, call):
|
||||
parser = self.parsers.get(self.llm_type, self._parse_openai_llm)
|
||||
return parser(call)
|
||||
|
||||
def _parse_openai_llm(self, call):
|
||||
call_args = json.loads(call.function.arguments)
|
||||
tool_id = call.function.name.split("_")[-1]
|
||||
action_name = call.function.name.rsplit("_", 1)[0]
|
||||
return tool_id, action_name, call_args
|
||||
|
||||
def _parse_google_llm(self, call):
|
||||
call_args = call.args
|
||||
tool_id = call.name.split("_")[-1]
|
||||
action_name = call.name.rsplit("_", 1)[0]
|
||||
return tool_id, action_name, call_args
|
||||
@@ -1,46 +0,0 @@
|
||||
import importlib
|
||||
import inspect
|
||||
import os
|
||||
import pkgutil
|
||||
|
||||
from application.tools.base import Tool
|
||||
|
||||
|
||||
class ToolManager:
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.tools = {}
|
||||
self.load_tools()
|
||||
|
||||
def load_tools(self):
|
||||
tools_dir = os.path.join(os.path.dirname(__file__), "implementations")
|
||||
for finder, name, ispkg in pkgutil.iter_modules([tools_dir]):
|
||||
if name == "base" or name.startswith("__"):
|
||||
continue
|
||||
module = importlib.import_module(
|
||||
f"application.tools.implementations.{name}"
|
||||
)
|
||||
for member_name, obj in inspect.getmembers(module, inspect.isclass):
|
||||
if issubclass(obj, Tool) and obj is not Tool:
|
||||
tool_config = self.config.get(name, {})
|
||||
self.tools[name] = obj(tool_config)
|
||||
|
||||
def load_tool(self, tool_name, tool_config):
|
||||
self.config[tool_name] = tool_config
|
||||
module = importlib.import_module(
|
||||
f"application.tools.implementations.{tool_name}"
|
||||
)
|
||||
for member_name, obj in inspect.getmembers(module, inspect.isclass):
|
||||
if issubclass(obj, Tool) and obj is not Tool:
|
||||
return obj(tool_config)
|
||||
|
||||
def execute_action(self, tool_name, action_name, **kwargs):
|
||||
if tool_name not in self.tools:
|
||||
raise ValueError(f"Tool '{tool_name}' not loaded")
|
||||
return self.tools[tool_name].execute_action(action_name, **kwargs)
|
||||
|
||||
def get_all_actions_metadata(self):
|
||||
metadata = []
|
||||
for tool in self.tools.values():
|
||||
metadata.extend(tool.get_actions_metadata())
|
||||
return metadata
|
||||
Reference in New Issue
Block a user