Merge pull request #1473 from arc53/tool-use

Tools + agent
This commit is contained in:
Alex
2024-12-20 18:17:39 +00:00
committed by GitHub
37 changed files with 1604 additions and 100 deletions

22
.vscode/launch.json vendored
View File

@@ -11,6 +11,26 @@
"skipFiles": [
"<node_internals>/**"
]
}
},
{
"name": "Python Debugger: Flask",
"type": "debugpy",
"request": "launch",
"module": "flask",
"env": {
"FLASK_APP": "application/app.py",
"PYTHONPATH": "${workspaceFolder}",
"FLASK_ENV": "development",
"FLASK_DEBUG": "1",
"FLASK_RUN_PORT": "7091",
"FLASK_RUN_HOST": "0.0.0.0"
},
"args": [
"run",
"--no-debugger"
],
"cwd": "${workspaceFolder}",
},
]
}

View File

@@ -37,7 +37,7 @@ api.add_namespace(answer_ns)
gpt_model = ""
# to have some kind of default behaviour
if settings.LLM_NAME == "openai":
gpt_model = "gpt-3.5-turbo"
gpt_model = "gpt-4o-mini"
elif settings.LLM_NAME == "anthropic":
gpt_model = "claude-2"
elif settings.LLM_NAME == "groq":

View File

