diff --git a/application/api/user/routes.py b/application/api/user/routes.py index 627f5665..efb0242d 100644 --- a/application/api/user/routes.py +++ b/application/api/user/routes.py @@ -1791,13 +1791,12 @@ class TextToSpeech(Resource): @user_ns.route("/api/create_tool") class CreateTool(Resource): - # write code such that it will accept tool_name, took_config and tool_actions create_tool_model = api.model( "CreateToolModel", { - "tool_name": fields.String(required=True, description="Name of the tool"), - "tool_config": fields.Raw(required=True, description="Configuration of the tool"), - "tool_actions": fields.List(required=True, description="Actions the tool can perform"), + "name": fields.String(required=True, description="Name of the tool"), + "config": fields.Raw(required=True, description="Configuration of the tool"), + "actions": fields.List(required=True, description="Actions the tool can perform"), "status": fields.Boolean(required=True, description="Status of the tool") }, @@ -1807,7 +1806,7 @@ class CreateTool(Resource): @api.doc(description="Create a new tool") def post(self): data = request.get_json() - required_fields = ["tool_name", "tool_config", "tool_actions", "status"] + required_fields = ["name", "config", "actions", "status"] missing_fields = check_required_fields(data, required_fields) if missing_fields: return missing_fields @@ -1815,9 +1814,9 @@ class CreateTool(Resource): user = "local" try: new_tool = { - "tool_name": data["tool_name"], - "tool_config": data["tool_config"], - "tool_actions": data["tool_actions"], + "name": data["name"], + "config": data["config"], + "actions": data["actions"], "user": user, "status": data["status"], } @@ -1833,8 +1832,8 @@ class UpdateToolConfig(Resource): update_tool_config_model = api.model( "UpdateToolConfigModel", { - "tool_id": fields.String(required=True, description="Tool ID"), - "tool_config": fields.Raw(required=True, description="Configuration of the tool"), + "id": fields.String(required=True, description="Tool ID"), + "config": fields.Raw(required=True, description="Configuration of the tool"), }, ) @@ -1842,15 +1841,15 @@ class UpdateToolConfig(Resource): @api.doc(description="Update the configuration of a tool") def post(self): data = request.get_json() - required_fields = ["tool_id", "tool_config"] + required_fields = ["id", "config"] missing_fields = check_required_fields(data, required_fields) if missing_fields: return missing_fields try: user_tools_collection.update_one( - {"_id": ObjectId(data["tool_id"])}, - {"$set": {"tool_config": data["tool_config"]}}, + {"_id": ObjectId(data["id"])}, + {"$set": {"config": data["config"]}}, ) except Exception as err: return make_response(jsonify({"success": False, "error": str(err)}), 400) @@ -1862,8 +1861,8 @@ class UpdateToolActions(Resource): update_tool_actions_model = api.model( "UpdateToolActionsModel", { - "tool_id": fields.String(required=True, description="Tool ID"), - "tool_actions": fields.List(required=True, description="Actions the tool can perform"), + "id": fields.String(required=True, description="Tool ID"), + "actions": fields.List(required=True, description="Actions the tool can perform"), }, ) @@ -1871,15 +1870,15 @@ class UpdateToolActions(Resource): @api.doc(description="Update the actions of a tool") def post(self): data = request.get_json() - required_fields = ["tool_id", "tool_actions"] + required_fields = ["id", "actions"] missing_fields = check_required_fields(data, required_fields) if missing_fields: return missing_fields try: user_tools_collection.update_one( - {"_id": ObjectId(data["tool_id"])}, - {"$set": {"tool_actions": data["tool_actions"]}}, + {"_id": ObjectId(data["id"])}, + {"$set": {"actions": data["actions"]}}, ) except Exception as err: return make_response(jsonify({"success": False, "error": str(err)}), 400) @@ -1891,7 +1890,7 @@ class UpdateToolStatus(Resource): update_tool_status_model = api.model( "UpdateToolStatusModel", { - "tool_id": fields.String(required=True, description="Tool ID"), + "id": fields.String(required=True, description="Tool ID"), "status": fields.Boolean(required=True, description="Status of the tool"), }, ) @@ -1900,14 +1899,14 @@ class UpdateToolStatus(Resource): @api.doc(description="Update the status of a tool") def post(self): data = request.get_json() - required_fields = ["tool_id", "status"] + required_fields = ["id", "status"] missing_fields = check_required_fields(data, required_fields) if missing_fields: return missing_fields try: user_tools_collection.update_one( - {"_id": ObjectId(data["tool_id"])}, + {"_id": ObjectId(data["id"])}, {"$set": {"status": data["status"]}}, ) except Exception as err: @@ -1919,20 +1918,20 @@ class UpdateToolStatus(Resource): class DeleteTool(Resource): delete_tool_model = api.model( "DeleteToolModel", - {"tool_id": fields.String(required=True, description="Tool ID")}, + {"id": fields.String(required=True, description="Tool ID")}, ) @api.expect(delete_tool_model) @api.doc(description="Delete a tool by ID") def post(self): data = request.get_json() - required_fields = ["tool_id"] + required_fields = ["id"] missing_fields = check_required_fields(data, required_fields) if missing_fields: return missing_fields try: - result = user_tools_collection.delete_one({"_id": ObjectId(data["tool_id"])}) + result = user_tools_collection.delete_one({"_id": ObjectId(data["id"])}) if result.deleted_count == 0: return {"success": False, "message": "Tool not found"}, 404 except Exception as err: diff --git a/application/tools/agent.py b/application/tools/agent.py index 2df14442..af23b99e 100644 --- a/application/tools/agent.py +++ b/application/tools/agent.py @@ -1,6 +1,7 @@ from application.llm.llm_creator import LLMCreator from application.core.settings import settings from application.tools.tool_manager import ToolManager +from application.core.mongo_db import MongoDB import json tool_tg = { @@ -53,9 +54,31 @@ class Agent: ] self.tool_config = { } + + def _get_user_tools(self, user="local"): + mongo = MongoDB.get_client() + db = mongo["docsgpt"] + user_tools_collection = db["user_tools"] + user_tools = user_tools_collection.find({"user": user, "status": True}) + user_tools = list(user_tools) + for tool in user_tools: + tool.pop("_id") + user_tools = {tool["name"]: tool for tool in user_tools} + return user_tools + + def _simple_tool_agent(self, messages): + tools_dict = self._get_user_tools() + # combine all tool_actions into one list + self.tools.extend([ + { + "type": "function", + "function": tool_action + } + for tool in tools_dict.values() + for tool_action in tool["actions"] + ]) + - def gen(self, messages): - # Generate initial response from the LLM resp = self.llm.gen(model=self.gpt_model, messages=messages, tools=self.tools) if isinstance(resp, str): @@ -75,7 +98,7 @@ class Agent: call_id = call.id # Determine the tool name and load it tool_name = call_name.split("_")[0] - tool = tm.load_tool(tool_name, tool_config=self.tool_config) + tool = tm.load_tool(tool_name, tool_config=tools_dict[tool_name]['config']) # Execute the tool's action resp_tool = tool.execute_action(call_name, **call_args) # Append the tool's response to the conversation @@ -96,3 +119,12 @@ class Agent: completion = self.llm.gen_stream(model=self.gpt_model, messages=messages, tools=self.tools) for line in completion: yield line + + def gen(self, messages): + # Generate initial response from the LLM + if self.llm.supports_tools(): + self._simple_tool_agent(messages) + else: + resp = self.llm.gen_stream(model=self.gpt_model, messages=messages) + for line in resp: + yield line \ No newline at end of file