mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 08:33:20 +00:00
Merge branch 'main' of https://github.com/arc53/DocsGPT
This commit is contained in:
@@ -149,7 +149,7 @@ class BaseAgent(ABC):
|
||||
tool_call_data = {
|
||||
"tool_name": "unknown",
|
||||
"call_id": call_id,
|
||||
"action_name": getattr(call, 'name', 'unknown'),
|
||||
"action_name": getattr(call, "name", "unknown"),
|
||||
"arguments": call_args or {},
|
||||
"result": f"Failed to parse tool call. Invalid tool name format: {getattr(call, 'name', 'unknown')}",
|
||||
}
|
||||
@@ -225,6 +225,7 @@ class BaseAgent(ABC):
|
||||
if tool_data["name"] == "api_tool"
|
||||
else tool_data["config"]
|
||||
),
|
||||
user_id=self.user, # Pass user ID for MCP tools credential decryption
|
||||
)
|
||||
if tool_data["name"] == "api_tool":
|
||||
print(
|
||||
|
||||
405
application/agents/tools/mcp_tool.py
Normal file
405
application/agents/tools/mcp_tool.py
Normal file
@@ -0,0 +1,405 @@
|
||||
import json
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
|
||||
from application.agents.tools.base import Tool
|
||||
from application.security.encryption import decrypt_credentials
|
||||
|
||||
|
||||
_mcp_session_cache = {}
|
||||
|
||||
|
||||
class MCPTool(Tool):
|
||||
"""
|
||||
MCP Tool
|
||||
Connect to remote Model Context Protocol (MCP) servers to access dynamic tools and resources. Supports various authentication methods and provides secure access to external services through the MCP protocol.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any], user_id: Optional[str] = None):
|
||||
"""
|
||||
Initialize the MCP Tool with configuration.
|
||||
|
||||
Args:
|
||||
config: Dictionary containing MCP server configuration:
|
||||
- server_url: URL of the remote MCP server
|
||||
- auth_type: Type of authentication (api_key, bearer, basic, none)
|
||||
- encrypted_credentials: Encrypted credentials (if available)
|
||||
- timeout: Request timeout in seconds (default: 30)
|
||||
user_id: User ID for decrypting credentials (required if encrypted_credentials exist)
|
||||
"""
|
||||
self.config = config
|
||||
self.server_url = config.get("server_url", "")
|
||||
self.auth_type = config.get("auth_type", "none")
|
||||
self.timeout = config.get("timeout", 30)
|
||||
|
||||
self.auth_credentials = {}
|
||||
if config.get("encrypted_credentials") and user_id:
|
||||
self.auth_credentials = decrypt_credentials(
|
||||
config["encrypted_credentials"], user_id
|
||||
)
|
||||
else:
|
||||
self.auth_credentials = config.get("auth_credentials", {})
|
||||
self.available_tools = []
|
||||
self._session = requests.Session()
|
||||
self._mcp_session_id = None
|
||||
self._setup_authentication()
|
||||
self._cache_key = self._generate_cache_key()
|
||||
|
||||
def _setup_authentication(self):
|
||||
"""Setup authentication for the MCP server connection."""
|
||||
if self.auth_type == "api_key":
|
||||
api_key = self.auth_credentials.get("api_key", "")
|
||||
header_name = self.auth_credentials.get("api_key_header", "X-API-Key")
|
||||
if api_key:
|
||||
self._session.headers.update({header_name: api_key})
|
||||
elif self.auth_type == "bearer":
|
||||
token = self.auth_credentials.get("bearer_token", "")
|
||||
if token:
|
||||
self._session.headers.update({"Authorization": f"Bearer {token}"})
|
||||
elif self.auth_type == "basic":
|
||||
username = self.auth_credentials.get("username", "")
|
||||
password = self.auth_credentials.get("password", "")
|
||||
if username and password:
|
||||
self._session.auth = (username, password)
|
||||
|
||||
def _generate_cache_key(self) -> str:
|
||||
"""Generate a unique cache key for this MCP server configuration."""
|
||||
auth_key = ""
|
||||
if self.auth_type == "bearer":
|
||||
token = self.auth_credentials.get("bearer_token", "")
|
||||
auth_key = f"bearer:{token[:10]}..." if token else "bearer:none"
|
||||
elif self.auth_type == "api_key":
|
||||
api_key = self.auth_credentials.get("api_key", "")
|
||||
auth_key = f"apikey:{api_key[:10]}..." if api_key else "apikey:none"
|
||||
elif self.auth_type == "basic":
|
||||
username = self.auth_credentials.get("username", "")
|
||||
auth_key = f"basic:{username}"
|
||||
else:
|
||||
auth_key = "none"
|
||||
return f"{self.server_url}#{auth_key}"
|
||||
|
||||
def _get_cached_session(self) -> Optional[str]:
|
||||
"""Get cached session ID if available and not expired."""
|
||||
global _mcp_session_cache
|
||||
|
||||
if self._cache_key in _mcp_session_cache:
|
||||
session_data = _mcp_session_cache[self._cache_key]
|
||||
if time.time() - session_data["created_at"] < 1800:
|
||||
return session_data["session_id"]
|
||||
else:
|
||||
del _mcp_session_cache[self._cache_key]
|
||||
return None
|
||||
|
||||
def _cache_session(self, session_id: str):
|
||||
"""Cache the session ID for reuse."""
|
||||
global _mcp_session_cache
|
||||
_mcp_session_cache[self._cache_key] = {
|
||||
"session_id": session_id,
|
||||
"created_at": time.time(),
|
||||
}
|
||||
|
||||
def _initialize_mcp_connection(self) -> Dict:
|
||||
"""
|
||||
Initialize MCP connection with the server, using cached session if available.
|
||||
|
||||
Returns:
|
||||
Server capabilities and information
|
||||
"""
|
||||
cached_session = self._get_cached_session()
|
||||
if cached_session:
|
||||
self._mcp_session_id = cached_session
|
||||
return {"cached": True}
|
||||
try:
|
||||
init_params = {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {"roots": {"listChanged": True}, "sampling": {}},
|
||||
"clientInfo": {"name": "DocsGPT", "version": "1.0.0"},
|
||||
}
|
||||
response = self._make_mcp_request("initialize", init_params)
|
||||
self._make_mcp_request("notifications/initialized")
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
return {"error": str(e), "fallback": True}
|
||||
|
||||
def _ensure_valid_session(self):
|
||||
"""Ensure we have a valid MCP session, reinitializing if needed."""
|
||||
if not self._mcp_session_id:
|
||||
self._initialize_mcp_connection()
|
||||
|
||||
def _make_mcp_request(self, method: str, params: Optional[Dict] = None) -> Dict:
|
||||
"""
|
||||
Make an MCP protocol request to the server with automatic session recovery.
|
||||
|
||||
Args:
|
||||
method: MCP method name (e.g., "tools/list", "tools/call")
|
||||
params: Parameters for the MCP method
|
||||
|
||||
Returns:
|
||||
Response data as dictionary
|
||||
|
||||
Raises:
|
||||
Exception: If request fails after retry
|
||||
"""
|
||||
mcp_message = {"jsonrpc": "2.0", "method": method}
|
||||
|
||||
if not method.startswith("notifications/"):
|
||||
mcp_message["id"] = 1
|
||||
if params:
|
||||
mcp_message["params"] = params
|
||||
return self._execute_mcp_request(mcp_message, method)
|
||||
|
||||
def _execute_mcp_request(
|
||||
self, mcp_message: Dict, method: str, is_retry: bool = False
|
||||
) -> Dict:
|
||||
"""Execute MCP request with optional retry on session failure."""
|
||||
try:
|
||||
final_headers = self._session.headers.copy()
|
||||
final_headers.update(
|
||||
{
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json, text/event-stream",
|
||||
}
|
||||
)
|
||||
|
||||
if self._mcp_session_id:
|
||||
final_headers["Mcp-Session-Id"] = self._mcp_session_id
|
||||
response = self._session.post(
|
||||
self.server_url.rstrip("/"),
|
||||
json=mcp_message,
|
||||
headers=final_headers,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
|
||||
if "mcp-session-id" in response.headers:
|
||||
self._mcp_session_id = response.headers["mcp-session-id"]
|
||||
self._cache_session(self._mcp_session_id)
|
||||
response.raise_for_status()
|
||||
|
||||
if method.startswith("notifications/"):
|
||||
return {}
|
||||
response_text = response.text.strip()
|
||||
if response_text.startswith("event:") and "data:" in response_text:
|
||||
lines = response_text.split("\n")
|
||||
data_line = None
|
||||
for line in lines:
|
||||
if line.startswith("data:"):
|
||||
data_line = line[5:].strip()
|
||||
break
|
||||
if data_line:
|
||||
try:
|
||||
result = json.loads(data_line)
|
||||
except json.JSONDecodeError:
|
||||
raise Exception(f"Invalid JSON in SSE data: {data_line}")
|
||||
else:
|
||||
raise Exception(f"No data found in SSE response: {response_text}")
|
||||
else:
|
||||
try:
|
||||
result = response.json()
|
||||
except json.JSONDecodeError:
|
||||
raise Exception(f"Invalid JSON response: {response.text}")
|
||||
if "error" in result:
|
||||
error_msg = result["error"]
|
||||
if isinstance(error_msg, dict):
|
||||
error_msg = error_msg.get("message", str(error_msg))
|
||||
raise Exception(f"MCP server error: {error_msg}")
|
||||
return result.get("result", result)
|
||||
except requests.exceptions.RequestException as e:
|
||||
if not is_retry and self._should_retry_with_new_session(e):
|
||||
self._invalidate_and_refresh_session()
|
||||
return self._execute_mcp_request(mcp_message, method, is_retry=True)
|
||||
raise Exception(f"MCP server request failed: {str(e)}")
|
||||
|
||||
def _should_retry_with_new_session(self, error: Exception) -> bool:
|
||||
"""Check if error indicates session invalidation and retry is warranted."""
|
||||
error_str = str(error).lower()
|
||||
return (
|
||||
any(
|
||||
indicator in error_str
|
||||
for indicator in [
|
||||
"invalid session",
|
||||
"session expired",
|
||||
"unauthorized",
|
||||
"401",
|
||||
"403",
|
||||
]
|
||||
)
|
||||
and self._mcp_session_id is not None
|
||||
)
|
||||
|
||||
def _invalidate_and_refresh_session(self) -> None:
|
||||
"""Invalidate current session and create a new one."""
|
||||
global _mcp_session_cache
|
||||
if self._cache_key in _mcp_session_cache:
|
||||
del _mcp_session_cache[self._cache_key]
|
||||
self._mcp_session_id = None
|
||||
self._initialize_mcp_connection()
|
||||
|
||||
def discover_tools(self) -> List[Dict]:
|
||||
"""
|
||||
Discover available tools from the MCP server using MCP protocol.
|
||||
|
||||
Returns:
|
||||
List of tool definitions from the server
|
||||
"""
|
||||
try:
|
||||
self._ensure_valid_session()
|
||||
|
||||
response = self._make_mcp_request("tools/list")
|
||||
|
||||
# Handle both formats: response with 'tools' key or response that IS the tools list
|
||||
|
||||
if isinstance(response, dict):
|
||||
if "tools" in response:
|
||||
self.available_tools = response["tools"]
|
||||
elif (
|
||||
"result" in response
|
||||
and isinstance(response["result"], dict)
|
||||
and "tools" in response["result"]
|
||||
):
|
||||
self.available_tools = response["result"]["tools"]
|
||||
else:
|
||||
self.available_tools = [response] if response else []
|
||||
elif isinstance(response, list):
|
||||
self.available_tools = response
|
||||
else:
|
||||
self.available_tools = []
|
||||
return self.available_tools
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to discover tools from MCP server: {str(e)}")
|
||||
|
||||
def execute_action(self, action_name: str, **kwargs) -> Any:
|
||||
"""
|
||||
Execute an action on the remote MCP server using MCP protocol.
|
||||
|
||||
Args:
|
||||
action_name: Name of the action to execute
|
||||
**kwargs: Parameters for the action
|
||||
|
||||
Returns:
|
||||
Result from the MCP server
|
||||
"""
|
||||
self._ensure_valid_session()
|
||||
|
||||
# Skipping empty/None values - letting the server use defaults
|
||||
|
||||
cleaned_kwargs = {}
|
||||
for key, value in kwargs.items():
|
||||
if value == "" or value is None:
|
||||
continue
|
||||
cleaned_kwargs[key] = value
|
||||
call_params = {"name": action_name, "arguments": cleaned_kwargs}
|
||||
try:
|
||||
result = self._make_mcp_request("tools/call", call_params)
|
||||
return result
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to execute action '{action_name}': {str(e)}")
|
||||
|
||||
def get_actions_metadata(self) -> List[Dict]:
|
||||
"""
|
||||
Get metadata for all available actions.
|
||||
|
||||
Returns:
|
||||
List of action metadata dictionaries
|
||||
"""
|
||||
actions = []
|
||||
for tool in self.available_tools:
|
||||
input_schema = (
|
||||
tool.get("inputSchema")
|
||||
or tool.get("input_schema")
|
||||
or tool.get("schema")
|
||||
or tool.get("parameters")
|
||||
)
|
||||
|
||||
parameters_schema = {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
}
|
||||
|
||||
if input_schema:
|
||||
if isinstance(input_schema, dict):
|
||||
if "properties" in input_schema:
|
||||
parameters_schema = {
|
||||
"type": input_schema.get("type", "object"),
|
||||
"properties": input_schema.get("properties", {}),
|
||||
"required": input_schema.get("required", []),
|
||||
}
|
||||
|
||||
for key in ["additionalProperties", "description"]:
|
||||
if key in input_schema:
|
||||
parameters_schema[key] = input_schema[key]
|
||||
else:
|
||||
parameters_schema["properties"] = input_schema
|
||||
action = {
|
||||
"name": tool.get("name", ""),
|
||||
"description": tool.get("description", ""),
|
||||
"parameters": parameters_schema,
|
||||
}
|
||||
actions.append(action)
|
||||
return actions
|
||||
|
||||
def test_connection(self) -> Dict:
|
||||
"""
|
||||
Test the connection to the MCP server and validate functionality.
|
||||
|
||||
Returns:
|
||||
Dictionary with connection test results including tool count
|
||||
"""
|
||||
try:
|
||||
self._mcp_session_id = None
|
||||
|
||||
init_result = self._initialize_mcp_connection()
|
||||
|
||||
tools = self.discover_tools()
|
||||
|
||||
message = f"Successfully connected to MCP server. Found {len(tools)} tools."
|
||||
if init_result.get("cached"):
|
||||
message += " (Using cached session)"
|
||||
elif init_result.get("fallback"):
|
||||
message += " (No formal initialization required)"
|
||||
return {
|
||||
"success": True,
|
||||
"message": message,
|
||||
"tools_count": len(tools),
|
||||
"session_id": self._mcp_session_id,
|
||||
"tools": [tool.get("name", "unknown") for tool in tools[:5]],
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"Connection failed: {str(e)}",
|
||||
"tools_count": 0,
|
||||
"error_type": type(e).__name__,
|
||||
}
|
||||
|
||||
def get_config_requirements(self) -> Dict:
|
||||
return {
|
||||
"server_url": {
|
||||
"type": "string",
|
||||
"description": "URL of the remote MCP server (e.g., https://api.example.com)",
|
||||
"required": True,
|
||||
},
|
||||
"auth_type": {
|
||||
"type": "string",
|
||||
"description": "Authentication type",
|
||||
"enum": ["none", "api_key", "bearer", "basic"],
|
||||
"default": "none",
|
||||
"required": True,
|
||||
},
|
||||
"auth_credentials": {
|
||||
"type": "object",
|
||||
"description": "Authentication credentials (varies by auth_type)",
|
||||
"required": False,
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": "Request timeout in seconds",
|
||||
"default": 30,
|
||||
"minimum": 1,
|
||||
"maximum": 300,
|
||||
"required": False,
|
||||
},
|
||||
}
|
||||
@@ -23,16 +23,23 @@ class ToolManager:
|
||||
tool_config = self.config.get(name, {})
|
||||
self.tools[name] = obj(tool_config)
|
||||
|
||||
def load_tool(self, tool_name, tool_config):
|
||||
def load_tool(self, tool_name, tool_config, user_id=None):
|
||||
self.config[tool_name] = tool_config
|
||||
module = importlib.import_module(f"application.agents.tools.{tool_name}")
|
||||
for member_name, obj in inspect.getmembers(module, inspect.isclass):
|
||||
if issubclass(obj, Tool) and obj is not Tool:
|
||||
if tool_name == "mcp_tool" and user_id:
|
||||
return obj(tool_config, user_id)
|
||||
else:
|
||||
return obj(tool_config)
|
||||
|
||||
def execute_action(self, tool_name, action_name, **kwargs):
|
||||
def execute_action(self, tool_name, action_name, user_id=None, **kwargs):
|
||||
if tool_name not in self.tools:
|
||||
raise ValueError(f"Tool '{tool_name}' not loaded")
|
||||
if tool_name == "mcp_tool" and user_id:
|
||||
tool_config = self.config.get(tool_name, {})
|
||||
tool = self.load_tool(tool_name, tool_config, user_id)
|
||||
return tool.execute_action(action_name, **kwargs)
|
||||
return self.tools[tool_name].execute_action(action_name, **kwargs)
|
||||
|
||||
def get_all_actions_metadata(self):
|
||||
|
||||
@@ -69,11 +69,8 @@ class StreamProcessor:
|
||||
self.decoded_token.get("sub") if self.decoded_token is not None else None
|
||||
)
|
||||
self.conversation_id = self.data.get("conversation_id")
|
||||
self.source = (
|
||||
{"active_docs": self.data["active_docs"]}
|
||||
if "active_docs" in self.data
|
||||
else {}
|
||||
)
|
||||
self.source = {}
|
||||
self.all_sources = []
|
||||
self.attachments = []
|
||||
self.history = []
|
||||
self.agent_config = {}
|
||||
@@ -85,6 +82,8 @@ class StreamProcessor:
|
||||
|
||||
def initialize(self):
|
||||
"""Initialize all required components for processing"""
|
||||
self._configure_agent()
|
||||
self._configure_source()
|
||||
self._configure_retriever()
|
||||
self._configure_agent()
|
||||
self._load_conversation_history()
|
||||
@@ -171,13 +170,77 @@ class StreamProcessor:
|
||||
source = data.get("source")
|
||||
if isinstance(source, DBRef):
|
||||
source_doc = self.db.dereference(source)
|
||||
if source_doc:
|
||||
data["source"] = str(source_doc["_id"])
|
||||
data["retriever"] = source_doc.get("retriever", data.get("retriever"))
|
||||
data["chunks"] = source_doc.get("chunks", data.get("chunks"))
|
||||
else:
|
||||
data["source"] = None
|
||||
elif source == "default":
|
||||
data["source"] = "default"
|
||||
else:
|
||||
data["source"] = None
|
||||
# Handle multiple sources
|
||||
|
||||
sources = data.get("sources", [])
|
||||
if sources and isinstance(sources, list):
|
||||
sources_list = []
|
||||
for i, source_ref in enumerate(sources):
|
||||
if source_ref == "default":
|
||||
processed_source = {
|
||||
"id": "default",
|
||||
"retriever": "classic",
|
||||
"chunks": data.get("chunks", "2"),
|
||||
}
|
||||
sources_list.append(processed_source)
|
||||
elif isinstance(source_ref, DBRef):
|
||||
source_doc = self.db.dereference(source_ref)
|
||||
if source_doc:
|
||||
processed_source = {
|
||||
"id": str(source_doc["_id"]),
|
||||
"retriever": source_doc.get("retriever", "classic"),
|
||||
"chunks": source_doc.get("chunks", data.get("chunks", "2")),
|
||||
}
|
||||
sources_list.append(processed_source)
|
||||
data["sources"] = sources_list
|
||||
else:
|
||||
data["sources"] = []
|
||||
return data
|
||||
|
||||
def _configure_source(self):
|
||||
"""Configure the source based on agent data"""
|
||||
api_key = self.data.get("api_key") or self.agent_key
|
||||
|
||||
if api_key:
|
||||
agent_data = self._get_data_from_api_key(api_key)
|
||||
|
||||
if agent_data.get("sources") and len(agent_data["sources"]) > 0:
|
||||
source_ids = [
|
||||
source["id"] for source in agent_data["sources"] if source.get("id")
|
||||
]
|
||||
if source_ids:
|
||||
self.source = {"active_docs": source_ids}
|
||||
else:
|
||||
self.source = {}
|
||||
self.all_sources = agent_data["sources"]
|
||||
elif agent_data.get("source"):
|
||||
self.source = {"active_docs": agent_data["source"]}
|
||||
self.all_sources = [
|
||||
{
|
||||
"id": agent_data["source"],
|
||||
"retriever": agent_data.get("retriever", "classic"),
|
||||
}
|
||||
]
|
||||
else:
|
||||
self.source = {}
|
||||
self.all_sources = []
|
||||
return
|
||||
if "active_docs" in self.data:
|
||||
self.source = {"active_docs": self.data["active_docs"]}
|
||||
return
|
||||
self.source = {}
|
||||
self.all_sources = []
|
||||
|
||||
def _configure_agent(self):
|
||||
"""Configure the agent based on request data"""
|
||||
agent_id = self.data.get("agent_id")
|
||||
@@ -203,7 +266,13 @@ class StreamProcessor:
|
||||
if data_key.get("retriever"):
|
||||
self.retriever_config["retriever_name"] = data_key["retriever"]
|
||||
if data_key.get("chunks") is not None:
|
||||
self.retriever_config["chunks"] = data_key["chunks"]
|
||||
try:
|
||||
self.retriever_config["chunks"] = int(data_key["chunks"])
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(
|
||||
f"Invalid chunks value: {data_key['chunks']}, using default value 2"
|
||||
)
|
||||
self.retriever_config["chunks"] = 2
|
||||
elif self.agent_key:
|
||||
data_key = self._get_data_from_api_key(self.agent_key)
|
||||
self.agent_config.update(
|
||||
@@ -224,7 +293,13 @@ class StreamProcessor:
|
||||
if data_key.get("retriever"):
|
||||
self.retriever_config["retriever_name"] = data_key["retriever"]
|
||||
if data_key.get("chunks") is not None:
|
||||
self.retriever_config["chunks"] = data_key["chunks"]
|
||||
try:
|
||||
self.retriever_config["chunks"] = int(data_key["chunks"])
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(
|
||||
f"Invalid chunks value: {data_key['chunks']}, using default value 2"
|
||||
)
|
||||
self.retriever_config["chunks"] = 2
|
||||
else:
|
||||
self.agent_config.update(
|
||||
{
|
||||
@@ -243,7 +318,8 @@ class StreamProcessor:
|
||||
"token_limit": self.data.get("token_limit", settings.DEFAULT_MAX_HISTORY),
|
||||
}
|
||||
|
||||
if "isNoneDoc" in self.data and self.data["isNoneDoc"]:
|
||||
api_key = self.data.get("api_key") or self.agent_key
|
||||
if not api_key and "isNoneDoc" in self.data and self.data["isNoneDoc"]:
|
||||
self.retriever_config["chunks"] = 0
|
||||
|
||||
def create_agent(self):
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -26,7 +26,7 @@ class Settings(BaseSettings):
|
||||
"gpt-4o-mini": 128000,
|
||||
"gpt-3.5-turbo": 4096,
|
||||
"claude-2": 1e5,
|
||||
"gemini-2.0-flash-exp": 1e6,
|
||||
"gemini-2.5-flash": 1e6,
|
||||
}
|
||||
UPLOAD_FOLDER: str = "inputs"
|
||||
PARSE_PDF_AS_IMAGE: bool = False
|
||||
@@ -116,6 +116,9 @@ class Settings(BaseSettings):
|
||||
|
||||
JWT_SECRET_KEY: str = ""
|
||||
|
||||
# Encryption settings
|
||||
ENCRYPTION_SECRET_KEY: str = "default-docsgpt-encryption-key"
|
||||
|
||||
|
||||
path = Path(__file__).parent.parent.absolute()
|
||||
settings = Settings(_env_file=path.joinpath(".env"), _env_file_encoding="utf-8")
|
||||
|
||||
@@ -143,6 +143,7 @@ class GoogleLLM(BaseLLM):
|
||||
raise
|
||||
|
||||
def _clean_messages_google(self, messages):
|
||||
"""Convert OpenAI format messages to Google AI format."""
|
||||
cleaned_messages = []
|
||||
for message in messages:
|
||||
role = message.get("role")
|
||||
@@ -150,6 +151,8 @@ class GoogleLLM(BaseLLM):
|
||||
|
||||
if role == "assistant":
|
||||
role = "model"
|
||||
elif role == "tool":
|
||||
role = "model"
|
||||
|
||||
parts = []
|
||||
if role and content is not None:
|
||||
@@ -188,11 +191,63 @@ class GoogleLLM(BaseLLM):
|
||||
else:
|
||||
raise ValueError(f"Unexpected content type: {type(content)}")
|
||||
|
||||
if parts:
|
||||
cleaned_messages.append(types.Content(role=role, parts=parts))
|
||||
|
||||
return cleaned_messages
|
||||
|
||||
def _clean_schema(self, schema_obj):
|
||||
"""
|
||||
Recursively remove unsupported fields from schema objects
|
||||
and validate required properties.
|
||||
"""
|
||||
if not isinstance(schema_obj, dict):
|
||||
return schema_obj
|
||||
allowed_fields = {
|
||||
"type",
|
||||
"description",
|
||||
"items",
|
||||
"properties",
|
||||
"required",
|
||||
"enum",
|
||||
"pattern",
|
||||
"minimum",
|
||||
"maximum",
|
||||
"nullable",
|
||||
"default",
|
||||
}
|
||||
|
||||
cleaned = {}
|
||||
for key, value in schema_obj.items():
|
||||
if key not in allowed_fields:
|
||||
continue
|
||||
elif key == "type" and isinstance(value, str):
|
||||
cleaned[key] = value.upper()
|
||||
elif isinstance(value, dict):
|
||||
cleaned[key] = self._clean_schema(value)
|
||||
elif isinstance(value, list):
|
||||
cleaned[key] = [self._clean_schema(item) for item in value]
|
||||
else:
|
||||
cleaned[key] = value
|
||||
|
||||
# Validate that required properties actually exist in properties
|
||||
if "required" in cleaned and "properties" in cleaned:
|
||||
valid_required = []
|
||||
properties_keys = set(cleaned["properties"].keys())
|
||||
for required_prop in cleaned["required"]:
|
||||
if required_prop in properties_keys:
|
||||
valid_required.append(required_prop)
|
||||
if valid_required:
|
||||
cleaned["required"] = valid_required
|
||||
else:
|
||||
cleaned.pop("required", None)
|
||||
elif "required" in cleaned and "properties" not in cleaned:
|
||||
cleaned.pop("required", None)
|
||||
|
||||
return cleaned
|
||||
|
||||
def _clean_tools_format(self, tools_list):
|
||||
"""Convert OpenAI format tools to Google AI format."""
|
||||
genai_tools = []
|
||||
for tool_data in tools_list:
|
||||
if tool_data["type"] == "function":
|
||||
@@ -201,18 +256,16 @@ class GoogleLLM(BaseLLM):
|
||||
properties = parameters.get("properties", {})
|
||||
|
||||
if properties:
|
||||
cleaned_properties = {}
|
||||
for k, v in properties.items():
|
||||
cleaned_properties[k] = self._clean_schema(v)
|
||||
|
||||
genai_function = dict(
|
||||
name=function["name"],
|
||||
description=function["description"],
|
||||
parameters={
|
||||
"type": "OBJECT",
|
||||
"properties": {
|
||||
k: {
|
||||
**v,
|
||||
"type": v["type"].upper() if v["type"] else None,
|
||||
}
|
||||
for k, v in properties.items()
|
||||
},
|
||||
"properties": cleaned_properties,
|
||||
"required": (
|
||||
parameters["required"]
|
||||
if "required" in parameters
|
||||
@@ -242,6 +295,7 @@ class GoogleLLM(BaseLLM):
|
||||
response_schema=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Generate content using Google AI API without streaming."""
|
||||
client = genai.Client(api_key=self.api_key)
|
||||
if formatting == "openai":
|
||||
messages = self._clean_messages_google(messages)
|
||||
@@ -281,6 +335,7 @@ class GoogleLLM(BaseLLM):
|
||||
response_schema=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Generate content using Google AI API with streaming."""
|
||||
client = genai.Client(api_key=self.api_key)
|
||||
if formatting == "openai":
|
||||
messages = self._clean_messages_google(messages)
|
||||
@@ -331,12 +386,15 @@ class GoogleLLM(BaseLLM):
|
||||
yield chunk.text
|
||||
|
||||
def _supports_tools(self):
|
||||
"""Return whether this LLM supports function calling."""
|
||||
return True
|
||||
|
||||
def _supports_structured_output(self):
|
||||
"""Return whether this LLM supports structured JSON output."""
|
||||
return True
|
||||
|
||||
def prepare_structured_output_format(self, json_schema):
|
||||
"""Convert JSON schema to Google AI structured output format."""
|
||||
if not json_schema:
|
||||
return None
|
||||
|
||||
|
||||
@@ -205,7 +205,6 @@ class LLMHandler(ABC):
|
||||
except StopIteration as e:
|
||||
tool_response, call_id = e.value
|
||||
break
|
||||
|
||||
updated_messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
@@ -222,17 +221,36 @@ class LLMHandler(ABC):
|
||||
)
|
||||
|
||||
updated_messages.append(self.create_tool_message(call, tool_response))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing tool: {str(e)}", exc_info=True)
|
||||
updated_messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": f"Error executing tool: {str(e)}",
|
||||
"tool_call_id": call.id,
|
||||
}
|
||||
error_call = ToolCall(
|
||||
id=call.id, name=call.name, arguments=call.arguments
|
||||
)
|
||||
error_response = f"Error executing tool: {str(e)}"
|
||||
error_message = self.create_tool_message(error_call, error_response)
|
||||
updated_messages.append(error_message)
|
||||
|
||||
call_parts = call.name.split("_")
|
||||
if len(call_parts) >= 2:
|
||||
tool_id = call_parts[-1] # Last part is tool ID (e.g., "1")
|
||||
action_name = "_".join(call_parts[:-1])
|
||||
tool_name = tools_dict.get(tool_id, {}).get("name", "unknown_tool")
|
||||
full_action_name = f"{action_name}_{tool_id}"
|
||||
else:
|
||||
tool_name = "unknown_tool"
|
||||
action_name = call.name
|
||||
full_action_name = call.name
|
||||
yield {
|
||||
"type": "tool_call",
|
||||
"data": {
|
||||
"tool_name": tool_name,
|
||||
"call_id": call.id,
|
||||
"action_name": full_action_name,
|
||||
"arguments": call.arguments,
|
||||
"error": error_response,
|
||||
"status": "error",
|
||||
},
|
||||
}
|
||||
return updated_messages
|
||||
|
||||
def handle_non_streaming(
|
||||
@@ -263,13 +281,11 @@ class LLMHandler(ABC):
|
||||
except StopIteration as e:
|
||||
messages = e.value
|
||||
break
|
||||
|
||||
response = agent.llm.gen(
|
||||
model=agent.gpt_model, messages=messages, tools=agent.tools
|
||||
)
|
||||
parsed = self.parse_response(response)
|
||||
self.llm_calls.append(build_stack_data(agent.llm))
|
||||
|
||||
return parsed.content
|
||||
|
||||
def handle_streaming(
|
||||
|
||||
@@ -17,7 +17,6 @@ class GoogleLLMHandler(LLMHandler):
|
||||
finish_reason="stop",
|
||||
raw_response=response,
|
||||
)
|
||||
|
||||
if hasattr(response, "candidates"):
|
||||
parts = response.candidates[0].content.parts if response.candidates else []
|
||||
tool_calls = [
|
||||
@@ -41,7 +40,6 @@ class GoogleLLMHandler(LLMHandler):
|
||||
finish_reason="tool_calls" if tool_calls else "stop",
|
||||
raw_response=response,
|
||||
)
|
||||
|
||||
else:
|
||||
tool_calls = []
|
||||
if hasattr(response, "function_call"):
|
||||
@@ -61,14 +59,16 @@ class GoogleLLMHandler(LLMHandler):
|
||||
|
||||
def create_tool_message(self, tool_call: ToolCall, result: Any) -> Dict:
|
||||
"""Create Google-style tool message."""
|
||||
from google.genai import types
|
||||
|
||||
return {
|
||||
"role": "tool",
|
||||
"role": "model",
|
||||
"content": [
|
||||
types.Part.from_function_response(
|
||||
name=tool_call.name, response={"result": result}
|
||||
).to_json_dict()
|
||||
{
|
||||
"function_response": {
|
||||
"name": tool_call.name,
|
||||
"response": {"result": result},
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ anthropic==0.49.0
|
||||
boto3==1.38.18
|
||||
beautifulsoup4==4.13.4
|
||||
celery==5.4.0
|
||||
cryptography==42.0.8
|
||||
dataclasses-json==0.6.7
|
||||
docx2txt==0.8
|
||||
duckduckgo-search==7.5.2
|
||||
|
||||
@@ -5,10 +5,6 @@ class BaseRetriever(ABC):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def gen(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
from application.retriever.base import BaseRetriever
|
||||
@@ -20,9 +21,19 @@ class ClassicRAG(BaseRetriever):
|
||||
api_key=settings.API_KEY,
|
||||
decoded_token=None,
|
||||
):
|
||||
self.original_question = ""
|
||||
"""Initialize ClassicRAG retriever with vectorstore sources and LLM configuration"""
|
||||
self.original_question = source.get("question", "")
|
||||
self.chat_history = chat_history if chat_history is not None else []
|
||||
self.prompt = prompt
|
||||
if isinstance(chunks, str):
|
||||
try:
|
||||
self.chunks = int(chunks)
|
||||
except ValueError:
|
||||
logging.warning(
|
||||
f"Invalid chunks value '{chunks}', using default value 2"
|
||||
)
|
||||
self.chunks = 2
|
||||
else:
|
||||
self.chunks = chunks
|
||||
self.gpt_model = gpt_model
|
||||
self.token_limit = (
|
||||
@@ -44,25 +55,52 @@ class ClassicRAG(BaseRetriever):
|
||||
user_api_key=self.user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
)
|
||||
self.vectorstore = source["active_docs"] if "active_docs" in source else None
|
||||
|
||||
if "active_docs" in source and source["active_docs"] is not None:
|
||||
if isinstance(source["active_docs"], list):
|
||||
self.vectorstores = source["active_docs"]
|
||||
else:
|
||||
self.vectorstores = [source["active_docs"]]
|
||||
else:
|
||||
self.vectorstores = []
|
||||
self.question = self._rephrase_query()
|
||||
self.decoded_token = decoded_token
|
||||
self._validate_vectorstore_config()
|
||||
|
||||
def _validate_vectorstore_config(self):
|
||||
"""Validate vectorstore IDs and remove any empty/invalid entries"""
|
||||
if not self.vectorstores:
|
||||
logging.warning("No vectorstores configured for retrieval")
|
||||
return
|
||||
invalid_ids = [
|
||||
vs_id for vs_id in self.vectorstores if not vs_id or not vs_id.strip()
|
||||
]
|
||||
if invalid_ids:
|
||||
logging.warning(f"Found invalid vectorstore IDs: {invalid_ids}")
|
||||
self.vectorstores = [
|
||||
vs_id for vs_id in self.vectorstores if vs_id and vs_id.strip()
|
||||
]
|
||||
|
||||
def _rephrase_query(self):
|
||||
"""Rephrase user query with chat history context for better retrieval"""
|
||||
if (
|
||||
not self.original_question
|
||||
or not self.chat_history
|
||||
or self.chat_history == []
|
||||
or self.chunks == 0
|
||||
or self.vectorstore is None
|
||||
or not self.vectorstores
|
||||
):
|
||||
return self.original_question
|
||||
|
||||
prompt = f"""Given the following conversation history:
|
||||
|
||||
{self.chat_history}
|
||||
|
||||
|
||||
|
||||
Rephrase the following user question to be a standalone search query
|
||||
|
||||
that captures all relevant context from the conversation:
|
||||
|
||||
"""
|
||||
|
||||
messages = [
|
||||
@@ -79,44 +117,62 @@ class ClassicRAG(BaseRetriever):
|
||||
return self.original_question
|
||||
|
||||
def _get_data(self):
|
||||
if self.chunks == 0 or self.vectorstore is None:
|
||||
docs = []
|
||||
else:
|
||||
"""Retrieve relevant documents from configured vectorstores"""
|
||||
if self.chunks == 0 or not self.vectorstores:
|
||||
return []
|
||||
all_docs = []
|
||||
chunks_per_source = max(1, self.chunks // len(self.vectorstores))
|
||||
|
||||
for vectorstore_id in self.vectorstores:
|
||||
if vectorstore_id:
|
||||
try:
|
||||
docsearch = VectorCreator.create_vectorstore(
|
||||
settings.VECTOR_STORE, self.vectorstore, settings.EMBEDDINGS_KEY
|
||||
settings.VECTOR_STORE, vectorstore_id, settings.EMBEDDINGS_KEY
|
||||
)
|
||||
docs_temp = docsearch.search(self.question, k=self.chunks)
|
||||
docs = [
|
||||
docs_temp = docsearch.search(self.question, k=chunks_per_source)
|
||||
|
||||
for doc in docs_temp:
|
||||
if hasattr(doc, "page_content") and hasattr(doc, "metadata"):
|
||||
page_content = doc.page_content
|
||||
metadata = doc.metadata
|
||||
else:
|
||||
page_content = doc.get("text", doc.get("page_content", ""))
|
||||
metadata = doc.get("metadata", {})
|
||||
title = metadata.get(
|
||||
"title", metadata.get("post_title", page_content)
|
||||
)
|
||||
if isinstance(title, str):
|
||||
title = title.split("/")[-1]
|
||||
else:
|
||||
title = str(title).split("/")[-1]
|
||||
all_docs.append(
|
||||
{
|
||||
"title": i.metadata.get(
|
||||
"title", i.metadata.get("post_title", i.page_content)
|
||||
).split("/")[-1],
|
||||
"text": i.page_content,
|
||||
"source": (
|
||||
i.metadata.get("source")
|
||||
if i.metadata.get("source")
|
||||
else "local"
|
||||
),
|
||||
"title": title,
|
||||
"text": page_content,
|
||||
"source": metadata.get("source") or vectorstore_id,
|
||||
}
|
||||
for i in docs_temp
|
||||
]
|
||||
|
||||
return docs
|
||||
|
||||
def gen():
|
||||
pass
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
f"Error searching vectorstore {vectorstore_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
continue
|
||||
return all_docs
|
||||
|
||||
def search(self, query: str = ""):
|
||||
"""Search for documents using optional query override"""
|
||||
if query:
|
||||
self.original_question = query
|
||||
self.question = self._rephrase_query()
|
||||
return self._get_data()
|
||||
|
||||
def get_params(self):
|
||||
"""Return current retriever configuration parameters"""
|
||||
return {
|
||||
"question": self.original_question,
|
||||
"rephrased_question": self.question,
|
||||
"source": self.vectorstore,
|
||||
"sources": self.vectorstores,
|
||||
"chunks": self.chunks,
|
||||
"token_limit": self.token_limit,
|
||||
"gpt_model": self.gpt_model,
|
||||
|
||||
0
application/security/__init__.py
Normal file
0
application/security/__init__.py
Normal file
85
application/security/encryption.py
Normal file
85
application/security/encryption.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.primitives.ciphers import algorithms, Cipher, modes
|
||||
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
|
||||
|
||||
from application.core.settings import settings
|
||||
|
||||
|
||||
def _derive_key(user_id: str, salt: bytes) -> bytes:
|
||||
app_secret = settings.ENCRYPTION_SECRET_KEY
|
||||
|
||||
password = f"{app_secret}#{user_id}".encode()
|
||||
|
||||
kdf = PBKDF2HMAC(
|
||||
algorithm=hashes.SHA256(),
|
||||
length=32,
|
||||
salt=salt,
|
||||
iterations=100000,
|
||||
backend=default_backend(),
|
||||
)
|
||||
|
||||
return kdf.derive(password)
|
||||
|
||||
|
||||
def encrypt_credentials(credentials: dict, user_id: str) -> str:
|
||||
if not credentials:
|
||||
return ""
|
||||
try:
|
||||
salt = os.urandom(16)
|
||||
iv = os.urandom(16)
|
||||
key = _derive_key(user_id, salt)
|
||||
|
||||
json_str = json.dumps(credentials)
|
||||
|
||||
cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend())
|
||||
encryptor = cipher.encryptor()
|
||||
|
||||
padded_data = _pad_data(json_str.encode())
|
||||
encrypted_data = encryptor.update(padded_data) + encryptor.finalize()
|
||||
|
||||
result = salt + iv + encrypted_data
|
||||
return base64.b64encode(result).decode()
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to encrypt credentials: {e}")
|
||||
return ""
|
||||
|
||||
|
||||
def decrypt_credentials(encrypted_data: str, user_id: str) -> dict:
|
||||
if not encrypted_data:
|
||||
return {}
|
||||
try:
|
||||
data = base64.b64decode(encrypted_data.encode())
|
||||
|
||||
salt = data[:16]
|
||||
iv = data[16:32]
|
||||
encrypted_content = data[32:]
|
||||
|
||||
key = _derive_key(user_id, salt)
|
||||
|
||||
cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend())
|
||||
decryptor = cipher.decryptor()
|
||||
|
||||
decrypted_padded = decryptor.update(encrypted_content) + decryptor.finalize()
|
||||
decrypted_data = _unpad_data(decrypted_padded)
|
||||
|
||||
return json.loads(decrypted_data.decode())
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to decrypt credentials: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
def _pad_data(data: bytes) -> bytes:
|
||||
block_size = 16
|
||||
padding_len = block_size - (len(data) % block_size)
|
||||
padding = bytes([padding_len]) * padding_len
|
||||
return data + padding
|
||||
|
||||
|
||||
def _unpad_data(data: bytes) -> bytes:
|
||||
padding_len = data[-1]
|
||||
return data[:-padding_len]
|
||||
@@ -1,12 +1,20 @@
|
||||
from abc import ABC, abstractmethod
|
||||
import os
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
from application.core.settings import settings
|
||||
|
||||
|
||||
class EmbeddingsWrapper:
|
||||
def __init__(self, model_name, *args, **kwargs):
|
||||
self.model = SentenceTransformer(model_name, config_kwargs={'allow_dangerous_deserialization': True}, *args, **kwargs)
|
||||
self.model = SentenceTransformer(
|
||||
model_name,
|
||||
config_kwargs={"allow_dangerous_deserialization": True},
|
||||
*args,
|
||||
**kwargs
|
||||
)
|
||||
self.dimension = self.model.get_sentence_embedding_dimension()
|
||||
|
||||
def embed_query(self, query: str):
|
||||
@@ -24,15 +32,14 @@ class EmbeddingsWrapper:
|
||||
raise ValueError("Input must be a string or a list of strings")
|
||||
|
||||
|
||||
|
||||
class EmbeddingsSingleton:
|
||||
_instances = {}
|
||||
|
||||
@staticmethod
|
||||
def get_instance(embeddings_name, *args, **kwargs):
|
||||
if embeddings_name not in EmbeddingsSingleton._instances:
|
||||
EmbeddingsSingleton._instances[embeddings_name] = EmbeddingsSingleton._create_instance(
|
||||
embeddings_name, *args, **kwargs
|
||||
EmbeddingsSingleton._instances[embeddings_name] = (
|
||||
EmbeddingsSingleton._create_instance(embeddings_name, *args, **kwargs)
|
||||
)
|
||||
return EmbeddingsSingleton._instances[embeddings_name]
|
||||
|
||||
@@ -40,9 +47,15 @@ class EmbeddingsSingleton:
|
||||
def _create_instance(embeddings_name, *args, **kwargs):
|
||||
embeddings_factory = {
|
||||
"openai_text-embedding-ada-002": OpenAIEmbeddings,
|
||||
"huggingface_sentence-transformers/all-mpnet-base-v2": lambda: EmbeddingsWrapper("sentence-transformers/all-mpnet-base-v2"),
|
||||
"huggingface_sentence-transformers-all-mpnet-base-v2": lambda: EmbeddingsWrapper("sentence-transformers/all-mpnet-base-v2"),
|
||||
"huggingface_hkunlp/instructor-large": lambda: EmbeddingsWrapper("hkunlp/instructor-large"),
|
||||
"huggingface_sentence-transformers/all-mpnet-base-v2": lambda: EmbeddingsWrapper(
|
||||
"sentence-transformers/all-mpnet-base-v2"
|
||||
),
|
||||
"huggingface_sentence-transformers-all-mpnet-base-v2": lambda: EmbeddingsWrapper(
|
||||
"sentence-transformers/all-mpnet-base-v2"
|
||||
),
|
||||
"huggingface_hkunlp/instructor-large": lambda: EmbeddingsWrapper(
|
||||
"hkunlp/instructor-large"
|
||||
),
|
||||
}
|
||||
|
||||
if embeddings_name in embeddings_factory:
|
||||
@@ -50,29 +63,58 @@ class EmbeddingsSingleton:
|
||||
else:
|
||||
return EmbeddingsWrapper(embeddings_name, *args, **kwargs)
|
||||
|
||||
|
||||
class BaseVectorStore(ABC):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search(self, *args, **kwargs):
|
||||
"""Search for similar documents/chunks in the vectorstore"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def add_texts(self, texts, metadatas=None, *args, **kwargs):
|
||||
"""Add texts with their embeddings to the vectorstore"""
|
||||
pass
|
||||
|
||||
def delete_index(self, *args, **kwargs):
|
||||
"""Delete the entire index/collection"""
|
||||
pass
|
||||
|
||||
def save_local(self, *args, **kwargs):
|
||||
"""Save vectorstore to local storage"""
|
||||
pass
|
||||
|
||||
def get_chunks(self, *args, **kwargs):
|
||||
"""Get all chunks from the vectorstore"""
|
||||
pass
|
||||
|
||||
def add_chunk(self, text, metadata=None, *args, **kwargs):
|
||||
"""Add a single chunk to the vectorstore"""
|
||||
pass
|
||||
|
||||
def delete_chunk(self, chunk_id, *args, **kwargs):
|
||||
"""Delete a specific chunk from the vectorstore"""
|
||||
pass
|
||||
|
||||
def is_azure_configured(self):
|
||||
return settings.OPENAI_API_BASE and settings.OPENAI_API_VERSION and settings.AZURE_DEPLOYMENT_NAME
|
||||
return (
|
||||
settings.OPENAI_API_BASE
|
||||
and settings.OPENAI_API_VERSION
|
||||
and settings.AZURE_DEPLOYMENT_NAME
|
||||
)
|
||||
|
||||
def _get_embeddings(self, embeddings_name, embeddings_key=None):
|
||||
if embeddings_name == "openai_text-embedding-ada-002":
|
||||
if self.is_azure_configured():
|
||||
os.environ["OPENAI_API_TYPE"] = "azure"
|
||||
embedding_instance = EmbeddingsSingleton.get_instance(
|
||||
embeddings_name,
|
||||
model=settings.AZURE_EMBEDDINGS_DEPLOYMENT_NAME
|
||||
embeddings_name, model=settings.AZURE_EMBEDDINGS_DEPLOYMENT_NAME
|
||||
)
|
||||
else:
|
||||
embedding_instance = EmbeddingsSingleton.get_instance(
|
||||
embeddings_name,
|
||||
openai_api_key=embeddings_key
|
||||
embeddings_name, openai_api_key=embeddings_key
|
||||
)
|
||||
elif embeddings_name == "huggingface_sentence-transformers/all-mpnet-base-v2":
|
||||
if os.path.exists("./models/all-mpnet-base-v2"):
|
||||
@@ -87,4 +129,3 @@ class BaseVectorStore(ABC):
|
||||
embedding_instance = EmbeddingsSingleton.get_instance(embeddings_name)
|
||||
|
||||
return embedding_instance
|
||||
|
||||
|
||||
4
frontend/public/toolIcons/tool_mcp_tool.svg
Normal file
4
frontend/public/toolIcons/tool_mcp_tool.svg
Normal file
@@ -0,0 +1,4 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" width="64" height="64" color="#000000" fill="none">
|
||||
<path d="M3.49994 11.7501L11.6717 3.57855C12.7762 2.47398 14.5672 2.47398 15.6717 3.57855C16.7762 4.68312 16.7762 6.47398 15.6717 7.57855M15.6717 7.57855L9.49994 13.7501M15.6717 7.57855C16.7762 6.47398 18.5672 6.47398 19.6717 7.57855C20.7762 8.68312 20.7762 10.474 19.6717 11.5785L12.7072 18.543C12.3167 18.9335 12.3167 19.5667 12.7072 19.9572L13.9999 21.2499" stroke="currentColor" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"></path>
|
||||
<path d="M17.4999 9.74921L11.3282 15.921C10.2237 17.0255 8.43272 17.0255 7.32823 15.921C6.22373 14.8164 6.22373 13.0255 7.32823 11.921L13.4999 5.74939" stroke="currentColor" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"></path>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 831 B |
@@ -45,6 +45,7 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
description: '',
|
||||
image: '',
|
||||
source: '',
|
||||
sources: [],
|
||||
chunks: '',
|
||||
retriever: '',
|
||||
prompt_id: 'default',
|
||||
@@ -150,7 +151,41 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
const formData = new FormData();
|
||||
formData.append('name', agent.name);
|
||||
formData.append('description', agent.description);
|
||||
formData.append('source', agent.source);
|
||||
|
||||
if (selectedSourceIds.size > 1) {
|
||||
const sourcesArray = Array.from(selectedSourceIds)
|
||||
.map((id) => {
|
||||
const sourceDoc = sourceDocs?.find(
|
||||
(source) =>
|
||||
source.id === id || source.retriever === id || source.name === id,
|
||||
);
|
||||
if (sourceDoc?.name === 'Default' && !sourceDoc?.id) {
|
||||
return 'default';
|
||||
}
|
||||
return sourceDoc?.id || id;
|
||||
})
|
||||
.filter(Boolean);
|
||||
formData.append('sources', JSON.stringify(sourcesArray));
|
||||
formData.append('source', '');
|
||||
} else if (selectedSourceIds.size === 1) {
|
||||
const singleSourceId = Array.from(selectedSourceIds)[0];
|
||||
const sourceDoc = sourceDocs?.find(
|
||||
(source) =>
|
||||
source.id === singleSourceId ||
|
||||
source.retriever === singleSourceId ||
|
||||
source.name === singleSourceId,
|
||||
);
|
||||
let finalSourceId;
|
||||
if (sourceDoc?.name === 'Default' && !sourceDoc?.id)
|
||||
finalSourceId = 'default';
|
||||
else finalSourceId = sourceDoc?.id || singleSourceId;
|
||||
formData.append('source', String(finalSourceId));
|
||||
formData.append('sources', JSON.stringify([]));
|
||||
} else {
|
||||
formData.append('source', '');
|
||||
formData.append('sources', JSON.stringify([]));
|
||||
}
|
||||
|
||||
formData.append('chunks', agent.chunks);
|
||||
formData.append('retriever', agent.retriever);
|
||||
formData.append('prompt_id', agent.prompt_id);
|
||||
@@ -196,7 +231,41 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
const formData = new FormData();
|
||||
formData.append('name', agent.name);
|
||||
formData.append('description', agent.description);
|
||||
formData.append('source', agent.source);
|
||||
|
||||
if (selectedSourceIds.size > 1) {
|
||||
const sourcesArray = Array.from(selectedSourceIds)
|
||||
.map((id) => {
|
||||
const sourceDoc = sourceDocs?.find(
|
||||
(source) =>
|
||||
source.id === id || source.retriever === id || source.name === id,
|
||||
);
|
||||
if (sourceDoc?.name === 'Default' && !sourceDoc?.id) {
|
||||
return 'default';
|
||||
}
|
||||
return sourceDoc?.id || id;
|
||||
})
|
||||
.filter(Boolean);
|
||||
formData.append('sources', JSON.stringify(sourcesArray));
|
||||
formData.append('source', '');
|
||||
} else if (selectedSourceIds.size === 1) {
|
||||
const singleSourceId = Array.from(selectedSourceIds)[0];
|
||||
const sourceDoc = sourceDocs?.find(
|
||||
(source) =>
|
||||
source.id === singleSourceId ||
|
||||
source.retriever === singleSourceId ||
|
||||
source.name === singleSourceId,
|
||||
);
|
||||
let finalSourceId;
|
||||
if (sourceDoc?.name === 'Default' && !sourceDoc?.id)
|
||||
finalSourceId = 'default';
|
||||
else finalSourceId = sourceDoc?.id || singleSourceId;
|
||||
formData.append('source', String(finalSourceId));
|
||||
formData.append('sources', JSON.stringify([]));
|
||||
} else {
|
||||
formData.append('source', '');
|
||||
formData.append('sources', JSON.stringify([]));
|
||||
}
|
||||
|
||||
formData.append('chunks', agent.chunks);
|
||||
formData.append('retriever', agent.retriever);
|
||||
formData.append('prompt_id', agent.prompt_id);
|
||||
@@ -293,9 +362,33 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
throw new Error('Failed to fetch agent');
|
||||
}
|
||||
const data = await response.json();
|
||||
if (data.source) setSelectedSourceIds(new Set([data.source]));
|
||||
else if (data.retriever)
|
||||
|
||||
if (data.sources && data.sources.length > 0) {
|
||||
const mappedSources = data.sources.map((sourceId: string) => {
|
||||
if (sourceId === 'default') {
|
||||
const defaultSource = sourceDocs?.find(
|
||||
(source) => source.name === 'Default',
|
||||
);
|
||||
return defaultSource?.retriever || 'classic';
|
||||
}
|
||||
return sourceId;
|
||||
});
|
||||
setSelectedSourceIds(new Set(mappedSources));
|
||||
} else if (data.source) {
|
||||
if (data.source === 'default') {
|
||||
const defaultSource = sourceDocs?.find(
|
||||
(source) => source.name === 'Default',
|
||||
);
|
||||
setSelectedSourceIds(
|
||||
new Set([defaultSource?.retriever || 'classic']),
|
||||
);
|
||||
} else {
|
||||
setSelectedSourceIds(new Set([data.source]));
|
||||
}
|
||||
} else if (data.retriever) {
|
||||
setSelectedSourceIds(new Set([data.retriever]));
|
||||
}
|
||||
|
||||
if (data.tools) setSelectedToolIds(new Set(data.tools));
|
||||
if (data.status === 'draft') setEffectiveMode('draft');
|
||||
if (data.json_schema) {
|
||||
@@ -311,24 +404,56 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
}, [agentId, mode, token]);
|
||||
|
||||
useEffect(() => {
|
||||
const selectedSource = Array.from(selectedSourceIds).map((id) =>
|
||||
const selectedSources = Array.from(selectedSourceIds)
|
||||
.map((id) =>
|
||||
sourceDocs?.find(
|
||||
(source) =>
|
||||
source.id === id || source.retriever === id || source.name === id,
|
||||
),
|
||||
);
|
||||
if (selectedSource[0]?.model === embeddingsName) {
|
||||
if (selectedSource[0] && 'id' in selectedSource[0]) {
|
||||
)
|
||||
.filter(Boolean);
|
||||
|
||||
if (selectedSources.length > 0) {
|
||||
// Handle multiple sources
|
||||
if (selectedSources.length > 1) {
|
||||
// Multiple sources selected - store in sources array
|
||||
const sourceIds = selectedSources
|
||||
.map((source) => source?.id)
|
||||
.filter((id): id is string => Boolean(id));
|
||||
setAgent((prev) => ({
|
||||
...prev,
|
||||
source: selectedSource[0]?.id || 'default',
|
||||
sources: sourceIds,
|
||||
source: '', // Clear single source for multiple sources
|
||||
retriever: '',
|
||||
}));
|
||||
} else
|
||||
} else {
|
||||
// Single source selected - maintain backward compatibility
|
||||
const selectedSource = selectedSources[0];
|
||||
if (selectedSource?.model === embeddingsName) {
|
||||
if (selectedSource && 'id' in selectedSource) {
|
||||
setAgent((prev) => ({
|
||||
...prev,
|
||||
source: selectedSource?.id || 'default',
|
||||
sources: [], // Clear sources array for single source
|
||||
retriever: '',
|
||||
}));
|
||||
} else {
|
||||
setAgent((prev) => ({
|
||||
...prev,
|
||||
source: '',
|
||||
retriever: selectedSource[0]?.retriever || 'classic',
|
||||
sources: [], // Clear sources array
|
||||
retriever: selectedSource?.retriever || 'classic',
|
||||
}));
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// No sources selected
|
||||
setAgent((prev) => ({
|
||||
...prev,
|
||||
source: '',
|
||||
sources: [],
|
||||
retriever: '',
|
||||
}));
|
||||
}
|
||||
}, [selectedSourceIds]);
|
||||
@@ -510,7 +635,7 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
)
|
||||
.filter(Boolean)
|
||||
.join(', ')
|
||||
: 'Select source'}
|
||||
: 'Select sources'}
|
||||
</button>
|
||||
<MultiSelectPopup
|
||||
isOpen={isSourcePopupOpen}
|
||||
@@ -526,12 +651,10 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
selectedIds={selectedSourceIds}
|
||||
onSelectionChange={(newSelectedIds: Set<string | number>) => {
|
||||
setSelectedSourceIds(newSelectedIds);
|
||||
setIsSourcePopupOpen(false);
|
||||
}}
|
||||
title="Select Source"
|
||||
title="Select Sources"
|
||||
searchPlaceholder="Search sources..."
|
||||
noOptionsMessage="No source available"
|
||||
singleSelect={true}
|
||||
noOptionsMessage="No sources available"
|
||||
/>
|
||||
</div>
|
||||
<div className="mt-3">
|
||||
|
||||
@@ -10,6 +10,7 @@ export type Agent = {
|
||||
description: string;
|
||||
image: string;
|
||||
source: string;
|
||||
sources?: string[];
|
||||
chunks: string;
|
||||
retriever: string;
|
||||
prompt_id: string;
|
||||
|
||||
@@ -57,6 +57,8 @@ const endpoints = {
|
||||
DIRECTORY_STRUCTURE: (docId: string) =>
|
||||
`/api/directory_structure?id=${docId}`,
|
||||
MANAGE_SOURCE_FILES: '/api/manage_source_files',
|
||||
MCP_TEST_CONNECTION: '/api/mcp_server/test',
|
||||
MCP_SAVE_SERVER: '/api/mcp_server/save',
|
||||
},
|
||||
CONVERSATION: {
|
||||
ANSWER: '/api/answer',
|
||||
|
||||
@@ -90,7 +90,10 @@ const userService = {
|
||||
path?: string,
|
||||
search?: string,
|
||||
): Promise<any> =>
|
||||
apiClient.get(endpoints.USER.GET_CHUNKS(docId, page, perPage, path, search), token),
|
||||
apiClient.get(
|
||||
endpoints.USER.GET_CHUNKS(docId, page, perPage, path, search),
|
||||
token,
|
||||
),
|
||||
addChunk: (data: any, token: string | null): Promise<any> =>
|
||||
apiClient.post(endpoints.USER.ADD_CHUNK, data, token),
|
||||
deleteChunk: (
|
||||
@@ -105,16 +108,24 @@ const userService = {
|
||||
apiClient.get(endpoints.USER.DIRECTORY_STRUCTURE(docId), token),
|
||||
manageSourceFiles: (data: FormData, token: string | null): Promise<any> =>
|
||||
apiClient.postFormData(endpoints.USER.MANAGE_SOURCE_FILES, data, token),
|
||||
syncConnector: (docId: string, provider: string, token: string | null): Promise<any> => {
|
||||
testMCPConnection: (data: any, token: string | null): Promise<any> =>
|
||||
apiClient.post(endpoints.USER.MCP_TEST_CONNECTION, data, token),
|
||||
saveMCPServer: (data: any, token: string | null): Promise<any> =>
|
||||
apiClient.post(endpoints.USER.MCP_SAVE_SERVER, data, token),
|
||||
syncConnector: (
|
||||
docId: string,
|
||||
provider: string,
|
||||
token: string | null,
|
||||
): Promise<any> => {
|
||||
const sessionToken = getSessionToken(provider);
|
||||
return apiClient.post(
|
||||
endpoints.USER.SYNC_CONNECTOR,
|
||||
{
|
||||
source_id: docId,
|
||||
session_token: sessionToken,
|
||||
provider: provider
|
||||
provider: provider,
|
||||
},
|
||||
token
|
||||
token,
|
||||
);
|
||||
},
|
||||
};
|
||||
|
||||
0
frontend/src/assets/server.svg
Normal file
0
frontend/src/assets/server.svg
Normal file
@@ -16,7 +16,12 @@ const providerLabel = (provider: string) => {
|
||||
return map[provider] || provider.replace(/_/g, ' ');
|
||||
};
|
||||
|
||||
const ConnectorAuth: React.FC<ConnectorAuthProps> = ({ provider, onSuccess, onError, label }) => {
|
||||
const ConnectorAuth: React.FC<ConnectorAuthProps> = ({
|
||||
provider,
|
||||
onSuccess,
|
||||
onError,
|
||||
label,
|
||||
}) => {
|
||||
const token = useSelector(selectToken);
|
||||
const completedRef = useRef(false);
|
||||
const intervalRef = useRef<number | null>(null);
|
||||
@@ -31,8 +36,12 @@ const ConnectorAuth: React.FC<ConnectorAuthProps> = ({ provider, onSuccess, onEr
|
||||
|
||||
const handleAuthMessage = (event: MessageEvent) => {
|
||||
const successGeneric = event.data?.type === 'connector_auth_success';
|
||||
const successProvider = event.data?.type === `${provider}_auth_success` || event.data?.type === 'google_drive_auth_success';
|
||||
const errorProvider = event.data?.type === `${provider}_auth_error` || event.data?.type === 'google_drive_auth_error';
|
||||
const successProvider =
|
||||
event.data?.type === `${provider}_auth_success` ||
|
||||
event.data?.type === 'google_drive_auth_success';
|
||||
const errorProvider =
|
||||
event.data?.type === `${provider}_auth_error` ||
|
||||
event.data?.type === 'google_drive_auth_error';
|
||||
|
||||
if (successGeneric || successProvider) {
|
||||
completedRef.current = true;
|
||||
@@ -54,12 +63,17 @@ const ConnectorAuth: React.FC<ConnectorAuthProps> = ({ provider, onSuccess, onEr
|
||||
cleanup();
|
||||
|
||||
const apiHost = import.meta.env.VITE_API_HOST;
|
||||
const authResponse = await fetch(`${apiHost}/api/connectors/auth?provider=${provider}`, {
|
||||
const authResponse = await fetch(
|
||||
`${apiHost}/api/connectors/auth?provider=${provider}`,
|
||||
{
|
||||
headers: { Authorization: `Bearer ${token}` },
|
||||
});
|
||||
},
|
||||
);
|
||||
|
||||
if (!authResponse.ok) {
|
||||
throw new Error(`Failed to get authorization URL: ${authResponse.status}`);
|
||||
throw new Error(
|
||||
`Failed to get authorization URL: ${authResponse.status}`,
|
||||
);
|
||||
}
|
||||
|
||||
const authData = await authResponse.json();
|
||||
@@ -70,10 +84,12 @@ const ConnectorAuth: React.FC<ConnectorAuthProps> = ({ provider, onSuccess, onEr
|
||||
const authWindow = window.open(
|
||||
authData.authorization_url,
|
||||
`${provider}-auth`,
|
||||
'width=500,height=600,scrollbars=yes,resizable=yes'
|
||||
'width=500,height=600,scrollbars=yes,resizable=yes',
|
||||
);
|
||||
if (!authWindow) {
|
||||
throw new Error('Failed to open authentication window. Please allow popups.');
|
||||
throw new Error(
|
||||
'Failed to open authentication window. Please allow popups.',
|
||||
);
|
||||
}
|
||||
|
||||
window.addEventListener('message', handleAuthMessage as any);
|
||||
@@ -98,10 +114,13 @@ const ConnectorAuth: React.FC<ConnectorAuthProps> = ({ provider, onSuccess, onEr
|
||||
return (
|
||||
<button
|
||||
onClick={handleAuth}
|
||||
className="w-full flex items-center justify-center gap-2 rounded-lg bg-blue-500 px-4 py-3 text-white hover:bg-blue-600 transition-colors"
|
||||
className="flex w-full items-center justify-center gap-2 rounded-lg bg-blue-500 px-4 py-3 text-white transition-colors hover:bg-blue-600"
|
||||
>
|
||||
<svg className="h-5 w-5" viewBox="0 0 24 24">
|
||||
<path fill="currentColor" d="M6.28 3l5.72 10H24l-5.72-10H6.28zm11.44 0L12 13l5.72 10H24L18.28 3h-.56zM0 13l5.72 10h5.72L5.72 13H0z"/>
|
||||
<path
|
||||
fill="currentColor"
|
||||
d="M6.28 3l5.72 10H24l-5.72-10H6.28zm11.44 0L12 13l5.72 10H24L18.28 3h-.56zM0 13l5.72 10h5.72L5.72 13H0z"
|
||||
/>
|
||||
</svg>
|
||||
{buttonLabel}
|
||||
</button>
|
||||
@@ -109,4 +128,3 @@ const ConnectorAuth: React.FC<ConnectorAuthProps> = ({ provider, onSuccess, onEr
|
||||
};
|
||||
|
||||
export default ConnectorAuth;
|
||||
|
||||
|
||||
@@ -240,8 +240,6 @@ const ConnectorTreeComponent: React.FC<ConnectorTreeComponentProps> = ({
|
||||
return current;
|
||||
};
|
||||
|
||||
|
||||
|
||||
const getMenuRef = (id: string) => {
|
||||
if (!menuRefs.current[id]) {
|
||||
menuRefs.current[id] = React.createRef();
|
||||
|
||||
@@ -136,8 +136,6 @@ const FileTreeComponent: React.FC<FileTreeComponentProps> = ({
|
||||
}
|
||||
}, [docId, token]);
|
||||
|
||||
|
||||
|
||||
const navigateToDirectory = (dirName: string) => {
|
||||
setCurrentPath((prev) => [...prev, dirName]);
|
||||
};
|
||||
@@ -445,18 +443,18 @@ const FileTreeComponent: React.FC<FileTreeComponentProps> = ({
|
||||
|
||||
const renderPathNavigation = () => {
|
||||
return (
|
||||
<div className="mb-0 min-h-[38px] flex flex-col gap-2 text-base sm:flex-row sm:items-center sm:justify-between">
|
||||
<div className="mb-0 flex min-h-[38px] flex-col gap-2 text-base sm:flex-row sm:items-center sm:justify-between">
|
||||
{/* Left side with path navigation */}
|
||||
<div className="flex w-full items-center sm:w-auto">
|
||||
<button
|
||||
className="mr-3 flex h-[29px] w-[29px] items-center justify-center rounded-full border p-2 text-sm text-gray-400 dark:border-0 dark:bg-[#28292D] dark:text-gray-500 dark:hover:bg-[#2E2F34] font-medium"
|
||||
className="mr-3 flex h-[29px] w-[29px] items-center justify-center rounded-full border p-2 text-sm font-medium text-gray-400 dark:border-0 dark:bg-[#28292D] dark:text-gray-500 dark:hover:bg-[#2E2F34]"
|
||||
onClick={handleBackNavigation}
|
||||
>
|
||||
<img src={ArrowLeft} alt="left-arrow" className="h-3 w-3" />
|
||||
</button>
|
||||
|
||||
<div className="flex flex-wrap items-center">
|
||||
<span className="text-[#7D54D1] font-semibold break-words">
|
||||
<span className="font-semibold break-words text-[#7D54D1]">
|
||||
{sourceName}
|
||||
</span>
|
||||
{currentPath.length > 0 && (
|
||||
@@ -487,8 +485,7 @@ const FileTreeComponent: React.FC<FileTreeComponentProps> = ({
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="flex relative flex-row flex-nowrap items-center gap-2 w-full sm:w-auto justify-end mt-2 sm:mt-0">
|
||||
|
||||
<div className="relative mt-2 flex w-full flex-row flex-nowrap items-center justify-end gap-2 sm:mt-0 sm:w-auto">
|
||||
{processingRef.current && (
|
||||
<div className="text-sm text-gray-500">
|
||||
{currentOpRef.current === 'add'
|
||||
@@ -503,7 +500,7 @@ const FileTreeComponent: React.FC<FileTreeComponentProps> = ({
|
||||
{!processingRef.current && (
|
||||
<button
|
||||
onClick={handleAddFile}
|
||||
className="bg-purple-30 hover:bg-violets-are-blue flex h-[38px] min-w-[108px] items-center justify-center rounded-full px-4 text-[14px] whitespace-nowrap text-white font-medium"
|
||||
className="bg-purple-30 hover:bg-violets-are-blue flex h-[38px] min-w-[108px] items-center justify-center rounded-full px-4 text-[14px] font-medium whitespace-nowrap text-white"
|
||||
title={t('settings.sources.addFile')}
|
||||
>
|
||||
{t('settings.sources.addFile')}
|
||||
@@ -609,7 +606,7 @@ const FileTreeComponent: React.FC<FileTreeComponentProps> = ({
|
||||
<div ref={menuRef} className="relative">
|
||||
<button
|
||||
onClick={(e) => handleMenuClick(e, itemId)}
|
||||
className="inline-flex h-[35px] w-[24px] shrink-0 items-center justify-center rounded-md transition-colors hover:bg-[#EBEBEB] dark:hover:bg-[#26272E] font-medium"
|
||||
className="inline-flex h-[35px] w-[24px] shrink-0 items-center justify-center rounded-md font-medium transition-colors hover:bg-[#EBEBEB] dark:hover:bg-[#26272E]"
|
||||
aria-label={t('settings.sources.menuAlt')}
|
||||
>
|
||||
<img
|
||||
@@ -664,7 +661,7 @@ const FileTreeComponent: React.FC<FileTreeComponentProps> = ({
|
||||
<div ref={menuRef} className="relative">
|
||||
<button
|
||||
onClick={(e) => handleMenuClick(e, itemId)}
|
||||
className="inline-flex h-[35px] w-[24px] shrink-0 items-center justify-center rounded-md transition-colors hover:bg-[#EBEBEB] dark:hover:bg-[#26272E] font-medium"
|
||||
className="inline-flex h-[35px] w-[24px] shrink-0 items-center justify-center rounded-md font-medium transition-colors hover:bg-[#EBEBEB] dark:hover:bg-[#26272E]"
|
||||
aria-label={t('settings.sources.menuAlt')}
|
||||
>
|
||||
<img
|
||||
@@ -756,14 +753,12 @@ const FileTreeComponent: React.FC<FileTreeComponentProps> = ({
|
||||
}
|
||||
}}
|
||||
placeholder={t('settings.sources.searchFiles')}
|
||||
className={`w-full h-[38px] border border-[#D1D9E0] px-4 py-2 dark:border-[#6A6A6A]
|
||||
${searchQuery ? 'rounded-t-[24px]' : 'rounded-[24px]'}
|
||||
bg-transparent focus:outline-none dark:text-[#E0E0E0]`}
|
||||
className={`h-[38px] w-full border border-[#D1D9E0] px-4 py-2 dark:border-[#6A6A6A] ${searchQuery ? 'rounded-t-[24px]' : 'rounded-[24px]'} bg-transparent focus:outline-none dark:text-[#E0E0E0]`}
|
||||
/>
|
||||
|
||||
{searchQuery && (
|
||||
<div className="absolute top-full left-0 right-0 z-10 max-h-[calc(100vh-200px)] w-full overflow-hidden rounded-b-[12px] border border-t-0 border-[#D1D9E0] bg-white shadow-lg dark:border-[#6A6A6A] dark:bg-[#1F2023] transition-all duration-200">
|
||||
<div className="max-h-[calc(100vh-200px)] overflow-y-auto overflow-x-hidden overscroll-contain">
|
||||
<div className="absolute top-full right-0 left-0 z-10 max-h-[calc(100vh-200px)] w-full overflow-hidden rounded-b-[12px] border border-t-0 border-[#D1D9E0] bg-white shadow-lg transition-all duration-200 dark:border-[#6A6A6A] dark:bg-[#1F2023]">
|
||||
<div className="max-h-[calc(100vh-200px)] overflow-x-hidden overflow-y-auto overscroll-contain">
|
||||
{searchResults.length === 0 ? (
|
||||
<div className="py-2 text-center text-sm text-gray-500 dark:text-gray-400">
|
||||
{t('settings.sources.noResults')}
|
||||
@@ -774,7 +769,8 @@ const FileTreeComponent: React.FC<FileTreeComponentProps> = ({
|
||||
key={index}
|
||||
onClick={() => handleSearchSelect(result)}
|
||||
title={result.path}
|
||||
className={`flex min-w-0 cursor-pointer items-center px-3 py-2 hover:bg-[#ECEEEF] dark:hover:bg-[#27282D] ${index !== searchResults.length - 1
|
||||
className={`flex min-w-0 cursor-pointer items-center px-3 py-2 hover:bg-[#ECEEEF] dark:hover:bg-[#27282D] ${
|
||||
index !== searchResults.length - 1
|
||||
? 'border-b border-[#D1D9E0] dark:border-[#6A6A6A]'
|
||||
: ''
|
||||
}`}
|
||||
@@ -788,7 +784,7 @@ const FileTreeComponent: React.FC<FileTreeComponentProps> = ({
|
||||
}
|
||||
className="mr-2 h-4 w-4 flex-shrink-0"
|
||||
/>
|
||||
<span className="text-sm dark:text-[#E0E0E0] truncate flex-1">
|
||||
<span className="flex-1 truncate text-sm dark:text-[#E0E0E0]">
|
||||
{result.path.split('/').pop() || result.path}
|
||||
</span>
|
||||
</div>
|
||||
@@ -870,7 +866,9 @@ const FileTreeComponent: React.FC<FileTreeComponentProps> = ({
|
||||
message={
|
||||
itemToDelete?.isFile
|
||||
? t('settings.sources.confirmDelete')
|
||||
: t('settings.sources.deleteDirectoryWarning', { name: itemToDelete?.name })
|
||||
: t('settings.sources.deleteDirectoryWarning', {
|
||||
name: itemToDelete?.name,
|
||||
})
|
||||
}
|
||||
modalState={deleteModalState}
|
||||
setModalState={setDeleteModalState}
|
||||
|
||||
@@ -368,8 +368,8 @@ export default function MessageInput({
|
||||
className="xs:px-3 xs:py-1.5 dark:border-purple-taupe flex max-w-[130px] items-center rounded-[32px] border border-[#AAAAAA] px-2 py-1 transition-colors hover:bg-gray-100 sm:max-w-[150px] dark:hover:bg-[#2C2E3C]"
|
||||
onClick={() => setIsSourcesPopupOpen(!isSourcesPopupOpen)}
|
||||
title={
|
||||
selectedDocs
|
||||
? selectedDocs.name
|
||||
selectedDocs && selectedDocs.length > 0
|
||||
? selectedDocs.map((doc) => doc.name).join(', ')
|
||||
: t('conversation.sources.title')
|
||||
}
|
||||
>
|
||||
@@ -379,8 +379,10 @@ export default function MessageInput({
|
||||
className="mr-1 h-3.5 w-3.5 shrink-0 sm:mr-1.5 sm:h-4"
|
||||
/>
|
||||
<span className="xs:text-[12px] dark:text-bright-gray truncate overflow-hidden text-[10px] font-medium text-[#5D5D5D] sm:text-[14px]">
|
||||
{selectedDocs
|
||||
? selectedDocs.name
|
||||
{selectedDocs && selectedDocs.length > 0
|
||||
? selectedDocs.length === 1
|
||||
? selectedDocs[0].name
|
||||
: `${selectedDocs.length} sources selected`
|
||||
: t('conversation.sources.title')}
|
||||
</span>
|
||||
{!isTouch && (
|
||||
|
||||
@@ -17,7 +17,7 @@ type SourcesPopupProps = {
|
||||
isOpen: boolean;
|
||||
onClose: () => void;
|
||||
anchorRef: React.RefObject<HTMLButtonElement | null>;
|
||||
handlePostDocumentSelect: (doc: Doc | null) => void;
|
||||
handlePostDocumentSelect: (doc: Doc[] | null) => void;
|
||||
setUploadModalState: React.Dispatch<React.SetStateAction<ActiveState>>;
|
||||
};
|
||||
|
||||
@@ -149,9 +149,13 @@ export default function SourcesPopup({
|
||||
if (option.model === embeddingsName) {
|
||||
const isSelected =
|
||||
selectedDocs &&
|
||||
(option.id
|
||||
? selectedDocs.id === option.id
|
||||
: selectedDocs.date === option.date);
|
||||
Array.isArray(selectedDocs) &&
|
||||
selectedDocs.length > 0 &&
|
||||
selectedDocs.some((doc) =>
|
||||
option.id
|
||||
? doc.id === option.id
|
||||
: doc.date === option.date,
|
||||
);
|
||||
|
||||
return (
|
||||
<div
|
||||
@@ -159,11 +163,29 @@ export default function SourcesPopup({
|
||||
className="border-opacity-80 dark:border-dim-gray flex cursor-pointer items-center border-b border-[#D9D9D9] p-3 transition-colors hover:bg-gray-100 dark:text-[14px] dark:hover:bg-[#2C2E3C]"
|
||||
onClick={() => {
|
||||
if (isSelected) {
|
||||
dispatch(setSelectedDocs(null));
|
||||
handlePostDocumentSelect(null);
|
||||
const updatedDocs =
|
||||
selectedDocs && Array.isArray(selectedDocs)
|
||||
? selectedDocs.filter((doc) =>
|
||||
option.id
|
||||
? doc.id !== option.id
|
||||
: doc.date !== option.date,
|
||||
)
|
||||
: [];
|
||||
dispatch(
|
||||
setSelectedDocs(
|
||||
updatedDocs.length > 0 ? updatedDocs : null,
|
||||
),
|
||||
);
|
||||
handlePostDocumentSelect(
|
||||
updatedDocs.length > 0 ? updatedDocs : null,
|
||||
);
|
||||
} else {
|
||||
dispatch(setSelectedDocs(option));
|
||||
handlePostDocumentSelect(option);
|
||||
const updatedDocs =
|
||||
selectedDocs && Array.isArray(selectedDocs)
|
||||
? [...selectedDocs, option]
|
||||
: [option];
|
||||
dispatch(setSelectedDocs(updatedDocs));
|
||||
handlePostDocumentSelect(updatedDocs);
|
||||
}
|
||||
}}
|
||||
>
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import 'katex/dist/katex.min.css';
|
||||
|
||||
import { forwardRef, Fragment, useRef, useState, useEffect } from 'react';
|
||||
import { forwardRef, Fragment, useEffect, useRef, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import ReactMarkdown from 'react-markdown';
|
||||
import { useSelector } from 'react-redux';
|
||||
@@ -12,12 +12,13 @@ import {
|
||||
import rehypeKatex from 'rehype-katex';
|
||||
import remarkGfm from 'remark-gfm';
|
||||
import remarkMath from 'remark-math';
|
||||
import DocumentationDark from '../assets/documentation-dark.svg';
|
||||
|
||||
import ChevronDown from '../assets/chevron-down.svg';
|
||||
import Cloud from '../assets/cloud.svg';
|
||||
import DocsGPT3 from '../assets/cute_docsgpt3.svg';
|
||||
import Dislike from '../assets/dislike.svg?react';
|
||||
import Document from '../assets/document.svg';
|
||||
import DocumentationDark from '../assets/documentation-dark.svg';
|
||||
import Edit from '../assets/edit.svg';
|
||||
import Like from '../assets/like.svg?react';
|
||||
import Link from '../assets/link.svg';
|
||||
@@ -761,7 +762,11 @@ function ToolCalls({ toolCalls }: { toolCalls: ToolCallsType[] }) {
|
||||
Response
|
||||
</span>{' '}
|
||||
<CopyButton
|
||||
textToCopy={JSON.stringify(toolCall.result, null, 2)}
|
||||
textToCopy={
|
||||
toolCall.status === 'error'
|
||||
? toolCall.error || 'Unknown error'
|
||||
: JSON.stringify(toolCall.result, null, 2)
|
||||
}
|
||||
/>
|
||||
</p>
|
||||
{toolCall.status === 'pending' && (
|
||||
@@ -779,6 +784,16 @@ function ToolCalls({ toolCalls }: { toolCalls: ToolCallsType[] }) {
|
||||
</span>
|
||||
</p>
|
||||
)}
|
||||
{toolCall.status === 'error' && (
|
||||
<p className="dark:bg-raisin-black rounded-b-2xl p-2 font-mono text-sm break-words">
|
||||
<span
|
||||
className="leading-[23px] text-red-500 dark:text-red-400"
|
||||
style={{ fontFamily: 'IBMPlexMono-Medium' }}
|
||||
>
|
||||
{toolCall.error}
|
||||
</span>
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</Accordion>
|
||||
|
||||
@@ -7,7 +7,7 @@ export function handleFetchAnswer(
|
||||
question: string,
|
||||
signal: AbortSignal,
|
||||
token: string | null,
|
||||
selectedDocs: Doc | null,
|
||||
selectedDocs: Doc[] | null,
|
||||
conversationId: string | null,
|
||||
promptId: string | null,
|
||||
chunks: string,
|
||||
@@ -52,10 +52,17 @@ export function handleFetchAnswer(
|
||||
payload.attachments = attachments;
|
||||
}
|
||||
|
||||
if (selectedDocs && 'id' in selectedDocs) {
|
||||
payload.active_docs = selectedDocs.id as string;
|
||||
if (selectedDocs && Array.isArray(selectedDocs)) {
|
||||
if (selectedDocs.length > 1) {
|
||||
// Handle multiple documents
|
||||
payload.active_docs = selectedDocs.map((doc) => doc.id!);
|
||||
payload.retriever = selectedDocs[0]?.retriever as string;
|
||||
} else if (selectedDocs.length === 1 && 'id' in selectedDocs[0]) {
|
||||
// Handle single document (backward compatibility)
|
||||
payload.active_docs = selectedDocs[0].id as string;
|
||||
payload.retriever = selectedDocs[0].retriever as string;
|
||||
}
|
||||
}
|
||||
payload.retriever = selectedDocs?.retriever as string;
|
||||
return conversationService
|
||||
.answer(payload, token, signal)
|
||||
.then((response) => {
|
||||
@@ -84,7 +91,7 @@ export function handleFetchAnswerSteaming(
|
||||
question: string,
|
||||
signal: AbortSignal,
|
||||
token: string | null,
|
||||
selectedDocs: Doc | null,
|
||||
selectedDocs: Doc[] | null,
|
||||
conversationId: string | null,
|
||||
promptId: string | null,
|
||||
chunks: string,
|
||||
@@ -112,10 +119,17 @@ export function handleFetchAnswerSteaming(
|
||||
payload.attachments = attachments;
|
||||
}
|
||||
|
||||
if (selectedDocs && 'id' in selectedDocs) {
|
||||
payload.active_docs = selectedDocs.id as string;
|
||||
if (selectedDocs && Array.isArray(selectedDocs)) {
|
||||
if (selectedDocs.length > 1) {
|
||||
// Handle multiple documents
|
||||
payload.active_docs = selectedDocs.map((doc) => doc.id!);
|
||||
payload.retriever = selectedDocs[0]?.retriever as string;
|
||||
} else if (selectedDocs.length === 1 && 'id' in selectedDocs[0]) {
|
||||
// Handle single document (backward compatibility)
|
||||
payload.active_docs = selectedDocs[0].id as string;
|
||||
payload.retriever = selectedDocs[0].retriever as string;
|
||||
}
|
||||
}
|
||||
payload.retriever = selectedDocs?.retriever as string;
|
||||
|
||||
return new Promise<Answer>((resolve, reject) => {
|
||||
conversationService
|
||||
@@ -171,7 +185,7 @@ export function handleFetchAnswerSteaming(
|
||||
export function handleSearch(
|
||||
question: string,
|
||||
token: string | null,
|
||||
selectedDocs: Doc | null,
|
||||
selectedDocs: Doc[] | null,
|
||||
conversation_id: string | null,
|
||||
chunks: string,
|
||||
token_limit: number,
|
||||
@@ -183,9 +197,17 @@ export function handleSearch(
|
||||
token_limit: token_limit,
|
||||
isNoneDoc: selectedDocs === null,
|
||||
};
|
||||
if (selectedDocs && 'id' in selectedDocs)
|
||||
payload.active_docs = selectedDocs.id as string;
|
||||
payload.retriever = selectedDocs?.retriever as string;
|
||||
if (selectedDocs && Array.isArray(selectedDocs)) {
|
||||
if (selectedDocs.length > 1) {
|
||||
// Handle multiple documents
|
||||
payload.active_docs = selectedDocs.map((doc) => doc.id!);
|
||||
payload.retriever = selectedDocs[0]?.retriever as string;
|
||||
} else if (selectedDocs.length === 1 && 'id' in selectedDocs[0]) {
|
||||
// Handle single document (backward compatibility)
|
||||
payload.active_docs = selectedDocs[0].id as string;
|
||||
payload.retriever = selectedDocs[0].retriever as string;
|
||||
}
|
||||
}
|
||||
return conversationService
|
||||
.search(payload, token)
|
||||
.then((response) => response.json())
|
||||
|
||||
@@ -54,7 +54,7 @@ export interface Query {
|
||||
|
||||
export interface RetrievalPayload {
|
||||
question: string;
|
||||
active_docs?: string;
|
||||
active_docs?: string | string[];
|
||||
retriever?: string;
|
||||
conversation_id: string | null;
|
||||
prompt_id?: string | null;
|
||||
|
||||
@@ -4,5 +4,6 @@ export type ToolCallsType = {
|
||||
call_id: string;
|
||||
arguments: Record<string, any>;
|
||||
result?: Record<string, any>;
|
||||
status?: 'pending' | 'completed';
|
||||
error?: string;
|
||||
status?: 'pending' | 'completed' | 'error';
|
||||
};
|
||||
|
||||
@@ -18,11 +18,14 @@ export default function useDefaultDocument() {
|
||||
const fetchDocs = () => {
|
||||
getDocs(token).then((data) => {
|
||||
dispatch(setSourceDocs(data));
|
||||
if (!selectedDoc)
|
||||
if (
|
||||
!selectedDoc ||
|
||||
(Array.isArray(selectedDoc) && selectedDoc.length === 0)
|
||||
)
|
||||
Array.isArray(data) &&
|
||||
data?.forEach((doc: Doc) => {
|
||||
if (doc.model && doc.name === 'default') {
|
||||
dispatch(setSelectedDocs(doc));
|
||||
dispatch(setSelectedDocs([doc]));
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
@@ -185,7 +185,39 @@
|
||||
"cancel": "Cancel",
|
||||
"addNew": "Add New",
|
||||
"name": "Name",
|
||||
"type": "Type"
|
||||
"type": "Type",
|
||||
"mcp": {
|
||||
"addServer": "Add MCP Server",
|
||||
"editServer": "Edit Server",
|
||||
"serverName": "Server Name",
|
||||
"serverUrl": "Server URL",
|
||||
"headerName": "Header Name",
|
||||
"timeout": "Timeout (seconds)",
|
||||
"testConnection": "Test Connection",
|
||||
"testing": "Testing...",
|
||||
"saving": "Saving...",
|
||||
"save": "Save",
|
||||
"cancel": "Cancel",
|
||||
"noAuth": "No Authentication",
|
||||
"placeholders": {
|
||||
"serverUrl": "https://api.example.com",
|
||||
"apiKey": "Your secret API key",
|
||||
"bearerToken": "Your secret token",
|
||||
"username": "Your username",
|
||||
"password": "Your password"
|
||||
},
|
||||
"errors": {
|
||||
"nameRequired": "Server name is required",
|
||||
"urlRequired": "Server URL is required",
|
||||
"invalidUrl": "Please enter a valid URL",
|
||||
"apiKeyRequired": "API key is required",
|
||||
"tokenRequired": "Bearer token is required",
|
||||
"usernameRequired": "Username is required",
|
||||
"passwordRequired": "Password is required",
|
||||
"testFailed": "Connection test failed",
|
||||
"saveFailed": "Failed to save MCP server"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"modals": {
|
||||
|
||||
@@ -8,6 +8,7 @@ import { useOutsideAlerter } from '../hooks';
|
||||
import { ActiveState } from '../models/misc';
|
||||
import { selectToken } from '../preferences/preferenceSlice';
|
||||
import ConfigToolModal from './ConfigToolModal';
|
||||
import MCPServerModal from './MCPServerModal';
|
||||
import { AvailableToolType } from './types';
|
||||
import WrapperComponent from './WrapperModal';
|
||||
|
||||
@@ -34,6 +35,8 @@ export default function AddToolModal({
|
||||
React.useState<AvailableToolType | null>(null);
|
||||
const [configModalState, setConfigModalState] =
|
||||
React.useState<ActiveState>('INACTIVE');
|
||||
const [mcpModalState, setMcpModalState] =
|
||||
React.useState<ActiveState>('INACTIVE');
|
||||
const [loading, setLoading] = React.useState(false);
|
||||
|
||||
useOutsideAlerter(modalRef, () => {
|
||||
@@ -86,6 +89,9 @@ export default function AddToolModal({
|
||||
.catch((error) => {
|
||||
console.error('Failed to create tool:', error);
|
||||
});
|
||||
} else if (tool.name === 'mcp_tool') {
|
||||
setModalState('INACTIVE');
|
||||
setMcpModalState('ACTIVE');
|
||||
} else {
|
||||
setModalState('INACTIVE');
|
||||
setConfigModalState('ACTIVE');
|
||||
@@ -95,6 +101,12 @@ export default function AddToolModal({
|
||||
React.useEffect(() => {
|
||||
if (modalState === 'ACTIVE') getAvailableTools();
|
||||
}, [modalState]);
|
||||
|
||||
const handleMcpServerAdded = () => {
|
||||
getUserTools();
|
||||
setMcpModalState('INACTIVE');
|
||||
};
|
||||
|
||||
return (
|
||||
<>
|
||||
{modalState === 'ACTIVE' && (
|
||||
@@ -166,6 +178,11 @@ export default function AddToolModal({
|
||||
tool={selectedTool}
|
||||
getUserTools={getUserTools}
|
||||
/>
|
||||
<MCPServerModal
|
||||
modalState={mcpModalState}
|
||||
setModalState={setMcpModalState}
|
||||
onServerSaved={handleMcpServerAdded}
|
||||
/>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
482
frontend/src/modals/MCPServerModal.tsx
Normal file
482
frontend/src/modals/MCPServerModal.tsx
Normal file
@@ -0,0 +1,482 @@
|
||||
import { useRef, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useSelector } from 'react-redux';
|
||||
|
||||
import userService from '../api/services/userService';
|
||||
import Dropdown from '../components/Dropdown';
|
||||
import Input from '../components/Input';
|
||||
import Spinner from '../components/Spinner';
|
||||
import { useOutsideAlerter } from '../hooks';
|
||||
import { ActiveState } from '../models/misc';
|
||||
import { selectToken } from '../preferences/preferenceSlice';
|
||||
import WrapperComponent from './WrapperModal';
|
||||
|
||||
interface MCPServerModalProps {
|
||||
modalState: ActiveState;
|
||||
setModalState: (state: ActiveState) => void;
|
||||
server?: any;
|
||||
onServerSaved: () => void;
|
||||
}
|
||||
|
||||
const authTypes = [
|
||||
{ label: 'No Authentication', value: 'none' },
|
||||
{ label: 'API Key', value: 'api_key' },
|
||||
{ label: 'Bearer Token', value: 'bearer' },
|
||||
// { label: 'Basic Authentication', value: 'basic' },
|
||||
];
|
||||
|
||||
export default function MCPServerModal({
|
||||
modalState,
|
||||
setModalState,
|
||||
server,
|
||||
onServerSaved,
|
||||
}: MCPServerModalProps) {
|
||||
const { t } = useTranslation();
|
||||
const token = useSelector(selectToken);
|
||||
const modalRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
const [formData, setFormData] = useState({
|
||||
name: server?.displayName || 'My MCP Server',
|
||||
server_url: server?.server_url || '',
|
||||
auth_type: server?.auth_type || 'none',
|
||||
api_key: '',
|
||||
header_name: 'X-API-Key',
|
||||
bearer_token: '',
|
||||
username: '',
|
||||
password: '',
|
||||
timeout: server?.timeout || 30,
|
||||
});
|
||||
|
||||
const [loading, setLoading] = useState(false);
|
||||
const [testing, setTesting] = useState(false);
|
||||
const [testResult, setTestResult] = useState<{
|
||||
success: boolean;
|
||||
message: string;
|
||||
} | null>(null);
|
||||
const [errors, setErrors] = useState<{ [key: string]: string }>({});
|
||||
|
||||
useOutsideAlerter(modalRef, () => {
|
||||
if (modalState === 'ACTIVE') {
|
||||
setModalState('INACTIVE');
|
||||
resetForm();
|
||||
}
|
||||
}, [modalState]);
|
||||
|
||||
const resetForm = () => {
|
||||
setFormData({
|
||||
name: 'My MCP Server',
|
||||
server_url: '',
|
||||
auth_type: 'none',
|
||||
api_key: '',
|
||||
header_name: 'X-API-Key',
|
||||
bearer_token: '',
|
||||
username: '',
|
||||
password: '',
|
||||
timeout: 30,
|
||||
});
|
||||
setErrors({});
|
||||
setTestResult(null);
|
||||
};
|
||||
|
||||
const validateForm = () => {
|
||||
const requiredFields: { [key: string]: boolean } = {
|
||||
name: !formData.name.trim(),
|
||||
server_url: !formData.server_url.trim(),
|
||||
};
|
||||
|
||||
const authFieldChecks: { [key: string]: () => void } = {
|
||||
api_key: () => {
|
||||
if (!formData.api_key.trim())
|
||||
newErrors.api_key = t('settings.tools.mcp.errors.apiKeyRequired');
|
||||
},
|
||||
bearer: () => {
|
||||
if (!formData.bearer_token.trim())
|
||||
newErrors.bearer_token = t('settings.tools.mcp.errors.tokenRequired');
|
||||
},
|
||||
basic: () => {
|
||||
if (!formData.username.trim())
|
||||
newErrors.username = t('settings.tools.mcp.errors.usernameRequired');
|
||||
if (!formData.password.trim())
|
||||
newErrors.password = t('settings.tools.mcp.errors.passwordRequired');
|
||||
},
|
||||
};
|
||||
|
||||
const newErrors: { [key: string]: string } = {};
|
||||
Object.entries(requiredFields).forEach(([field, isEmpty]) => {
|
||||
if (isEmpty)
|
||||
newErrors[field] = t(
|
||||
`settings.tools.mcp.errors.${field === 'name' ? 'nameRequired' : 'urlRequired'}`,
|
||||
);
|
||||
});
|
||||
|
||||
if (formData.server_url.trim()) {
|
||||
try {
|
||||
new URL(formData.server_url);
|
||||
} catch {
|
||||
newErrors.server_url = t('settings.tools.mcp.errors.invalidUrl');
|
||||
}
|
||||
}
|
||||
|
||||
const timeoutValue = formData.timeout === '' ? 30 : formData.timeout;
|
||||
if (
|
||||
typeof timeoutValue === 'number' &&
|
||||
(timeoutValue < 1 || timeoutValue > 300)
|
||||
)
|
||||
newErrors.timeout = 'Timeout must be between 1 and 300 seconds';
|
||||
|
||||
if (authFieldChecks[formData.auth_type])
|
||||
authFieldChecks[formData.auth_type]();
|
||||
|
||||
setErrors(newErrors);
|
||||
return Object.keys(newErrors).length === 0;
|
||||
};
|
||||
|
||||
const handleInputChange = (name: string, value: string | number) => {
|
||||
setFormData((prev) => ({ ...prev, [name]: value }));
|
||||
if (errors[name]) {
|
||||
setErrors((prev) => ({ ...prev, [name]: '' }));
|
||||
}
|
||||
setTestResult(null);
|
||||
};
|
||||
|
||||
const buildToolConfig = () => {
|
||||
const config: any = {
|
||||
server_url: formData.server_url.trim(),
|
||||
auth_type: formData.auth_type,
|
||||
timeout: formData.timeout === '' ? 30 : formData.timeout,
|
||||
};
|
||||
|
||||
if (formData.auth_type === 'api_key') {
|
||||
config.api_key = formData.api_key.trim();
|
||||
config.api_key_header = formData.header_name.trim() || 'X-API-Key';
|
||||
} else if (formData.auth_type === 'bearer') {
|
||||
config.bearer_token = formData.bearer_token.trim();
|
||||
} else if (formData.auth_type === 'basic') {
|
||||
config.username = formData.username.trim();
|
||||
config.password = formData.password.trim();
|
||||
}
|
||||
return config;
|
||||
};
|
||||
|
||||
const testConnection = async () => {
|
||||
if (!validateForm()) return;
|
||||
setTesting(true);
|
||||
setTestResult(null);
|
||||
try {
|
||||
const config = buildToolConfig();
|
||||
const response = await userService.testMCPConnection({ config }, token);
|
||||
const result = await response.json();
|
||||
|
||||
setTestResult(result);
|
||||
} catch (error) {
|
||||
setTestResult({
|
||||
success: false,
|
||||
message: t('settings.tools.mcp.errors.testFailed'),
|
||||
});
|
||||
} finally {
|
||||
setTesting(false);
|
||||
}
|
||||
};
|
||||
|
||||
const handleSave = async () => {
|
||||
if (!validateForm()) return;
|
||||
setLoading(true);
|
||||
try {
|
||||
const config = buildToolConfig();
|
||||
const serverData = {
|
||||
displayName: formData.name,
|
||||
config,
|
||||
status: true,
|
||||
...(server?.id && { id: server.id }),
|
||||
};
|
||||
|
||||
const response = await userService.saveMCPServer(serverData, token);
|
||||
const result = await response.json();
|
||||
|
||||
if (response.ok && result.success) {
|
||||
setTestResult({
|
||||
success: true,
|
||||
message: result.message,
|
||||
});
|
||||
onServerSaved();
|
||||
setModalState('INACTIVE');
|
||||
resetForm();
|
||||
} else {
|
||||
setErrors({
|
||||
general: result.error || t('settings.tools.mcp.errors.saveFailed'),
|
||||
});
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error saving MCP server:', error);
|
||||
setErrors({ general: t('settings.tools.mcp.errors.saveFailed') });
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
const renderAuthFields = () => {
|
||||
switch (formData.auth_type) {
|
||||
case 'api_key':
|
||||
return (
|
||||
<div className="mb-10">
|
||||
<div className="mt-6">
|
||||
<Input
|
||||
name="api_key"
|
||||
type="text"
|
||||
className="rounded-md"
|
||||
value={formData.api_key}
|
||||
onChange={(e) => handleInputChange('api_key', e.target.value)}
|
||||
placeholder={t('settings.tools.mcp.placeholders.apiKey')}
|
||||
borderVariant="thin"
|
||||
labelBgClassName="bg-white dark:bg-charleston-green-2"
|
||||
/>
|
||||
{errors.api_key && (
|
||||
<p className="mt-1 text-sm text-red-600">{errors.api_key}</p>
|
||||
)}
|
||||
</div>
|
||||
<div className="mt-5">
|
||||
<Input
|
||||
name="header_name"
|
||||
type="text"
|
||||
className="rounded-md"
|
||||
value={formData.header_name}
|
||||
onChange={(e) =>
|
||||
handleInputChange('header_name', e.target.value)
|
||||
}
|
||||
placeholder={t('settings.tools.mcp.headerName')}
|
||||
borderVariant="thin"
|
||||
labelBgClassName="bg-white dark:bg-charleston-green-2"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
case 'bearer':
|
||||
return (
|
||||
<div className="mb-10">
|
||||
<Input
|
||||
name="bearer_token"
|
||||
type="text"
|
||||
className="rounded-md"
|
||||
value={formData.bearer_token}
|
||||
onChange={(e) =>
|
||||
handleInputChange('bearer_token', e.target.value)
|
||||
}
|
||||
placeholder={t('settings.tools.mcp.placeholders.bearerToken')}
|
||||
borderVariant="thin"
|
||||
labelBgClassName="bg-white dark:bg-charleston-green-2"
|
||||
/>
|
||||
{errors.bearer_token && (
|
||||
<p className="mt-1 text-sm text-red-600">{errors.bearer_token}</p>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
case 'basic':
|
||||
return (
|
||||
<div className="mb-10">
|
||||
<div className="mt-6">
|
||||
<Input
|
||||
name="username"
|
||||
type="text"
|
||||
className="rounded-md"
|
||||
value={formData.username}
|
||||
onChange={(e) => handleInputChange('username', e.target.value)}
|
||||
placeholder={t('settings.tools.mcp.username')}
|
||||
borderVariant="thin"
|
||||
labelBgClassName="bg-white dark:bg-charleston-green-2"
|
||||
/>
|
||||
{errors.username && (
|
||||
<p className="mt-1 text-sm text-red-600">{errors.username}</p>
|
||||
)}
|
||||
</div>
|
||||
<div className="mt-5">
|
||||
<Input
|
||||
name="password"
|
||||
type="text"
|
||||
className="rounded-md"
|
||||
value={formData.password}
|
||||
onChange={(e) => handleInputChange('password', e.target.value)}
|
||||
placeholder={t('settings.tools.mcp.password')}
|
||||
borderVariant="thin"
|
||||
labelBgClassName="bg-white dark:bg-charleston-green-2"
|
||||
/>
|
||||
{errors.password && (
|
||||
<p className="mt-1 text-sm text-red-600">{errors.password}</p>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
default:
|
||||
return null;
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
modalState === 'ACTIVE' && (
|
||||
<WrapperComponent
|
||||
close={() => {
|
||||
setModalState('INACTIVE');
|
||||
resetForm();
|
||||
}}
|
||||
className="max-w-[600px] md:w-[80vw] lg:w-[60vw]"
|
||||
>
|
||||
<div className="flex h-full flex-col">
|
||||
<div className="px-6 py-4">
|
||||
<h2 className="text-jet dark:text-bright-gray text-xl font-semibold">
|
||||
{server
|
||||
? t('settings.tools.mcp.editServer')
|
||||
: t('settings.tools.mcp.addServer')}
|
||||
</h2>
|
||||
</div>
|
||||
<div className="flex-1 px-6">
|
||||
<div className="space-y-6 py-6">
|
||||
<div>
|
||||
<Input
|
||||
name="name"
|
||||
type="text"
|
||||
className="rounded-md"
|
||||
value={formData.name}
|
||||
onChange={(e) => handleInputChange('name', e.target.value)}
|
||||
borderVariant="thin"
|
||||
placeholder={t('settings.tools.mcp.serverName')}
|
||||
labelBgClassName="bg-white dark:bg-charleston-green-2"
|
||||
/>
|
||||
{errors.name && (
|
||||
<p className="mt-1 text-sm text-red-600">{errors.name}</p>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<Input
|
||||
name="server_url"
|
||||
type="text"
|
||||
className="rounded-md"
|
||||
value={formData.server_url}
|
||||
onChange={(e) =>
|
||||
handleInputChange('server_url', e.target.value)
|
||||
}
|
||||
placeholder={t('settings.tools.mcp.serverUrl')}
|
||||
borderVariant="thin"
|
||||
labelBgClassName="bg-white dark:bg-charleston-green-2"
|
||||
/>
|
||||
{errors.server_url && (
|
||||
<p className="mt-1 text-sm text-red-600">
|
||||
{errors.server_url}
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<Dropdown
|
||||
placeholder={t('settings.tools.mcp.authType')}
|
||||
selectedValue={
|
||||
authTypes.find((type) => type.value === formData.auth_type)
|
||||
?.label || null
|
||||
}
|
||||
onSelect={(selection: { label: string; value: string }) => {
|
||||
handleInputChange('auth_type', selection.value);
|
||||
}}
|
||||
options={authTypes}
|
||||
size="w-full"
|
||||
rounded="3xl"
|
||||
border="border"
|
||||
/>
|
||||
|
||||
{renderAuthFields()}
|
||||
|
||||
<div>
|
||||
<Input
|
||||
name="timeout"
|
||||
type="number"
|
||||
className="rounded-md"
|
||||
value={formData.timeout}
|
||||
onChange={(e) => {
|
||||
const value = e.target.value;
|
||||
if (value === '') {
|
||||
handleInputChange('timeout', '');
|
||||
} else {
|
||||
const numValue = parseInt(value);
|
||||
if (!isNaN(numValue) && numValue >= 1) {
|
||||
handleInputChange('timeout', numValue);
|
||||
}
|
||||
}
|
||||
}}
|
||||
placeholder={t('settings.tools.mcp.timeout')}
|
||||
borderVariant="thin"
|
||||
labelBgClassName="bg-white dark:bg-charleston-green-2"
|
||||
/>
|
||||
{errors.timeout && (
|
||||
<p className="mt-2 text-sm text-red-600">{errors.timeout}</p>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{testResult && (
|
||||
<div
|
||||
className={`rounded-md p-5 ${
|
||||
testResult.success
|
||||
? 'bg-green-50 text-green-700 dark:bg-green-900/40 dark:text-green-300'
|
||||
: 'bg-red-50 text-red-700 dark:bg-red-900 dark:text-red-300'
|
||||
}`}
|
||||
>
|
||||
{testResult.message}
|
||||
</div>
|
||||
)}
|
||||
{errors.general && (
|
||||
<div className="rounded-2xl bg-red-50 p-5 text-red-700 dark:bg-red-900 dark:text-red-300">
|
||||
{errors.general}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="px-6 py-2">
|
||||
<div className="flex flex-col gap-4 sm:flex-row sm:justify-between">
|
||||
<button
|
||||
onClick={testConnection}
|
||||
disabled={testing}
|
||||
className="border-silver dark:border-dim-gray dark:text-light-gray w-full rounded-3xl border px-6 py-2 text-sm font-medium transition-all hover:bg-gray-100 disabled:opacity-50 sm:w-auto dark:hover:bg-[#767183]/50"
|
||||
>
|
||||
{testing ? (
|
||||
<div className="flex items-center justify-center">
|
||||
<Spinner size="small" />
|
||||
<span className="ml-2">
|
||||
{t('settings.tools.mcp.testing')}
|
||||
</span>
|
||||
</div>
|
||||
) : (
|
||||
t('settings.tools.mcp.testConnection')
|
||||
)}
|
||||
</button>
|
||||
|
||||
<div className="flex flex-col-reverse gap-3 sm:flex-row sm:gap-3">
|
||||
<button
|
||||
onClick={() => {
|
||||
setModalState('INACTIVE');
|
||||
resetForm();
|
||||
}}
|
||||
className="dark:text-light-gray w-full cursor-pointer rounded-3xl px-6 py-2 text-sm font-medium hover:bg-gray-100 sm:w-auto dark:bg-transparent dark:hover:bg-[#767183]/50"
|
||||
>
|
||||
{t('settings.tools.mcp.cancel')}
|
||||
</button>
|
||||
<button
|
||||
onClick={handleSave}
|
||||
disabled={loading}
|
||||
className="bg-purple-30 hover:bg-violets-are-blue w-full rounded-3xl px-6 py-2 text-sm font-medium text-white transition-all disabled:opacity-50 sm:w-auto"
|
||||
>
|
||||
{loading ? (
|
||||
<div className="flex items-center justify-center">
|
||||
<Spinner size="small" />
|
||||
<span className="ml-2">
|
||||
{t('settings.tools.mcp.saving')}
|
||||
</span>
|
||||
</div>
|
||||
) : (
|
||||
t('settings.tools.mcp.save')
|
||||
)}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</WrapperComponent>
|
||||
)
|
||||
);
|
||||
}
|
||||
@@ -60,7 +60,7 @@ export const ShareConversationModal = ({
|
||||
const [sourcePath, setSourcePath] = useState<{
|
||||
label: string;
|
||||
value: string;
|
||||
} | null>(preSelectedDoc ? extractDocPaths([preSelectedDoc])[0] : null);
|
||||
} | null>(preSelectedDoc ? extractDocPaths(preSelectedDoc)[0] : null);
|
||||
|
||||
const handleCopyKey = (url: string) => {
|
||||
navigator.clipboard.writeText(url);
|
||||
@@ -105,14 +105,14 @@ export const ShareConversationModal = ({
|
||||
return (
|
||||
<WrapperModal close={close}>
|
||||
<div className="flex flex-col gap-2">
|
||||
<h2 className="text-xl font-medium text-eerie-black dark:text-chinese-white">
|
||||
<h2 className="text-eerie-black dark:text-chinese-white text-xl font-medium">
|
||||
{t('modals.shareConv.label')}
|
||||
</h2>
|
||||
<p className="text-sm text-eerie-black dark:text-silver/60">
|
||||
<p className="text-eerie-black dark:text-silver/60 text-sm">
|
||||
{t('modals.shareConv.note')}
|
||||
</p>
|
||||
<div className="flex items-center justify-between">
|
||||
<span className="text-lg text-eerie-black dark:text-white">
|
||||
<span className="text-eerie-black text-lg dark:text-white">
|
||||
{t('modals.shareConv.option')}
|
||||
</span>
|
||||
<ToggleSwitch
|
||||
@@ -136,19 +136,19 @@ export const ShareConversationModal = ({
|
||||
</div>
|
||||
)}
|
||||
<div className="flex items-baseline justify-between gap-2">
|
||||
<span className="no-scrollbar w-full overflow-x-auto whitespace-nowrap rounded-full border-2 border-silver px-4 py-3 text-eerie-black dark:border-silver/40 dark:text-white">
|
||||
<span className="no-scrollbar border-silver text-eerie-black dark:border-silver/40 w-full overflow-x-auto rounded-full border-2 px-4 py-3 whitespace-nowrap dark:text-white">
|
||||
{`${domain}/share/${identifier ?? '....'}`}
|
||||
</span>
|
||||
{status === 'fetched' ? (
|
||||
<button
|
||||
className="my-1 h-10 w-28 rounded-full bg-purple-30 p-2 text-sm text-white hover:bg-violets-are-blue"
|
||||
className="bg-purple-30 hover:bg-violets-are-blue my-1 h-10 w-28 rounded-full p-2 text-sm text-white"
|
||||
onClick={() => handleCopyKey(`${domain}/share/${identifier}`)}
|
||||
>
|
||||
{isCopied ? t('modals.saveKey.copied') : t('modals.saveKey.copy')}
|
||||
</button>
|
||||
) : (
|
||||
<button
|
||||
className="my-1 flex h-10 w-28 items-center justify-evenly rounded-full bg-purple-30 p-2 text-center text-sm font-normal text-white hover:bg-violets-are-blue"
|
||||
className="bg-purple-30 hover:bg-violets-are-blue my-1 flex h-10 w-28 items-center justify-evenly rounded-full p-2 text-center text-sm font-normal text-white"
|
||||
onClick={() => {
|
||||
shareCoversationPublicly(allowPrompt);
|
||||
}}
|
||||
|
||||
@@ -90,9 +90,9 @@ export function getLocalApiKey(): string | null {
|
||||
return key;
|
||||
}
|
||||
|
||||
export function getLocalRecentDocs(): string | null {
|
||||
const doc = localStorage.getItem('DocsGPTRecentDocs');
|
||||
return doc;
|
||||
export function getLocalRecentDocs(): Doc[] | null {
|
||||
const docs = localStorage.getItem('DocsGPTRecentDocs');
|
||||
return docs ? (JSON.parse(docs) as Doc[]) : null;
|
||||
}
|
||||
|
||||
export function getLocalPrompt(): string | null {
|
||||
@@ -108,19 +108,20 @@ export function setLocalPrompt(prompt: string): void {
|
||||
localStorage.setItem('DocsGPTPrompt', prompt);
|
||||
}
|
||||
|
||||
export function setLocalRecentDocs(doc: Doc | null): void {
|
||||
localStorage.setItem('DocsGPTRecentDocs', JSON.stringify(doc));
|
||||
export function setLocalRecentDocs(docs: Doc[] | null): void {
|
||||
if (docs && docs.length > 0) {
|
||||
localStorage.setItem('DocsGPTRecentDocs', JSON.stringify(docs));
|
||||
|
||||
docs.forEach((doc) => {
|
||||
let docPath = 'default';
|
||||
if (doc?.type === 'local') {
|
||||
if (doc.type === 'local') {
|
||||
docPath = 'local' + '/' + doc.name + '/';
|
||||
}
|
||||
userService
|
||||
.checkDocs(
|
||||
{
|
||||
docs: docPath,
|
||||
},
|
||||
null,
|
||||
)
|
||||
.checkDocs({ docs: docPath }, null)
|
||||
.then((response) => response.json());
|
||||
});
|
||||
} else {
|
||||
localStorage.removeItem('DocsGPTRecentDocs');
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,7 +15,7 @@ export interface Preference {
|
||||
prompt: { name: string; id: string; type: string };
|
||||
chunks: string;
|
||||
token_limit: number;
|
||||
selectedDocs: Doc | null;
|
||||
selectedDocs: Doc[] | null;
|
||||
sourceDocs: Doc[] | null;
|
||||
conversations: {
|
||||
data: { name: string; id: string }[] | null;
|
||||
@@ -34,15 +34,16 @@ const initialState: Preference = {
|
||||
prompt: { name: 'default', id: 'default', type: 'public' },
|
||||
chunks: '2',
|
||||
token_limit: 2000,
|
||||
selectedDocs: {
|
||||
selectedDocs: [
|
||||
{
|
||||
id: 'default',
|
||||
name: 'default',
|
||||
type: 'remote',
|
||||
date: 'default',
|
||||
docLink: 'default',
|
||||
model: 'openai_text-embedding-ada-002',
|
||||
retriever: 'classic',
|
||||
} as Doc,
|
||||
},
|
||||
] as Doc[],
|
||||
sourceDocs: null,
|
||||
conversations: {
|
||||
data: null,
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
|
||||
import React, { useCallback, useEffect, useRef, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useDispatch, useSelector } from 'react-redux';
|
||||
@@ -319,7 +318,7 @@ export default function Sources({
|
||||
setSearchTerm(e.target.value);
|
||||
setCurrentPage(1);
|
||||
}}
|
||||
className="w-full h-[32px] rounded-full border border-silver dark:border-silver/40 bg-transparent px-3 text-sm text-jet dark:text-bright-gray placeholder:text-gray-400 dark:placeholder:text-gray-500 outline-none focus:border-silver dark:focus:border-silver/60"
|
||||
className="border-silver dark:border-silver/40 text-jet dark:text-bright-gray focus:border-silver dark:focus:border-silver/60 h-[32px] w-full rounded-full border bg-transparent px-3 text-sm outline-none placeholder:text-gray-400 dark:placeholder:text-gray-500"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
@@ -336,7 +335,7 @@ export default function Sources({
|
||||
</div>
|
||||
<div className="relative w-full">
|
||||
{loading ? (
|
||||
<div className="w-full grid grid-cols-1 sm:grid-cols-2 lg:grid-cols-3 xl:grid-cols-4 gap-6 px-2 py-4">
|
||||
<div className="grid w-full grid-cols-1 gap-6 px-2 py-4 sm:grid-cols-2 lg:grid-cols-3 xl:grid-cols-4">
|
||||
<SkeletonLoader component="sourceCards" count={rowsPerPage} />
|
||||
</div>
|
||||
) : !currentDocuments?.length ? (
|
||||
@@ -351,14 +350,15 @@ export default function Sources({
|
||||
</p>
|
||||
</div>
|
||||
) : (
|
||||
<div className="w-full grid grid-cols-1 sm:grid-cols-2 lg:grid-cols-3 xl:grid-cols-4 gap-6 px-2 py-4">
|
||||
<div className="grid w-full grid-cols-1 gap-6 px-2 py-4 sm:grid-cols-2 lg:grid-cols-3 xl:grid-cols-4">
|
||||
{currentDocuments.map((document, index) => {
|
||||
const docId = document.id ? document.id.toString() : '';
|
||||
|
||||
return (
|
||||
<div key={docId} className="relative">
|
||||
<div
|
||||
className={`flex h-[130px] w-full flex-col rounded-2xl bg-[#F9F9F9] p-3 transition-all duration-200 dark:bg-[#383838] ${activeMenuId === docId || syncMenuState.docId === docId
|
||||
className={`flex h-[130px] w-full flex-col rounded-2xl bg-[#F9F9F9] p-3 transition-all duration-200 dark:bg-[#383838] ${
|
||||
activeMenuId === docId || syncMenuState.docId === docId
|
||||
? 'scale-[1.05]'
|
||||
: 'hover:scale-[1.05]'
|
||||
}`}
|
||||
@@ -426,7 +426,7 @@ export default function Sources({
|
||||
<img
|
||||
src={CalendarIcon}
|
||||
alt=""
|
||||
className="w-[14px] h-[14px]"
|
||||
className="h-[14px] w-[14px]"
|
||||
/>
|
||||
<span className="font-inter text-[12px] leading-[18px] font-[500] text-[#848484] dark:text-[#848484]">
|
||||
{document.date ? formatDate(document.date) : ''}
|
||||
@@ -436,7 +436,7 @@ export default function Sources({
|
||||
<img
|
||||
src={DiscIcon}
|
||||
alt=""
|
||||
className="w-[14px] h-[14px]"
|
||||
className="h-[14px] w-[14px]"
|
||||
/>
|
||||
<span className="font-inter text-[12px] leading-[18px] font-[500] text-[#848484] dark:text-[#848484]">
|
||||
{document.tokens
|
||||
|
||||
@@ -30,9 +30,22 @@ export default function ToolConfig({
|
||||
handleGoBack: () => void;
|
||||
}) {
|
||||
const token = useSelector(selectToken);
|
||||
const [authKey, setAuthKey] = React.useState<string>(
|
||||
'token' in tool.config ? tool.config.token : '',
|
||||
);
|
||||
const [authKey, setAuthKey] = React.useState<string>(() => {
|
||||
if (tool.name === 'mcp_tool') {
|
||||
const config = tool.config as any;
|
||||
if (config.auth_type === 'api_key') {
|
||||
return config.api_key || '';
|
||||
} else if (config.auth_type === 'bearer') {
|
||||
return config.encrypted_token || '';
|
||||
} else if (config.auth_type === 'basic') {
|
||||
return config.password || '';
|
||||
}
|
||||
return '';
|
||||
} else if ('token' in tool.config) {
|
||||
return tool.config.token;
|
||||
}
|
||||
return '';
|
||||
});
|
||||
const [customName, setCustomName] = React.useState<string>(
|
||||
tool.customName || '',
|
||||
);
|
||||
@@ -97,6 +110,26 @@ export default function ToolConfig({
|
||||
};
|
||||
|
||||
const handleSaveChanges = () => {
|
||||
let configToSave;
|
||||
if (tool.name === 'api_tool') {
|
||||
configToSave = tool.config;
|
||||
} else if (tool.name === 'mcp_tool') {
|
||||
configToSave = { ...tool.config } as any;
|
||||
const mcpConfig = tool.config as any;
|
||||
|
||||
if (authKey.trim()) {
|
||||
if (mcpConfig.auth_type === 'api_key') {
|
||||
configToSave.api_key = authKey;
|
||||
} else if (mcpConfig.auth_type === 'bearer') {
|
||||
configToSave.encrypted_token = authKey;
|
||||
} else if (mcpConfig.auth_type === 'basic') {
|
||||
configToSave.password = authKey;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
configToSave = { token: authKey };
|
||||
}
|
||||
|
||||
userService
|
||||
.updateTool(
|
||||
{
|
||||
@@ -105,7 +138,7 @@ export default function ToolConfig({
|
||||
displayName: tool.displayName,
|
||||
customName: customName,
|
||||
description: tool.description,
|
||||
config: tool.name === 'api_tool' ? tool.config : { token: authKey },
|
||||
config: configToSave,
|
||||
actions: 'actions' in tool ? tool.actions : [],
|
||||
status: tool.status,
|
||||
},
|
||||
@@ -196,7 +229,15 @@ export default function ToolConfig({
|
||||
<div className="mt-1">
|
||||
{Object.keys(tool?.config).length !== 0 && tool.name !== 'api_tool' && (
|
||||
<p className="text-eerie-black dark:text-bright-gray text-sm font-semibold">
|
||||
{t('settings.tools.authentication')}
|
||||
{tool.name === 'mcp_tool'
|
||||
? (tool.config as any)?.auth_type === 'bearer'
|
||||
? 'Bearer Token'
|
||||
: (tool.config as any)?.auth_type === 'api_key'
|
||||
? 'API Key'
|
||||
: (tool.config as any)?.auth_type === 'basic'
|
||||
? 'Password'
|
||||
: t('settings.tools.authentication')
|
||||
: t('settings.tools.authentication')}
|
||||
</p>
|
||||
)}
|
||||
<div className="mt-4 flex flex-col items-start gap-2 sm:flex-row sm:items-center">
|
||||
@@ -208,7 +249,17 @@ export default function ToolConfig({
|
||||
value={authKey}
|
||||
onChange={(e) => setAuthKey(e.target.value)}
|
||||
borderVariant="thin"
|
||||
placeholder={t('modals.configTool.apiKeyPlaceholder')}
|
||||
placeholder={
|
||||
tool.name === 'mcp_tool'
|
||||
? (tool.config as any)?.auth_type === 'bearer'
|
||||
? 'Bearer Token'
|
||||
: (tool.config as any)?.auth_type === 'api_key'
|
||||
? 'API Key'
|
||||
: (tool.config as any)?.auth_type === 'basic'
|
||||
? 'Password'
|
||||
: t('modals.configTool.apiKeyPlaceholder')
|
||||
: t('modals.configTool.apiKeyPlaceholder')
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
@@ -450,6 +501,26 @@ export default function ToolConfig({
|
||||
setModalState={(state) => setShowUnsavedModal(state === 'ACTIVE')}
|
||||
submitLabel={t('settings.tools.saveAndLeave')}
|
||||
handleSubmit={() => {
|
||||
let configToSave;
|
||||
if (tool.name === 'api_tool') {
|
||||
configToSave = tool.config;
|
||||
} else if (tool.name === 'mcp_tool') {
|
||||
configToSave = { ...tool.config } as any;
|
||||
const mcpConfig = tool.config as any;
|
||||
|
||||
if (authKey.trim()) {
|
||||
if (mcpConfig.auth_type === 'api_key') {
|
||||
configToSave.api_key = authKey;
|
||||
} else if (mcpConfig.auth_type === 'bearer') {
|
||||
configToSave.encrypted_token = authKey;
|
||||
} else if (mcpConfig.auth_type === 'basic') {
|
||||
configToSave.password = authKey;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
configToSave = { token: authKey };
|
||||
}
|
||||
|
||||
userService
|
||||
.updateTool(
|
||||
{
|
||||
@@ -458,10 +529,7 @@ export default function ToolConfig({
|
||||
displayName: tool.displayName,
|
||||
customName: customName,
|
||||
description: tool.description,
|
||||
config:
|
||||
tool.name === 'api_tool'
|
||||
? tool.config
|
||||
: { token: authKey },
|
||||
config: configToSave,
|
||||
actions: 'actions' in tool ? tool.actions : [],
|
||||
status: tool.status,
|
||||
},
|
||||
|
||||
@@ -4,8 +4,15 @@ import { useTranslation } from 'react-i18next';
|
||||
import { useDispatch, useSelector } from 'react-redux';
|
||||
|
||||
import userService from '../api/services/userService';
|
||||
import { getSessionToken } from '../utils/providerUtils';
|
||||
|
||||
import {
|
||||
getSessionToken,
|
||||
setSessionToken,
|
||||
removeSessionToken,
|
||||
} from '../utils/providerUtils';
|
||||
import { formatDate } from '../utils/dateTimeUtils';
|
||||
import { formatBytes } from '../utils/stringUtils';
|
||||
import FileUpload from '../assets/file_upload.svg';
|
||||
import WebsiteCollect from '../assets/website_collect.svg';
|
||||
import Dropdown from '../components/Dropdown';
|
||||
import Input from '../components/Input';
|
||||
import ToggleSwitch from '../components/ToggleSwitch';
|
||||
@@ -377,7 +384,8 @@ function Upload({
|
||||
data?.find(
|
||||
(d: Doc) => d.type?.toLowerCase() === 'local',
|
||||
),
|
||||
));
|
||||
),
|
||||
);
|
||||
});
|
||||
setProgress(
|
||||
(progress) =>
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
* Follows the convention: {provider}_session_token
|
||||
*/
|
||||
|
||||
|
||||
export const getSessionToken = (provider: string): string | null => {
|
||||
return localStorage.getItem(`${provider}_session_token`);
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user