@@ -1,14 +1,14 @@
import datetime
import math
import os
import shutil
import uuid
import math
from bson.binary import Binary, UuidRepresentation
from bson.dbref import DBRef
from bson.objectid import ObjectId
from flask import Blueprint, jsonify, make_response, request, redirect
from flask_restx import inputs, fields, Namespace, Resource
from flask import Blueprint, jsonify, make_response, redirect, request
from flask_restx import fields, inputs, Namespace, Resource
from werkzeug.utils import secure_filename
from application.api.user.tasks import ingest, ingest_remote
@@ -16,9 +16,10 @@ from application.api.user.tasks import ingest, ingest_remote
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.extensions import api
from application.tools.tool_manager import ToolManager
from application.tts.google_tts import GoogleTTS
from application.utils import check_required_fields
from application.vectorstore.vector_creator import VectorCreator
from application.tts.google_tts import GoogleTTS
mongo = MongoDB.get_client()
db = mongo["docsgpt"]
@@ -30,6 +31,7 @@ api_key_collection = db["api_keys"]
token_usage_collection = db["token_usage"]
shared_conversations_collections = db["shared_conversations"]
user_logs_collection = db["user_logs"]
user_tools_collection = db["user_tools"]
user = Blueprint("user", __name__)
user_ns = Namespace("user", description="User related operations", path="/")
@@ -39,6 +41,9 @@ current_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
tool_config = {}
tool_manager = ToolManager(config=tool_config)
def generate_minute_range(start_date, end_date):
return {
@@ -1802,3 +1807,295 @@ class TextToSpeech(Resource):
)
except Exception as err:
return make_response(jsonify({"success": False, "error": str(err)}), 400)
@user_ns.route("/api/available_tools")
class AvailableTools(Resource):
@api.doc(description="Get available tools for a user")
def get(self):
try:
tools_metadata = []
for tool_name, tool_instance in tool_manager.tools.items():
doc = tool_instance.__doc__.strip()
lines = doc.split("\n", 1)
name = lines[0].strip()
description = lines[1].strip() if len(lines) > 1 else ""
tools_metadata.append(
{
"name": tool_name,
"displayName": name,
"description": description,
"configRequirements": tool_instance.get_config_requirements(),
"actions": tool_instance.get_actions_metadata(),
}
)
except Exception as err:
return make_response(jsonify({"success": False, "error": str(err)}), 400)
return make_response(jsonify({"success": True, "data": tools_metadata}), 200)
@user_ns.route("/api/get_tools")
class GetTools(Resource):
@api.doc(description="Get tools created by a user")
def get(self):
try:
user = "local"
tools = user_tools_collection.find({"user": user})
user_tools = []
for tool in tools:
tool["id"] = str(tool["_id"])
tool.pop("_id")
user_tools.append(tool)
except Exception as err:
return make_response(jsonify({"success": False, "error": str(err)}), 400)
return make_response(jsonify({"success": True, "tools": user_tools}), 200)
@user_ns.route("/api/create_tool")
class CreateTool(Resource):
@api.expect(
api.model(
"CreateToolModel",
{
"name": fields.String(required=True, description="Name of the tool"),
"displayName": fields.String(
required=True, description="Display name for the tool"
),
"description": fields.String(
required=True, description="Tool description"
),
"config": fields.Raw(
required=True, description="Configuration of the tool"
),
"actions": fields.List(
fields.Raw,
required=True,
description="Actions the tool can perform",
),
"status": fields.Boolean(
required=True, description="Status of the tool"
),
},
)
)
@api.doc(description="Create a new tool")
def post(self):
data = request.get_json()
required_fields = [
"name",
"displayName",
"description",
"actions",
"config",
"status",
]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
user = "local"
transformed_actions = []
for action in data["actions"]:
action["active"] = True
if "parameters" in action:
if "properties" in action["parameters"]:
for param_name, param_details in action["parameters"][
"properties"
].items():
param_details["filled_by_llm"] = True
param_details["value"] = ""
transformed_actions.append(action)
try:
new_tool = {
"user": user,
"name": data["name"],
"displayName": data["displayName"],
"description": data["description"],
"actions": transformed_actions,
"config": data["config"],
"status": data["status"],
}
resp = user_tools_collection.insert_one(new_tool)
new_id = str(resp.inserted_id)
except Exception as err:
return make_response(jsonify({"success": False, "error": str(err)}), 400)
return make_response(jsonify({"id": new_id}), 200)
@user_ns.route("/api/update_tool")
class UpdateTool(Resource):
@api.expect(
api.model(
"UpdateToolModel",
{
"id": fields.String(required=True, description="Tool ID"),
"name": fields.String(description="Name of the tool"),
"displayName": fields.String(description="Display name for the tool"),
"description": fields.String(description="Tool description"),
"config": fields.Raw(description="Configuration of the tool"),
"actions": fields.List(
fields.Raw, description="Actions the tool can perform"
),
"status": fields.Boolean(description="Status of the tool"),
},
)
)
@api.doc(description="Update a tool by ID")
def post(self):
data = request.get_json()
required_fields = ["id"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
try:
update_data = {}
if "name" in data:
update_data["name"] = data["name"]
if "displayName" in data:
update_data["displayName"] = data["displayName"]
if "description" in data:
update_data["description"] = data["description"]
if "actions" in data:
update_data["actions"] = data["actions"]
if "config" in data:
update_data["config"] = data["config"]
if "status" in data:
update_data["status"] = data["status"]
user_tools_collection.update_one(
{"_id": ObjectId(data["id"]), "user": "local"},
{"$set": update_data},
)
except Exception as err:
return make_response(jsonify({"success": False, "error": str(err)}), 400)
return make_response(jsonify({"success": True}), 200)
@user_ns.route("/api/update_tool_config")
class UpdateToolConfig(Resource):
@api.expect(
api.model(
"UpdateToolConfigModel",
{
"id": fields.String(required=True, description="Tool ID"),
"config": fields.Raw(
required=True, description="Configuration of the tool"
),
},
)
)
@api.doc(description="Update the configuration of a tool")
def post(self):
data = request.get_json()
required_fields = ["id", "config"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
try:
user_tools_collection.update_one(
{"_id": ObjectId(data["id"])},
{"$set": {"config": data["config"]}},
)
except Exception as err:
return make_response(jsonify({"success": False, "error": str(err)}), 400)
return make_response(jsonify({"success": True}), 200)
@user_ns.route("/api/update_tool_actions")
class UpdateToolActions(Resource):
@api.expect(
api.model(
"UpdateToolActionsModel",
{
"id": fields.String(required=True, description="Tool ID"),
"actions": fields.List(
fields.Raw,
required=True,
description="Actions the tool can perform",
),
},
)
)
@api.doc(description="Update the actions of a tool")
def post(self):
data = request.get_json()
required_fields = ["id", "actions"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
try:
user_tools_collection.update_one(
{"_id": ObjectId(data["id"])},
{"$set": {"actions": data["actions"]}},
)
except Exception as err:
return make_response(jsonify({"success": False, "error": str(err)}), 400)
return make_response(jsonify({"success": True}), 200)
@user_ns.route("/api/update_tool_status")
class UpdateToolStatus(Resource):
@api.expect(
api.model(
"UpdateToolStatusModel",
{
"id": fields.String(required=True, description="Tool ID"),
"status": fields.Boolean(
required=True, description="Status of the tool"
),
},
)
)
@api.doc(description="Update the status of a tool")
def post(self):
data = request.get_json()
required_fields = ["id", "status"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
try:
user_tools_collection.update_one(
{"_id": ObjectId(data["id"])},
{"$set": {"status": data["status"]}},
)
except Exception as err:
return make_response(jsonify({"success": False, "error": str(err)}), 400)
return make_response(jsonify({"success": True}), 200)
@user_ns.route("/api/delete_tool")
class DeleteTool(Resource):
@api.expect(
api.model(
"DeleteToolModel",
{"id": fields.String(required=True, description="Tool ID")},
)
)
@api.doc(description="Delete a tool by ID")
def post(self):
data = request.get_json()
required_fields = ["id"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
try:
result = user_tools_collection.delete_one({"_id": ObjectId(data["id"])})
if result.deleted_count == 0:
return {"success": False, "message": "Tool not found"}, 404
except Exception as err:
return {"success": False, "error": str(err)}, 400
return {"success": True}, 200

View File

@@ -1,8 +1,10 @@
import redis
import time
import json
import logging
import time
from threading import Lock
import redis
from application.core.settings import settings
from application.utils import get_hash
@@ -11,41 +13,47 @@ logger = logging.getLogger(__name__)
_redis_instance = None
_instance_lock = Lock()
def get_redis_instance():
global _redis_instance
if _redis_instance is None:
with _instance_lock:
if _redis_instance is None:
try:
_redis_instance = redis.Redis.from_url(settings.CACHE_REDIS_URL, socket_connect_timeout=2)
_redis_instance = redis.Redis.from_url(
settings.CACHE_REDIS_URL, socket_connect_timeout=2
)
except redis.ConnectionError as e:
logger.error(f"Redis connection error: {e}")
_redis_instance = None
return _redis_instance
def gen_cache_key(*messages, model="docgpt"):
def gen_cache_key(messages, model="docgpt", tools=None):
if not all(isinstance(msg, dict) for msg in messages):
raise ValueError("All messages must be dictionaries.")
messages_str = json.dumps(list(messages), sort_keys=True)
combined = f"{model}_{messages_str}"
messages_str = json.dumps(messages)
tools_str = json.dumps(tools) if tools else ""
combined = f"{model}_{messages_str}_{tools_str}"
cache_key = get_hash(combined)
return cache_key
def gen_cache(func):
def wrapper(self, model, messages, *args, **kwargs):
def wrapper(self, model, messages, stream, tools=None, *args, **kwargs):
try:
cache_key = gen_cache_key(*messages)
cache_key = gen_cache_key(messages, model, tools)
redis_client = get_redis_instance()
if redis_client:
try:
cached_response = redis_client.get(cache_key)
if cached_response:
return cached_response.decode('utf-8')
return cached_response.decode("utf-8")
except redis.ConnectionError as e:
logger.error(f"Redis connection error: {e}")
result = func(self, model, messages, *args, **kwargs)
if redis_client:
result = func(self, model, messages, stream, tools, *args, **kwargs)
if redis_client and isinstance(result, str):
try:
redis_client.set(cache_key, result, ex=1800)
except redis.ConnectionError as e:
@@ -55,11 +63,13 @@ def gen_cache(func):
except ValueError as e:
logger.error(e)
return "Error: No user message found in the conversation to generate a cache key."
return wrapper
def stream_cache(func):
def wrapper(self, model, messages, stream, *args, **kwargs):
cache_key = gen_cache_key(*messages)
cache_key = gen_cache_key(messages)
logger.info(f"Stream cache key: {cache_key}")
redis_client = get_redis_instance()
@@ -68,7 +78,7 @@ def stream_cache(func):
cached_response = redis_client.get(cache_key)
if cached_response:
logger.info(f"Cache hit for stream key: {cache_key}")
cached_response = json.loads(cached_response.decode('utf-8'))
cached_response = json.loads(cached_response.decode("utf-8"))
for chunk in cached_response:
yield chunk
time.sleep(0.03)

View File

@@ -17,7 +17,7 @@ class AnthropicLLM(BaseLLM):
self.AI_PROMPT = AI_PROMPT
def _raw_gen(
self, baseself, model, messages, stream=False, max_tokens=300, **kwargs
self, baseself, model, messages, stream=False, tools=None, max_tokens=300, **kwargs
):
context = messages[0]["content"]
user_question = messages[-1]["content"]
@@ -34,7 +34,7 @@ class AnthropicLLM(BaseLLM):
return completion.completion
def _raw_gen_stream(
self, baseself, model, messages, stream=True, max_tokens=300, **kwargs
self, baseself, model, messages, stream=True, tools=None, max_tokens=300, **kwargs
):
context = messages[0]["content"]
user_question = messages[-1]["content"]

View File

@@ -13,12 +13,12 @@ class BaseLLM(ABC):
return method(self, *args, **kwargs)
@abstractmethod
def _raw_gen(self, model, messages, stream, *args, **kwargs):
def _raw_gen(self, model, messages, stream, tools, *args, **kwargs):
pass
def gen(self, model, messages, stream=False, *args, **kwargs):
def gen(self, model, messages, stream=False, tools=None, *args, **kwargs):
decorators = [gen_token_usage, gen_cache]
return self._apply_decorator(self._raw_gen, decorators=decorators, model=model, messages=messages, stream=stream, *args, **kwargs)
return self._apply_decorator(self._raw_gen, decorators=decorators, model=model, messages=messages, stream=stream, tools=tools, *args, **kwargs)
@abstractmethod
def _raw_gen_stream(self, model, messages, stream, *args, **kwargs):
@@ -27,3 +27,9 @@ class BaseLLM(ABC):
def gen_stream(self, model, messages, stream=True, *args, **kwargs):
decorators = [stream_cache, stream_token_usage]
return self._apply_decorator(self._raw_gen_stream, decorators=decorators, model=model, messages=messages, stream=stream, *args, **kwargs)
def supports_tools(self):
return hasattr(self, '_supports_tools') and callable(getattr(self, '_supports_tools'))
def _supports_tools(self):
raise NotImplementedError("Subclass must implement _supports_tools method")

View File

@@ -1,45 +1,32 @@
from application.llm.base import BaseLLM
from openai import OpenAI
class GroqLLM(BaseLLM):
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):
from openai import OpenAI
super().__init__(*args, **kwargs)
self.client = OpenAI(api_key=api_key, base_url="https://api.groq.com/openai/v1")
self.api_key = api_key
self.user_api_key = user_api_key
def _raw_gen(
self,
baseself,
model,
messages,
stream=False,
**kwargs
):
response = self.client.chat.completions.create(
model=model, messages=messages, stream=stream, **kwargs
)
return response.choices[0].message.content
def _raw_gen(self, baseself, model, messages, stream=False, tools=None, **kwargs):
if tools:
response = self.client.chat.completions.create(
model=model, messages=messages, stream=stream, tools=tools, **kwargs
)
return response.choices[0]
else:
response = self.client.chat.completions.create(
model=model, messages=messages, stream=stream, **kwargs
)
return response.choices[0].message.content
def _raw_gen_stream(
self,
baseself,
model,
messages,
stream=True,
**kwargs
self, baseself, model, messages, stream=True, tools=None, **kwargs
):
response = self.client.chat.completions.create(
model=model, messages=messages, stream=stream, **kwargs
)
for line in response:
# import sys
# print(line.choices[0].delta.content, file=sys.stderr)
if line.choices[0].delta.content is not None:
yield line.choices[0].delta.content

View File

@@ -25,14 +25,20 @@ class OpenAILLM(BaseLLM):
model,
messages,
stream=False,
tools=None,
engine=settings.AZURE_DEPLOYMENT_NAME,
**kwargs
):
response = self.client.chat.completions.create(
model=model, messages=messages, stream=stream, **kwargs
)
return response.choices[0].message.content
if tools:
response = self.client.chat.completions.create(
model=model, messages=messages, stream=stream, tools=tools, **kwargs
)
return response.choices[0]
else:
response = self.client.chat.completions.create(
model=model, messages=messages, stream=stream, **kwargs
)
return response.choices[0].message.content
def _raw_gen_stream(
self,
@@ -40,6 +46,7 @@ class OpenAILLM(BaseLLM):
model,
messages,
stream=True,
tools=None,
engine=settings.AZURE_DEPLOYMENT_NAME,
**kwargs
):
@@ -53,6 +60,9 @@ class OpenAILLM(BaseLLM):
if line.choices[0].delta.content is not None:
yield line.choices[0].delta.content
def _supports_tools(self):
return True
class AzureOpenAILLM(OpenAILLM):

View File

@@ -76,7 +76,7 @@ class SagemakerAPILLM(BaseLLM):
self.endpoint = settings.SAGEMAKER_ENDPOINT
self.runtime = runtime
def _raw_gen(self, baseself, model, messages, stream=False, **kwargs):
def _raw_gen(self, baseself, model, messages, stream=False, tools=None, **kwargs):
context = messages[0]["content"]
user_question = messages[-1]["content"]
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
@@ -105,7 +105,7 @@ class SagemakerAPILLM(BaseLLM):
print(result[0]["generated_text"], file=sys.stderr)
return result[0]["generated_text"][len(prompt) :]
def _raw_gen_stream(self, baseself, model, messages, stream=True, **kwargs):
def _raw_gen_stream(self, baseself, model, messages, stream=True, tools=None, **kwargs):
context = messages[0]["content"]
user_question = messages[-1]["content"]
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"

View File

@@ -43,7 +43,7 @@ multidict==6.1.0
mypy-extensions==1.0.0
networkx==3.3
numpy==1.26.4
openai==1.55.3
openai==1.57.0
openapi-schema-validator==0.6.2
openapi-spec-validator==0.6.0
openapi3-parser==1.1.18

View File

@@ -1,7 +1,8 @@
from application.retriever.base import BaseRetriever
from application.core.settings import settings
from application.retriever.base import BaseRetriever
from application.tools.agent import Agent
from application.vectorstore.vector_creator import VectorCreator
from application.llm.llm_creator import LLMCreator
@@ -19,7 +20,7 @@ class ClassicRAG(BaseRetriever):
user_api_key=None,
):
self.question = question
self.vectorstore = source['active_docs'] if 'active_docs' in source else None
self.vectorstore = source["active_docs"] if "active_docs" in source else None
self.chat_history = chat_history
self.prompt = prompt
self.chunks = chunks
@@ -81,11 +82,17 @@ class ClassicRAG(BaseRetriever):
{"role": "system", "content": i["response"]}
)
messages_combine.append({"role": "user", "content": self.question})
llm = LLMCreator.create_llm(
settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=self.user_api_key
# llm = LLMCreator.create_llm(
# settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=self.user_api_key
# )
# completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine)
agent = Agent(
llm_name=settings.LLM_NAME,
gpt_model=self.gpt_model,
api_key=settings.API_KEY,
user_api_key=self.user_api_key,
)
completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine)
completion = agent.gen(messages_combine)
for line in completion:
yield {"answer": str(line)}
@@ -101,5 +108,5 @@ class ClassicRAG(BaseRetriever):
"chunks": self.chunks,
"token_limit": self.token_limit,
"gpt_model": self.gpt_model,
"user_api_key": self.user_api_key
"user_api_key": self.user_api_key,
}

149
application/tools/agent.py Normal file
View File

@@ -0,0 +1,149 @@
import json
from application.core.mongo_db import MongoDB
from application.llm.llm_creator import LLMCreator
from application.tools.tool_manager import ToolManager
class Agent:
def __init__(self, llm_name, gpt_model, api_key, user_api_key=None):
# Initialize the LLM with the provided parameters
self.llm = LLMCreator.create_llm(
llm_name, api_key=api_key, user_api_key=user_api_key
)
self.gpt_model = gpt_model
# Static tool configuration (to be replaced later)
self.tools = []
self.tool_config = {}
def _get_user_tools(self, user="local"):
mongo = MongoDB.get_client()
db = mongo["docsgpt"]
user_tools_collection = db["user_tools"]
user_tools = user_tools_collection.find({"user": user, "status": True})
user_tools = list(user_tools)
tools_by_id = {str(tool["_id"]): tool for tool in user_tools}
return tools_by_id
def _prepare_tools(self, tools_dict):
self.tools = [
{
"type": "function",
"function": {
"name": f"{action['name']}_{tool_id}",
"description": action["description"],
"parameters": {
**action["parameters"],
"properties": {
k: {
key: value
for key, value in v.items()
if key != "filled_by_llm" and key != "value"
}
for k, v in action["parameters"]["properties"].items()
if v.get("filled_by_llm", False)
},
"required": [
key
for key in action["parameters"]["required"]
if key in action["parameters"]["properties"]
and action["parameters"]["properties"][key].get(
"filled_by_llm", False
)
],
},
},
}
for tool_id, tool in tools_dict.items()
for action in tool["actions"]
if action["active"]
]
def _execute_tool_action(self, tools_dict, call):
call_id = call.id
call_args = json.loads(call.function.arguments)
tool_id = call.function.name.split("_")[-1]
action_name = call.function.name.rsplit("_", 1)[0]
tool_data = tools_dict[tool_id]
action_data = next(
action for action in tool_data["actions"] if action["name"] == action_name
)
for param, details in action_data["parameters"]["properties"].items():
if param not in call_args and "value" in details:
call_args[param] = details["value"]
tm = ToolManager(config={})
tool = tm.load_tool(tool_data["name"], tool_config=tool_data["config"])
print(f"Executing tool: {action_name} with args: {call_args}")
return tool.execute_action(action_name, **call_args), call_id
def _simple_tool_agent(self, messages):
tools_dict = self._get_user_tools()
self._prepare_tools(tools_dict)
resp = self.llm.gen(model=self.gpt_model, messages=messages, tools=self.tools)
if isinstance(resp, str):
yield resp
return
if resp.message.content:
yield resp.message.content
return
while resp.finish_reason == "tool_calls":
message = json.loads(resp.model_dump_json())["message"]
keys_to_remove = {"audio", "function_call", "refusal"}
filtered_data = {
k: v for k, v in message.items() if k not in keys_to_remove
}
messages.append(filtered_data)
tool_calls = resp.message.tool_calls
for call in tool_calls:
try:
tool_response, call_id = self._execute_tool_action(tools_dict, call)
messages.append(
{
"role": "tool",
"content": str(tool_response),
"tool_call_id": call_id,
}
)
except Exception as e:
messages.append(
{
"role": "tool",
"content": f"Error executing tool: {str(e)}",
"tool_call_id": call.id,
}
)
# Generate a new response from the LLM after processing tools
resp = self.llm.gen(
model=self.gpt_model, messages=messages, tools=self.tools
)
# If no tool calls are needed, generate the final response
if isinstance(resp, str):
yield resp
elif resp.message.content:
yield resp.message.content
else:
completion = self.llm.gen_stream(
model=self.gpt_model, messages=messages, tools=self.tools
)
for line in completion:
yield line
return
def gen(self, messages):
# Generate initial response from the LLM
if self.llm.supports_tools():
resp = self._simple_tool_agent(messages)
for line in resp:
yield line
else:
resp = self.llm.gen_stream(model=self.gpt_model, messages=messages)
for line in resp:
yield line

21
application/tools/base.py Normal file
View File

@@ -0,0 +1,21 @@
from abc import ABC, abstractmethod
class Tool(ABC):
@abstractmethod
def execute_action(self, action_name: str, **kwargs):
pass
@abstractmethod
def get_actions_metadata(self):
"""
Returns a list of JSON objects describing the actions supported by the tool.
"""
pass
@abstractmethod
def get_config_requirements(self):
"""
Returns a dictionary describing the configuration requirements for the tool.
"""
pass

View File

@@ -0,0 +1,77 @@
import requests
from application.tools.base import Tool
class CryptoPriceTool(Tool):
"""
CryptoPrice
A tool for retrieving cryptocurrency prices using the CryptoCompare public API
"""
def __init__(self, config):
self.config = config
def execute_action(self, action_name, **kwargs):
actions = {"cryptoprice_get": self._get_price}
if action_name in actions:
return actions[action_name](**kwargs)
else:
raise ValueError(f"Unknown action: {action_name}")
def _get_price(self, symbol, currency):
"""
Fetches the current price of a given cryptocurrency symbol in the specified currency.
Example:
symbol = "BTC"
currency = "USD"
returns price in USD.
"""
url = f"https://min-api.cryptocompare.com/data/price?fsym={symbol.upper()}&tsyms={currency.upper()}"
response = requests.get(url)
if response.status_code == 200:
data = response.json()
# data will be like {"USD": <price>} if the call is successful
if currency.upper() in data:
return {
"status_code": response.status_code,
"price": data[currency.upper()],
"message": f"Price of {symbol.upper()} in {currency.upper()} retrieved successfully.",
}
else:
return {
"status_code": response.status_code,
"message": f"Couldn't find price for {symbol.upper()} in {currency.upper()}.",
}
else:
return {
"status_code": response.status_code,
"message": "Failed to retrieve price.",
}
def get_actions_metadata(self):
return [
{
"name": "cryptoprice_get",
"description": "Retrieve the price of a specified cryptocurrency in a given currency",
"parameters": {
"type": "object",
"properties": {
"symbol": {
"type": "string",
"description": "The cryptocurrency symbol (e.g. BTC)",
},
"currency": {
"type": "string",
"description": "The currency in which you want the price (e.g. USD)",
},
},
"required": ["symbol", "currency"],
"additionalProperties": False,
},
}
]
def get_config_requirements(self):
# No specific configuration needed for this tool as it just queries a public endpoint
return {}

View File

@@ -0,0 +1,86 @@
import requests
from application.tools.base import Tool
class TelegramTool(Tool):
"""
Telegram Bot
A flexible Telegram tool for performing various actions (e.g., sending messages, images).
Requires a bot token and chat ID for configuration
"""
def __init__(self, config):
self.config = config
self.token = config.get("token", "")
def execute_action(self, action_name, **kwargs):
actions = {
"telegram_send_message": self._send_message,
"telegram_send_image": self._send_image,
}
if action_name in actions:
return actions[action_name](**kwargs)
else:
raise ValueError(f"Unknown action: {action_name}")
def _send_message(self, text, chat_id):
print(f"Sending message: {text}")
url = f"https://api.telegram.org/bot{self.token}/sendMessage"
payload = {"chat_id": chat_id, "text": text}
response = requests.post(url, data=payload)
return {"status_code": response.status_code, "message": "Message sent"}
def _send_image(self, image_url, chat_id):
print(f"Sending image: {image_url}")
url = f"https://api.telegram.org/bot{self.token}/sendPhoto"
payload = {"chat_id": chat_id, "photo": image_url}
response = requests.post(url, data=payload)
return {"status_code": response.status_code, "message": "Image sent"}
def get_actions_metadata(self):
return [
{
"name": "telegram_send_message",
"description": "Send a notification to Telegram chat",
"parameters": {
"type": "object",
"properties": {
"text": {
"type": "string",
"description": "Text to send in the notification",
},
"chat_id": {
"type": "string",
"description": "Chat ID to send the notification to",
},
},
"required": ["text"],
"additionalProperties": False,
},
},
{
"name": "telegram_send_image",
"description": "Send an image to the Telegram chat",
"parameters": {
"type": "object",
"properties": {
"image_url": {
"type": "string",
"description": "URL of the image to send",
},
"chat_id": {
"type": "string",
"description": "Chat ID to send the image to",
},
},
"required": ["image_url"],
"additionalProperties": False,
},
},
]
def get_config_requirements(self):
return {
"token": {"type": "string", "description": "Bot token for authentication"},
}

View File

@@ -0,0 +1,46 @@
import importlib
import inspect
import os
import pkgutil
from application.tools.base import Tool
class ToolManager:
def __init__(self, config):
self.config = config
self.tools = {}
self.load_tools()
def load_tools(self):
tools_dir = os.path.join(os.path.dirname(__file__), "implementations")
for finder, name, ispkg in pkgutil.iter_modules([tools_dir]):
if name == "base" or name.startswith("__"):
continue
module = importlib.import_module(
f"application.tools.implementations.{name}"
)
for member_name, obj in inspect.getmembers(module, inspect.isclass):
if issubclass(obj, Tool) and obj is not Tool:
tool_config = self.config.get(name, {})
self.tools[name] = obj(tool_config)
def load_tool(self, tool_name, tool_config):
self.config[tool_name] = tool_config
module = importlib.import_module(
f"application.tools.implementations.{tool_name}"
)
for member_name, obj in inspect.getmembers(module, inspect.isclass):
if issubclass(obj, Tool) and obj is not Tool:
return obj(tool_config)
def execute_action(self, tool_name, action_name, **kwargs):
if tool_name not in self.tools:
raise ValueError(f"Tool '{tool_name}' not loaded")
return self.tools[tool_name].execute_action(action_name, **kwargs)
def get_all_actions_metadata(self):
metadata = []
for tool in self.tools.values():
metadata.extend(tool.get_actions_metadata())
return metadata

View File

@@ -1,7 +1,7 @@
import sys
from datetime import datetime
from application.core.mongo_db import MongoDB
from application.utils import num_tokens_from_string
from application.utils import num_tokens_from_string, num_tokens_from_object_or_list
mongo = MongoDB.get_client()
db = mongo["docsgpt"]
@@ -21,11 +21,16 @@ def update_token_usage(user_api_key, token_usage):
def gen_token_usage(func):
def wrapper(self, model, messages, stream, **kwargs):
def wrapper(self, model, messages, stream, tools, **kwargs):
for message in messages:
self.token_usage["prompt_tokens"] += num_tokens_from_string(message["content"])
result = func(self, model, messages, stream, **kwargs)
self.token_usage["generated_tokens"] += num_tokens_from_string(result)
if message["content"]:
self.token_usage["prompt_tokens"] += num_tokens_from_string(message["content"])
result = func(self, model, messages, stream, tools, **kwargs)
# check if result is a string
if isinstance(result, str):
self.token_usage["generated_tokens"] += num_tokens_from_string(result)
else:
self.token_usage["generated_tokens"] += num_tokens_from_object_or_list(result)
update_token_usage(self.user_api_key, self.token_usage)
return result
@@ -33,11 +38,11 @@ def gen_token_usage(func):
def stream_token_usage(func):
def wrapper(self, model, messages, stream, **kwargs):
def wrapper(self, model, messages, stream, tools, **kwargs):
for message in messages:
self.token_usage["prompt_tokens"] += num_tokens_from_string(message["content"])
batch = []
result = func(self, model, messages, stream, **kwargs)
result = func(self, model, messages, stream, tools, **kwargs)
for r in result:
batch.append(r)
yield r

View File

@@ -15,9 +15,21 @@ def get_encoding():
def num_tokens_from_string(string: str) -> int:
encoding = get_encoding()
num_tokens = len(encoding.encode(string))
return num_tokens
if isinstance(string, str):
num_tokens = len(encoding.encode(string))
return num_tokens
else:
return 0
def num_tokens_from_object_or_list(thing):
if isinstance(thing, list):
return sum([num_tokens_from_object_or_list(x) for x in thing])
elif isinstance(thing, dict):
return sum([num_tokens_from_object_or_list(x) for x in thing.values()])
elif isinstance(thing, str):
return num_tokens_from_string(thing)
else:
return 0
def count_tokens_docs(docs):
docs_content = ""

View File

@@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 122.88 122.88"><path d="M17.89 0h88.9c8.85 0 16.1 7.24 16.1 16.1v90.68c0 8.85-7.24 16.1-16.1 16.1H16.1c-8.85 0-16.1-7.24-16.1-16.1v-88.9C0 8.05 8.05 0 17.89 0zm57.04 66.96l16.46 4.96c-1.1 4.61-2.84 8.47-5.23 11.56-2.38 3.1-5.32 5.43-8.85 7-3.52 1.57-8.01 2.36-13.45 2.36-6.62 0-12.01-.96-16.21-2.87-4.19-1.92-7.79-5.3-10.83-10.13-3.04-4.82-4.57-11.02-4.57-18.54 0-10.04 2.67-17.76 8.02-23.17 5.36-5.39 12.93-8.09 22.71-8.09 7.65 0 13.68 1.54 18.06 4.64 4.37 3.1 7.64 7.85 9.76 14.27l-16.55 3.66c-.58-1.84-1.19-3.18-1.82-4.03-1.06-1.43-2.35-2.53-3.86-3.3-1.53-.78-3.22-1.16-5.11-1.16-4.27 0-7.54 1.71-9.8 5.12-1.71 2.53-2.57 6.52-2.57 11.94 0 6.73 1.02 11.33 3.07 13.83 2.05 2.49 4.92 3.73 8.63 3.73 3.59 0 6.31-1 8.15-3.03 1.83-1.99 3.16-4.92 3.99-8.75z" fill-rule="evenodd" clip-rule="evenodd"/></svg>

After

Width:  |  Height:  |  Size: 855 B

View File

@@ -0,0 +1,10 @@
<svg width="24" height="25" viewBox="0 0 24 25" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M12 0.5C8.81812 0.5 5.76375 1.76506 3.51562 4.01469C1.2652 6.26522 0.000643966 9.31734 0 12.5C0 15.6813 1.26562 18.7357 3.51562 20.9853C5.76375 23.2349 8.81812 24.5 12 24.5C15.1819 24.5 18.2362 23.2349 20.4844 20.9853C22.7344 18.7357 24 15.6813 24 12.5C24 9.31869 22.7344 6.26431 20.4844 4.01469C18.2362 1.76506 15.1819 0.5 12 0.5Z" fill="url(#paint0_linear_5586_9958)"/>
<path d="M5.43282 12.373C8.93157 10.849 11.2641 9.8443 12.4303 9.3588C15.7641 7.97261 16.4559 7.73186 16.9078 7.7237C17.0072 7.72211 17.2284 7.74667 17.3728 7.86339C17.4928 7.96183 17.5266 8.09495 17.5434 8.18842C17.5584 8.2818 17.5791 8.49461 17.5622 8.66074C17.3822 10.5582 16.6003 15.1629 16.2028 17.2882C16.0359 18.1874 15.7041 18.4889 15.3834 18.5184C14.6859 18.5825 14.1572 18.0579 13.4822 17.6155C12.4266 16.9231 11.8303 16.4922 10.8047 15.8167C9.6197 15.0359 10.3884 14.6067 11.0634 13.9055C11.2397 13.7219 14.3109 10.9291 14.3691 10.6758C14.3766 10.6441 14.3841 10.526 14.3128 10.4637C14.2434 10.4013 14.1403 10.4227 14.0653 10.4395C13.9584 10.4635 12.2728 11.5788 9.00282 13.7851C8.52469 14.114 8.09157 14.2743 7.70157 14.2659C7.27407 14.2567 6.44907 14.0236 5.83595 13.8245C5.08595 13.5802 4.48782 13.451 4.54032 13.036C4.56657 12.82 4.8647 12.599 5.43282 12.373Z" fill="white"/>
<defs>
<linearGradient id="paint0_linear_5586_9958" x1="1200" y1="0.5" x2="1200" y2="2400.5" gradientUnits="userSpaceOnUse">
<stop stop-color="#2AABEE"/>
<stop offset="1" stop-color="#229ED9"/>
</linearGradient>
</defs>
</svg>

After

Width:  |  Height:  |  Size: 1.6 KiB

View File

@@ -18,6 +18,11 @@ const endpoints = {
FEEDBACK_ANALYTICS: '/api/get_feedback_analytics',
LOGS: `/api/get_user_logs`,
MANAGE_SYNC: '/api/manage_sync',
GET_AVAILABLE_TOOLS: '/api/available_tools',
GET_USER_TOOLS: '/api/get_tools',
CREATE_TOOL: '/api/create_tool',
UPDATE_TOOL_STATUS: '/api/update_tool_status',
UPDATE_TOOL: '/api/update_tool',
},
CONVERSATION: {
ANSWER: '/api/answer',

View File

@@ -35,6 +35,16 @@ const userService = {
apiClient.post(endpoints.USER.LOGS, data),
manageSync: (data: any): Promise<any> =>
apiClient.post(endpoints.USER.MANAGE_SYNC, data),
getAvailableTools: (): Promise<any> =>
apiClient.get(endpoints.USER.GET_AVAILABLE_TOOLS),
getUserTools: (): Promise<any> =>
apiClient.get(endpoints.USER.GET_USER_TOOLS),
createTool: (data: any): Promise<any> =>
apiClient.post(endpoints.USER.CREATE_TOOL, data),
updateToolStatus: (data: any): Promise<any> =>
apiClient.post(endpoints.USER.UPDATE_TOOL_STATUS, data),
updateTool: (data: any): Promise<any> =>
apiClient.post(endpoints.USER.UPDATE_TOOL, data),
};
export default userService;

View File

@@ -0,0 +1,3 @@
<svg width="16" height="17" viewBox="0 0 16 17" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M8.00182 11.2502C7.23838 11.2502 6.50621 10.9552 5.96637 10.4301C5.42654 9.90499 5.12327 9.1928 5.12327 8.4502C5.12327 7.70759 5.42654 6.9954 5.96637 6.4703C6.50621 5.9452 7.23838 5.6502 8.00182 5.6502C8.76525 5.6502 9.49743 5.9452 10.0373 6.4703C10.5771 6.9954 10.8804 7.70759 10.8804 8.4502C10.8804 9.1928 10.5771 9.90499 10.0373 10.4301C9.49743 10.9552 8.76525 11.2502 8.00182 11.2502ZM14.1126 9.2262C14.1455 8.9702 14.1701 8.7142 14.1701 8.4502C14.1701 8.1862 14.1455 7.9222 14.1126 7.6502L15.8479 6.3462C16.0042 6.2262 16.0453 6.0102 15.9466 5.8342L14.3017 3.0662C14.203 2.8902 13.981 2.8182 13.8 2.8902L11.7522 3.6902C11.3245 3.3782 10.8804 3.1062 10.3622 2.9062L10.0579 0.786197C10.0412 0.69197 9.99076 0.606538 9.91549 0.545038C9.84022 0.483538 9.745 0.449939 9.6467 0.450197H6.35693C6.15132 0.450197 5.97861 0.594197 5.94571 0.786197L5.64141 2.9062C5.12327 3.1062 4.67915 3.3782 4.25148 3.6902L2.2036 2.8902C2.02266 2.8182 1.8006 2.8902 1.70191 3.0662L0.0570212 5.8342C-0.0498963 6.0102 -0.00054964 6.2262 0.155714 6.3462L1.89107 7.6502C1.85817 7.9222 1.8335 8.1862 1.8335 8.4502C1.8335 8.7142 1.85817 8.9702 1.89107 9.2262L0.155714 10.5542C-0.00054964 10.6742 -0.0498963 10.8902 0.0570212 11.0662L1.70191 13.8342C1.8006 14.0102 2.02266 14.0742 2.2036 14.0102L4.25148 13.2022C4.67915 13.5222 5.12327 13.7942 5.64141 13.9942L5.94571 16.1142C5.97861 16.3062 6.15132 16.4502 6.35693 16.4502H9.6467C9.85231 16.4502 10.025 16.3062 10.0579 16.1142L10.3622 13.9942C10.8804 13.7862 11.3245 13.5222 11.7522 13.2022L13.8 14.0102C13.981 14.0742 14.203 14.0102 14.3017 13.8342L15.9466 11.0662C16.0453 10.8902 16.0042 10.6742 15.8479 10.5542L14.1126 9.2262Z" fill="#747474"/>
</svg>

After

Width:  |  Height:  |  Size: 1.7 KiB

View File

@@ -13,6 +13,7 @@ const useTabs = () => {
t('settings.apiKeys.label'),
t('settings.analytics.label'),
t('settings.logs.label'),
t('settings.tools.label'),
];
return tabs;
};

View File

@@ -72,6 +72,15 @@ body.dark {
.table-default td:last-child {
@apply border-r-0; /* Ensure no right border on the last column */
}
.table-default th,
.table-default td {
min-width: 150px;
max-width: 320px;
overflow: auto;
scrollbar-width: thin;
scrollbar-color: grey transparent;
}
}
/*! normalize.css v8.0.1 | MIT License | github.com/necolas/normalize.css */

View File

@@ -73,6 +73,9 @@
},
"logs": {
"label": "Logs"
},
"tools": {
"label": "Tools"
}
},
"modals": {

View File

@@ -0,0 +1,136 @@
import React from 'react';
import userService from '../api/services/userService';
import Exit from '../assets/exit.svg';
import { ActiveState } from '../models/misc';
import { AvailableTool } from './types';
import ConfigToolModal from './ConfigToolModal';
export default function AddToolModal({
message,
modalState,
setModalState,
getUserTools,
}: {
message: string;
modalState: ActiveState;
setModalState: (state: ActiveState) => void;
getUserTools: () => void;
}) {
const [availableTools, setAvailableTools] = React.useState<AvailableTool[]>(
[],
);
const [selectedTool, setSelectedTool] = React.useState<AvailableTool | null>(
null,
);
const [configModalState, setConfigModalState] =
React.useState<ActiveState>('INACTIVE');
const getAvailableTools = () => {
userService
.getAvailableTools()
.then((res) => {
return res.json();
})
.then((data) => {
setAvailableTools(data.data);
});
};
const handleAddTool = (tool: AvailableTool) => {
if (Object.keys(tool.configRequirements).length === 0) {
userService
.createTool({
name: tool.name,
displayName: tool.displayName,
description: tool.description,
config: {},
actions: tool.actions,
status: true,
})
.then((res) => {
if (res.status === 200) {
getUserTools();
setModalState('INACTIVE');
}
});
} else {
setModalState('INACTIVE');
setConfigModalState('ACTIVE');
}
};
React.useEffect(() => {
if (modalState === 'ACTIVE') getAvailableTools();
}, [modalState]);
return (
<>
<div
className={`${
modalState === 'ACTIVE' ? 'visible' : 'hidden'
} fixed top-0 left-0 z-30 h-screen w-screen bg-gray-alpha flex items-center justify-center`}
>
<article className="flex h-[85vh] w-[90vw] md:w-[75vw] flex-col gap-4 rounded-2xl bg-[#FBFBFB] shadow-lg dark:bg-[#26272E]">
<div className="relative">
<button
className="absolute top-3 right-4 m-2 w-3"
onClick={() => {
setModalState('INACTIVE');
}}
>
<img className="filter dark:invert" src={Exit} />
</button>
<div className="p-6">
<h2 className="font-semibold text-xl text-jet dark:text-bright-gray px-3">
Select a tool to set up
</h2>
<div className="mt-5 grid grid-cols-3 gap-4 h-[73vh] overflow-auto px-3 py-px">
{availableTools.map((tool, index) => (
<div
role="button"
tabIndex={0}
key={index}
className="h-52 w-full p-6 border rounded-2xl border-silver dark:border-[#4D4E58] flex flex-col justify-between dark:bg-[#32333B] cursor-pointer"
onClick={() => {
setSelectedTool(tool);
handleAddTool(tool);
}}
onKeyDown={(e) => {
if (e.key === 'Enter' || e.key === ' ') {
setSelectedTool(tool);
handleAddTool(tool);
}
}}
>
<div className="w-full">
<div className="px-1 w-full flex items-center justify-between">
<img
src={`/toolIcons/tool_${tool.name}.svg`}
className="h-8 w-8"
/>
</div>
<div className="mt-[9px]">
<p className="px-1 text-sm font-semibold text-eerie-black dark:text-white leading-relaxed capitalize">
{tool.displayName}
</p>
<p className="mt-1 px-1 h-24 overflow-auto text-sm text-gray-600 dark:text-[#8a8a8c] leading-relaxed">
{tool.description}
</p>
</div>
</div>
</div>
))}
</div>
</div>
</div>
</article>
</div>
<ConfigToolModal
modalState={configModalState}
setModalState={setConfigModalState}
tool={selectedTool}
getUserTools={getUserTools}
/>
</>
);
}

View File

@@ -0,0 +1,95 @@
import React from 'react';
import Exit from '../assets/exit.svg';
import Input from '../components/Input';
import { ActiveState } from '../models/misc';
import { AvailableTool } from './types';
import userService from '../api/services/userService';
export default function ConfigToolModal({
modalState,
setModalState,
tool,
getUserTools,
}: {
modalState: ActiveState;
setModalState: (state: ActiveState) => void;
tool: AvailableTool | null;
getUserTools: () => void;
}) {
const [authKey, setAuthKey] = React.useState<string>('');
const handleAddTool = (tool: AvailableTool) => {
userService
.createTool({
name: tool.name,
displayName: tool.displayName,
description: tool.description,
config: { token: authKey },
actions: tool.actions,
status: true,
})
.then(() => {
setModalState('INACTIVE');
getUserTools();
});
};
return (
<div
className={`${
modalState === 'ACTIVE' ? 'visible' : 'hidden'
} fixed top-0 left-0 z-30 h-screen w-screen bg-gray-alpha flex items-center justify-center`}
>
<article className="flex w-11/12 sm:w-[512px] flex-col gap-4 rounded-2xl bg-white shadow-lg dark:bg-[#26272E]">
<div className="relative">
<button
className="absolute top-3 right-4 m-2 w-3"
onClick={() => {
setModalState('INACTIVE');
}}
>
<img className="filter dark:invert" src={Exit} />
</button>
<div className="p-6">
<h2 className="font-semibold text-xl text-jet dark:text-bright-gray px-3">
Tool Config
</h2>
<p className="mt-5 text-sm text-gray-600 dark:text-gray-400 px-3">
Type: <span className="font-semibold">{tool?.name} </span>
</p>
<div className="mt-6 relative px-3">
<span className="absolute left-5 -top-2 bg-white px-2 text-xs text-gray-4000 dark:bg-[#26272E] dark:text-silver">
API Key / Oauth
</span>
<Input
type="text"
value={authKey}
onChange={(e) => setAuthKey(e.target.value)}
borderVariant="thin"
placeholder="Enter API Key / Oauth"
></Input>
</div>
<div className="mt-8 flex flex-row-reverse gap-1 px-3">
<button
onClick={() => {
handleAddTool(tool as AvailableTool);
}}
className="rounded-3xl bg-purple-30 px-5 py-2 text-sm text-white transition-all hover:bg-[#6F3FD1]"
>
Add Tool
</button>
<button
onClick={() => {
setModalState('INACTIVE');
}}
className="cursor-pointer rounded-3xl px-5 py-2 text-sm font-medium hover:bg-gray-100 dark:bg-transparent dark:text-light-gray dark:hover:bg-[#767183]/50"
>
Close
</button>
</div>
</div>
</div>
</article>
</div>
);
}

View File

@@ -1,3 +1,15 @@
export type AvailableTool = {
name: string;
displayName: string;
description: string;
configRequirements: object;
actions: {
name: string;
description: string;
parameters: object;
}[];
};
export type WrapperModalProps = {
children?: React.ReactNode;
isPerformingTask?: boolean;

View File

@@ -181,7 +181,7 @@ const Documents: React.FC<DocumentsProps> = ({
{loading ? (
<SkeletonLoader count={1} />
) : (
<div className="flex flex-col">
<div className="flex flex-col">
<div className="flex-grow">
<div className="dark:border-silver/40 border-silver rounded-md border overflow-auto">
<table className="min-w-full divide-y divide-silver dark:divide-silver/40 text-xs sm:text-sm ">

View File

@@ -0,0 +1,293 @@
import React from 'react';
import userService from '../api/services/userService';
import ArrowLeft from '../assets/arrow-left.svg';
import Input from '../components/Input';
import { UserTool } from './types';
export default function ToolConfig({
tool,
setTool,
handleGoBack,
}: {
tool: UserTool;
setTool: (tool: UserTool) => void;
handleGoBack: () => void;
}) {
const [authKey, setAuthKey] = React.useState<string>(
tool.config?.token || '',
);
const handleCheckboxChange = (actionIndex: number, property: string) => {
setTool({
...tool,
actions: tool.actions.map((action, index) => {
if (index === actionIndex) {
return {
...action,
parameters: {
...action.parameters,
properties: {
...action.parameters.properties,
[property]: {
...action.parameters.properties[property],
filled_by_llm:
!action.parameters.properties[property].filled_by_llm,
},
},
},
};
}
return action;
}),
});
};
const handleSaveChanges = () => {
userService
.updateTool({
id: tool.id,
name: tool.name,
displayName: tool.displayName,
description: tool.description,
config: { token: authKey },
actions: tool.actions,
status: tool.status,
})
.then(() => {
handleGoBack();
});
};
return (
<div className="mt-8 flex flex-col gap-4">
<div className="mb-4 flex items-center gap-3 text-eerie-black dark:text-bright-gray text-sm">
<button
className="text-sm text-gray-400 dark:text-gray-500 border dark:border-0 dark:bg-[#28292D] dark:hover:bg-[#2E2F34] p-3 rounded-full"
onClick={handleGoBack}
>
<img src={ArrowLeft} alt="left-arrow" className="w-3 h-3" />
</button>
<p className="mt-px">Back to all tools</p>
</div>
<div>
<p className="text-sm font-semibold text-eerie-black dark:text-bright-gray">
Type
</p>
<p className="mt-1 text-base font-normal text-eerie-black dark:text-bright-gray font-sans">
{tool.name}
</p>
</div>
<div className="mt-1">
{Object.keys(tool?.config).length !== 0 && (
<p className="text-sm font-semibold text-eerie-black dark:text-bright-gray">
Authentication
</p>
)}
<div className="mt-4 flex items-center gap-2">
{Object.keys(tool?.config).length !== 0 && (
<div className="relative w-96">
<span className="absolute left-5 -top-2 bg-white px-2 text-xs text-gray-4000 dark:bg-[#26272E] dark:text-silver">
API Key / Oauth
</span>
<Input
type="text"
value={authKey}
onChange={(e) => setAuthKey(e.target.value)}
borderVariant="thin"
placeholder="Enter API Key / Oauth"
></Input>
</div>
)}
<button
className="rounded-full h-10 w-36 bg-purple-30 text-white hover:bg-[#6F3FD1] text-nowrap text-sm"
onClick={handleSaveChanges}
>
Save changes
</button>
</div>
</div>
<div className="flex flex-col gap-4">
<div className="mx-1 my-2 h-[0.8px] w-full rounded-full bg-[#C4C4C4]/40 lg:w-[95%] "></div>
<p className="text-base font-semibold text-eerie-black dark:text-bright-gray">
Actions
</p>
<div className="flex flex-col gap-10">
{tool.actions.map((action, actionIndex) => {
return (
<div
key={actionIndex}
className="w-full border border-silver dark:border-silver/40 rounded-xl"
>
<div className="h-10 bg-[#F9F9F9] dark:bg-[#28292D] rounded-t-xl border-b border-silver dark:border-silver/40 flex items-center justify-between px-5">
<p className="font-semibold text-eerie-black dark:text-bright-gray">
{action.name}
</p>
<label
htmlFor={`actionToggle-${actionIndex}`}
className="relative inline-block h-6 w-10 cursor-pointer rounded-full bg-gray-300 dark:bg-[#D2D5DA33]/20 transition [-webkit-tap-highlight-color:_transparent] has-[:checked]:bg-[#0C9D35CC] has-[:checked]:dark:bg-[#0C9D35CC]"
>
<input
type="checkbox"
id={`actionToggle-${actionIndex}`}
className="peer sr-only"
checked={action.active}
onChange={() => {
setTool({
...tool,
actions: tool.actions.map((act, index) => {
if (index === actionIndex) {
return { ...act, active: !act.active };
}
return act;
}),
});
}}
/>
<span className="absolute inset-y-0 start-0 m-[3px] size-[18px] rounded-full bg-white transition-all peer-checked:start-4"></span>
</label>
</div>
<div className="mt-5 relative px-5 w-96">
<Input
type="text"
placeholder="Enter description"
value={action.description}
onChange={(e) => {
setTool({
...tool,
actions: tool.actions.map((act, index) => {
if (index === actionIndex) {
return {
...act,
description: e.target.value,
};
}
return act;
}),
});
}}
borderVariant="thin"
></Input>
</div>
<div className="px-5 py-4">
<table className="table-default">
<thead>
<tr>
<th>Field Name</th>
<th>Field Type</th>
<th>Filled by LLM</th>
<th>FIeld description</th>
<th>Value</th>
</tr>
</thead>
<tbody>
{Object.entries(action.parameters?.properties).map(
(param, index) => {
const uniqueKey = `${actionIndex}-${param[0]}`;
return (
<tr key={index} className="text-nowrap font-normal">
<td>{param[0]}</td>
<td>{param[1].type}</td>
<td>
<label
htmlFor={uniqueKey}
className="ml-[10px] flex cursor-pointer items-start gap-4"
>
<div className="flex items-center">
&#8203;
<input
checked={param[1].filled_by_llm}
id={uniqueKey}
type="checkbox"
className="size-4 rounded border-gray-300 bg-transparent"
onChange={() =>
handleCheckboxChange(
actionIndex,
param[0],
)
}
/>
</div>
</label>
</td>
<td className="w-10">
<input
key={uniqueKey}
value={param[1].description}
className="bg-transparent border border-silver dark:border-silver/40 outline-none px-2 py-1 rounded-lg text-sm"
onChange={(e) => {
setTool({
...tool,
actions: tool.actions.map(
(act, index) => {
if (index === actionIndex) {
return {
...act,
parameters: {
...act.parameters,
properties: {
...act.parameters.properties,
[param[0]]: {
...act.parameters
.properties[param[0]],
description: e.target.value,
},
},
},
};
}
return act;
},
),
});
}}
></input>
</td>
<td>
<input
value={param[1].value}
key={uniqueKey}
disabled={param[1].filled_by_llm}
className={`bg-transparent border border-silver dark:border-silver/40 outline-none px-2 py-1 rounded-lg text-sm ${param[1].filled_by_llm ? 'opacity-50' : ''}`}
onChange={(e) => {
setTool({
...tool,
actions: tool.actions.map(
(act, index) => {
if (index === actionIndex) {
return {
...act,
parameters: {
...act.parameters,
properties: {
...act.parameters.properties,
[param[0]]: {
...act.parameters
.properties[param[0]],
value: e.target.value,
},
},
},
};
}
return act;
},
),
});
}}
></input>
</td>
</tr>
);
},
)}
</tbody>
</table>
</div>
</div>
);
})}
</div>
</div>
</div>
);
}

View File

@@ -0,0 +1,157 @@
import React from 'react';
import userService from '../api/services/userService';
import CogwheelIcon from '../assets/cogwheel.svg';
import Input from '../components/Input';
import AddToolModal from '../modals/AddToolModal';
import { ActiveState } from '../models/misc';
import { UserTool } from './types';
import ToolConfig from './ToolConfig';
export default function Tools() {
const [searchTerm, setSearchTerm] = React.useState('');
const [addToolModalState, setAddToolModalState] =
React.useState<ActiveState>('INACTIVE');
const [userTools, setUserTools] = React.useState<UserTool[]>([]);
const [selectedTool, setSelectedTool] = React.useState<UserTool | null>(null);
const getUserTools = () => {
userService
.getUserTools()
.then((res) => {
return res.json();
})
.then((data) => {
setUserTools(data.tools);
});
};
const updateToolStatus = (toolId: string, newStatus: boolean) => {
userService
.updateToolStatus({ id: toolId, status: newStatus })
.then(() => {
setUserTools((prevTools) =>
prevTools.map((tool) =>
tool.id === toolId ? { ...tool, status: newStatus } : tool,
),
);
})
.catch((error) => {
console.error('Failed to update tool status:', error);
});
};
const handleSettingsClick = (tool: UserTool) => {
setSelectedTool(tool);
};
const handleGoBack = () => {
setSelectedTool(null);
getUserTools();
};
React.useEffect(() => {
getUserTools();
}, []);
return (
<div>
{selectedTool ? (
<ToolConfig
tool={selectedTool}
setTool={setSelectedTool}
handleGoBack={handleGoBack}
/>
) : (
<div className="mt-8">
<div className="flex flex-col relative">
<div className="my-3 flex justify-between items-center gap-1">
<div className="p-1">
<Input
maxLength={256}
placeholder="Search..."
name="Document-search-input"
type="text"
id="document-search-input"
value={searchTerm}
onChange={(e) => setSearchTerm(e.target.value)}
/>
</div>
<button
className="rounded-full w-40 bg-purple-30 px-4 py-3 text-white hover:bg-[#6F3FD1] text-nowrap"
onClick={() => {
setAddToolModalState('ACTIVE');
}}
>
Add Tool
</button>
</div>
<div className="grid grid-cols-2 lg:grid-cols-3 gap-6">
{userTools
.filter((tool) =>
tool.displayName
.toLowerCase()
.includes(searchTerm.toLowerCase()),
)
.map((tool, index) => (
<div
key={index}
className="relative h-56 w-full p-6 border rounded-2xl border-silver dark:border-silver/40 flex flex-col justify-between"
>
<div className="w-full">
<div className="w-full flex items-center justify-between">
<img
src={`/toolIcons/tool_${tool.name}.svg`}
className="h-8 w-8"
/>
<button
className="absolute top-3 right-3 cursor-pointer"
onClick={() => handleSettingsClick(tool)}
>
<img
src={CogwheelIcon}
alt="settings"
className="h-[19px] w-[19px]"
/>
</button>
</div>
<div className="mt-[9px]">
<p className="text-sm font-semibold text-eerie-black dark:text-[#EEEEEE] leading-relaxed">
{tool.displayName}
</p>
<p className="mt-1 h-16 overflow-auto text-[13px] text-gray-600 dark:text-gray-400 leading-relaxed pr-1">
{tool.description}
</p>
</div>
</div>
<div className="absolute bottom-3 right-3">
<label
htmlFor={`toolToggle-${index}`}
className="relative inline-block h-6 w-10 cursor-pointer rounded-full bg-gray-300 dark:bg-[#D2D5DA33]/20 transition [-webkit-tap-highlight-color:_transparent] has-[:checked]:bg-[#0C9D35CC] has-[:checked]:dark:bg-[#0C9D35CC]"
>
<input
type="checkbox"
id={`toolToggle-${index}`}
className="peer sr-only"
checked={tool.status}
onChange={() =>
updateToolStatus(tool.id, !tool.status)
}
/>
<span className="absolute inset-y-0 start-0 m-[3px] size-[18px] rounded-full bg-white transition-all peer-checked:start-4"></span>
</label>
</div>
</div>
))}
</div>
</div>
<AddToolModal
message="Select a tool to set up"
modalState={addToolModalState}
setModalState={setAddToolModalState}
getUserTools={getUserTools}
/>
</div>
)}
</div>
);
}

View File

@@ -7,8 +7,8 @@ import SettingsBar from '../components/SettingsBar';
import i18n from '../locale/i18n';
import { Doc } from '../models/misc';
import {
selectSourceDocs,
selectPaginatedDocuments,
selectSourceDocs,
setPaginatedDocuments,
setSourceDocs,
} from '../preferences/preferenceSlice';
@@ -17,6 +17,7 @@ import APIKeys from './APIKeys';
import Documents from './Documents';
import General from './General';
import Logs from './Logs';
import Tools from './Tools';
import Widgets from './Widgets';
export default function Settings() {
@@ -100,6 +101,8 @@ export default function Settings() {
return <Analytics />;
case t('settings.logs.label'):
return <Logs />;
case t('settings.tools.label'):
return <Tools />;
default:
return null;
}

View File

@@ -18,3 +18,32 @@ export type LogData = {
retriever_params: Record<string, any>;
timestamp: string;
};
export type UserTool = {
id: string;
name: string;
displayName: string;
description: string;
status: boolean;
config: {
[key: string]: string;
};
actions: {
name: string;
description: string;
parameters: {
properties: {
[key: string]: {
type: string;
description: string;
filled_by_llm: boolean;
value: string;
};
};
additionalProperties: boolean;
required: string[];
type: string;
};
active: boolean;
}[];
};

View File

@@ -46,6 +46,7 @@ class TestAnthropicLLM(unittest.TestCase):
{"content": "question"}
]
mock_responses = [Mock(completion="response_1"), Mock(completion="response_2")]
mock_tools = Mock()
with patch("application.cache.get_redis_instance") as mock_make_redis:
mock_redis_instance = mock_make_redis.return_value
@@ -53,7 +54,7 @@ class TestAnthropicLLM(unittest.TestCase):
mock_redis_instance.set = Mock()
with patch.object(self.llm.anthropic.completions, "create", return_value=iter(mock_responses)) as mock_create:
responses = list(self.llm.gen_stream("test_model", messages))
responses = list(self.llm.gen_stream("test_model", messages, tools=mock_tools))
self.assertListEqual(responses, ["response_1", "response_2"])
prompt_expected = "### Context \n context \n ### Question \n question"

View File

@@ -76,7 +76,7 @@ class TestSagemakerAPILLM(unittest.TestCase):
with patch.object(self.sagemaker.runtime, 'invoke_endpoint_with_response_stream',
return_value=self.response) as mock_invoke_endpoint:
output = list(self.sagemaker.gen_stream(None, self.messages))
output = list(self.sagemaker.gen_stream(None, self.messages, tools=None))
mock_invoke_endpoint.assert_called_once_with(
EndpointName=self.sagemaker.endpoint,
ContentType='application/json',

View File

@@ -12,18 +12,21 @@ def test_make_gen_cache_key():
{'role': 'system', 'content': 'test_system_message'},
]
model = "test_docgpt"
tools = None
# Manually calculate the expected hash
expected_combined = f"{model}_{json.dumps(messages, sort_keys=True)}"
messages_str = json.dumps(messages)
tools_str = json.dumps(tools) if tools else ""
expected_combined = f"{model}_{messages_str}_{tools_str}"
expected_hash = get_hash(expected_combined)
cache_key = gen_cache_key(*messages, model=model)
cache_key = gen_cache_key(messages, model=model, tools=None)
assert cache_key == expected_hash
def test_gen_cache_key_invalid_message_format():
# Test when messages is not a list
with unittest.TestCase.assertRaises(unittest.TestCase, ValueError) as context:
gen_cache_key("This is not a list", model="docgpt")
gen_cache_key("This is not a list", model="docgpt", tools=None)
assert str(context.exception) == "All messages must be dictionaries."
# Test for gen_cache decorator
@@ -35,14 +38,14 @@ def test_gen_cache_hit(mock_make_redis):
mock_redis_instance.get.return_value = b"cached_result" # Simulate a cache hit
@gen_cache
def mock_function(self, model, messages):
def mock_function(self, model, messages, stream, tools):
return "new_result"
messages = [{'role': 'user', 'content': 'test_user_message'}]
model = "test_docgpt"
# Act
result = mock_function(None, model, messages)
result = mock_function(None, model, messages, stream=False, tools=None)
# Assert
assert result == "cached_result" # Should return cached result
@@ -58,7 +61,7 @@ def test_gen_cache_miss(mock_make_redis):
mock_redis_instance.get.return_value = None # Simulate a cache miss
@gen_cache
def mock_function(self, model, messages):
def mock_function(self, model, messages, steam, tools):
return "new_result"
messages = [
@@ -67,7 +70,7 @@ def test_gen_cache_miss(mock_make_redis):
]
model = "test_docgpt"
# Act
result = mock_function(None, model, messages)
result = mock_function(None, model, messages, stream=False, tools=None)
# Assert
assert result == "new_result"
@@ -83,14 +86,14 @@ def test_stream_cache_hit(mock_make_redis):
mock_redis_instance.get.return_value = cached_chunk
@stream_cache
def mock_function(self, model, messages, stream):
def mock_function(self, model, messages, stream, tools):
yield "new_chunk"
messages = [{'role': 'user', 'content': 'test_user_message'}]
model = "test_docgpt"
# Act
result = list(mock_function(None, model, messages, stream=True))
result = list(mock_function(None, model, messages, stream=True, tools=None))
# Assert
assert result == ["chunk1", "chunk2"] # Should return cached chunks
@@ -106,7 +109,7 @@ def test_stream_cache_miss(mock_make_redis):
mock_redis_instance.get.return_value = None # Simulate a cache miss
@stream_cache
def mock_function(self, model, messages, stream):
def mock_function(self, model, messages, stream, tools):
yield "new_chunk"
messages = [
@@ -117,7 +120,7 @@ def test_stream_cache_miss(mock_make_redis):
model = "test_docgpt"
# Act
result = list(mock_function(None, model, messages, stream=True))
result = list(mock_function(None, model, messages, stream=True, tools=None))
# Assert
assert result == ["new_chunk"]