This commit is contained in:
ManishMadan2882
2025-09-16 14:59:18 +05:30
42 changed files with 2463 additions and 414 deletions

View File

@@ -149,7 +149,7 @@ class BaseAgent(ABC):
tool_call_data = {
"tool_name": "unknown",
"call_id": call_id,
"action_name": getattr(call, 'name', 'unknown'),
"action_name": getattr(call, "name", "unknown"),
"arguments": call_args or {},
"result": f"Failed to parse tool call. Invalid tool name format: {getattr(call, 'name', 'unknown')}",
}
@@ -225,6 +225,7 @@ class BaseAgent(ABC):
if tool_data["name"] == "api_tool"
else tool_data["config"]
),
user_id=self.user, # Pass user ID for MCP tools credential decryption
)
if tool_data["name"] == "api_tool":
print(

View File

@@ -0,0 +1,405 @@
import json
import time
from typing import Any, Dict, List, Optional
import requests
from application.agents.tools.base import Tool
from application.security.encryption import decrypt_credentials
_mcp_session_cache = {}
class MCPTool(Tool):
"""
MCP Tool
Connect to remote Model Context Protocol (MCP) servers to access dynamic tools and resources. Supports various authentication methods and provides secure access to external services through the MCP protocol.
"""
def __init__(self, config: Dict[str, Any], user_id: Optional[str] = None):
"""
Initialize the MCP Tool with configuration.
Args:
config: Dictionary containing MCP server configuration:
- server_url: URL of the remote MCP server
- auth_type: Type of authentication (api_key, bearer, basic, none)
- encrypted_credentials: Encrypted credentials (if available)
- timeout: Request timeout in seconds (default: 30)
user_id: User ID for decrypting credentials (required if encrypted_credentials exist)
"""
self.config = config
self.server_url = config.get("server_url", "")
self.auth_type = config.get("auth_type", "none")
self.timeout = config.get("timeout", 30)
self.auth_credentials = {}
if config.get("encrypted_credentials") and user_id:
self.auth_credentials = decrypt_credentials(
config["encrypted_credentials"], user_id
)
else:
self.auth_credentials = config.get("auth_credentials", {})
self.available_tools = []
self._session = requests.Session()
self._mcp_session_id = None
self._setup_authentication()
self._cache_key = self._generate_cache_key()
def _setup_authentication(self):
"""Setup authentication for the MCP server connection."""
if self.auth_type == "api_key":
api_key = self.auth_credentials.get("api_key", "")
header_name = self.auth_credentials.get("api_key_header", "X-API-Key")
if api_key:
self._session.headers.update({header_name: api_key})
elif self.auth_type == "bearer":
token = self.auth_credentials.get("bearer_token", "")
if token:
self._session.headers.update({"Authorization": f"Bearer {token}"})
elif self.auth_type == "basic":
username = self.auth_credentials.get("username", "")
password = self.auth_credentials.get("password", "")
if username and password:
self._session.auth = (username, password)
def _generate_cache_key(self) -> str:
"""Generate a unique cache key for this MCP server configuration."""
auth_key = ""
if self.auth_type == "bearer":
token = self.auth_credentials.get("bearer_token", "")
auth_key = f"bearer:{token[:10]}..." if token else "bearer:none"
elif self.auth_type == "api_key":
api_key = self.auth_credentials.get("api_key", "")
auth_key = f"apikey:{api_key[:10]}..." if api_key else "apikey:none"
elif self.auth_type == "basic":
username = self.auth_credentials.get("username", "")
auth_key = f"basic:{username}"
else:
auth_key = "none"
return f"{self.server_url}#{auth_key}"
def _get_cached_session(self) -> Optional[str]:
"""Get cached session ID if available and not expired."""
global _mcp_session_cache
if self._cache_key in _mcp_session_cache:
session_data = _mcp_session_cache[self._cache_key]
if time.time() - session_data["created_at"] < 1800:
return session_data["session_id"]
else:
del _mcp_session_cache[self._cache_key]
return None
def _cache_session(self, session_id: str):
"""Cache the session ID for reuse."""
global _mcp_session_cache
_mcp_session_cache[self._cache_key] = {
"session_id": session_id,
"created_at": time.time(),
}
def _initialize_mcp_connection(self) -> Dict:
"""
Initialize MCP connection with the server, using cached session if available.
Returns:
Server capabilities and information
"""
cached_session = self._get_cached_session()
if cached_session:
self._mcp_session_id = cached_session
return {"cached": True}
try:
init_params = {
"protocolVersion": "2024-11-05",
"capabilities": {"roots": {"listChanged": True}, "sampling": {}},
"clientInfo": {"name": "DocsGPT", "version": "1.0.0"},
}
response = self._make_mcp_request("initialize", init_params)
self._make_mcp_request("notifications/initialized")
return response
except Exception as e:
return {"error": str(e), "fallback": True}
def _ensure_valid_session(self):
"""Ensure we have a valid MCP session, reinitializing if needed."""
if not self._mcp_session_id:
self._initialize_mcp_connection()
def _make_mcp_request(self, method: str, params: Optional[Dict] = None) -> Dict:
"""
Make an MCP protocol request to the server with automatic session recovery.
Args:
method: MCP method name (e.g., "tools/list", "tools/call")
params: Parameters for the MCP method
Returns:
Response data as dictionary
Raises:
Exception: If request fails after retry
"""
mcp_message = {"jsonrpc": "2.0", "method": method}
if not method.startswith("notifications/"):
mcp_message["id"] = 1
if params:
mcp_message["params"] = params
return self._execute_mcp_request(mcp_message, method)
def _execute_mcp_request(
self, mcp_message: Dict, method: str, is_retry: bool = False
) -> Dict:
"""Execute MCP request with optional retry on session failure."""
try:
final_headers = self._session.headers.copy()
final_headers.update(
{
"Content-Type": "application/json",
"Accept": "application/json, text/event-stream",
}
)
if self._mcp_session_id:
final_headers["Mcp-Session-Id"] = self._mcp_session_id
response = self._session.post(
self.server_url.rstrip("/"),
json=mcp_message,
headers=final_headers,
timeout=self.timeout,
)
if "mcp-session-id" in response.headers:
self._mcp_session_id = response.headers["mcp-session-id"]
self._cache_session(self._mcp_session_id)
response.raise_for_status()
if method.startswith("notifications/"):
return {}
response_text = response.text.strip()
if response_text.startswith("event:") and "data:" in response_text:
lines = response_text.split("\n")
data_line = None
for line in lines:
if line.startswith("data:"):
data_line = line[5:].strip()
break
if data_line:
try:
result = json.loads(data_line)
except json.JSONDecodeError:
raise Exception(f"Invalid JSON in SSE data: {data_line}")
else:
raise Exception(f"No data found in SSE response: {response_text}")
else:
try:
result = response.json()
except json.JSONDecodeError:
raise Exception(f"Invalid JSON response: {response.text}")
if "error" in result:
error_msg = result["error"]
if isinstance(error_msg, dict):
error_msg = error_msg.get("message", str(error_msg))
raise Exception(f"MCP server error: {error_msg}")
return result.get("result", result)
except requests.exceptions.RequestException as e:
if not is_retry and self._should_retry_with_new_session(e):
self._invalidate_and_refresh_session()
return self._execute_mcp_request(mcp_message, method, is_retry=True)
raise Exception(f"MCP server request failed: {str(e)}")
def _should_retry_with_new_session(self, error: Exception) -> bool:
"""Check if error indicates session invalidation and retry is warranted."""
error_str = str(error).lower()
return (
any(
indicator in error_str
for indicator in [
"invalid session",
"session expired",
"unauthorized",
"401",
"403",
]
)
and self._mcp_session_id is not None
)
def _invalidate_and_refresh_session(self) -> None:
"""Invalidate current session and create a new one."""
global _mcp_session_cache
if self._cache_key in _mcp_session_cache:
del _mcp_session_cache[self._cache_key]
self._mcp_session_id = None
self._initialize_mcp_connection()
def discover_tools(self) -> List[Dict]:
"""
Discover available tools from the MCP server using MCP protocol.
Returns:
List of tool definitions from the server
"""
try:
self._ensure_valid_session()
response = self._make_mcp_request("tools/list")
# Handle both formats: response with 'tools' key or response that IS the tools list
if isinstance(response, dict):
if "tools" in response:
self.available_tools = response["tools"]
elif (
"result" in response
and isinstance(response["result"], dict)
and "tools" in response["result"]
):
self.available_tools = response["result"]["tools"]
else:
self.available_tools = [response] if response else []
elif isinstance(response, list):
self.available_tools = response
else:
self.available_tools = []
return self.available_tools
except Exception as e:
raise Exception(f"Failed to discover tools from MCP server: {str(e)}")
def execute_action(self, action_name: str, **kwargs) -> Any:
"""
Execute an action on the remote MCP server using MCP protocol.
Args:
action_name: Name of the action to execute
**kwargs: Parameters for the action
Returns:
Result from the MCP server
"""
self._ensure_valid_session()
# Skipping empty/None values - letting the server use defaults
cleaned_kwargs = {}
for key, value in kwargs.items():
if value == "" or value is None:
continue
cleaned_kwargs[key] = value
call_params = {"name": action_name, "arguments": cleaned_kwargs}
try:
result = self._make_mcp_request("tools/call", call_params)
return result
except Exception as e:
raise Exception(f"Failed to execute action '{action_name}': {str(e)}")
def get_actions_metadata(self) -> List[Dict]:
"""
Get metadata for all available actions.
Returns:
List of action metadata dictionaries
"""
actions = []
for tool in self.available_tools:
input_schema = (
tool.get("inputSchema")
or tool.get("input_schema")
or tool.get("schema")
or tool.get("parameters")
)
parameters_schema = {
"type": "object",
"properties": {},
"required": [],
}
if input_schema:
if isinstance(input_schema, dict):
if "properties" in input_schema:
parameters_schema = {
"type": input_schema.get("type", "object"),
"properties": input_schema.get("properties", {}),
"required": input_schema.get("required", []),
}
for key in ["additionalProperties", "description"]:
if key in input_schema:
parameters_schema[key] = input_schema[key]
else:
parameters_schema["properties"] = input_schema
action = {
"name": tool.get("name", ""),
"description": tool.get("description", ""),
"parameters": parameters_schema,
}
actions.append(action)
return actions
def test_connection(self) -> Dict:
"""
Test the connection to the MCP server and validate functionality.
Returns:
Dictionary with connection test results including tool count
"""
try:
self._mcp_session_id = None
init_result = self._initialize_mcp_connection()
tools = self.discover_tools()
message = f"Successfully connected to MCP server. Found {len(tools)} tools."
if init_result.get("cached"):
message += " (Using cached session)"
elif init_result.get("fallback"):
message += " (No formal initialization required)"
return {
"success": True,
"message": message,
"tools_count": len(tools),
"session_id": self._mcp_session_id,
"tools": [tool.get("name", "unknown") for tool in tools[:5]],
}
except Exception as e:
return {
"success": False,
"message": f"Connection failed: {str(e)}",
"tools_count": 0,
"error_type": type(e).__name__,
}
def get_config_requirements(self) -> Dict:
return {
"server_url": {
"type": "string",
"description": "URL of the remote MCP server (e.g., https://api.example.com)",
"required": True,
},
"auth_type": {
"type": "string",
"description": "Authentication type",
"enum": ["none", "api_key", "bearer", "basic"],
"default": "none",
"required": True,
},
"auth_credentials": {
"type": "object",
"description": "Authentication credentials (varies by auth_type)",
"required": False,
},
"timeout": {
"type": "integer",
"description": "Request timeout in seconds",
"default": 30,
"minimum": 1,
"maximum": 300,
"required": False,
},
}

View File

@@ -23,16 +23,23 @@ class ToolManager:
tool_config = self.config.get(name, {})
self.tools[name] = obj(tool_config)
def load_tool(self, tool_name, tool_config):
def load_tool(self, tool_name, tool_config, user_id=None):
self.config[tool_name] = tool_config
module = importlib.import_module(f"application.agents.tools.{tool_name}")
for member_name, obj in inspect.getmembers(module, inspect.isclass):
if issubclass(obj, Tool) and obj is not Tool:
if tool_name == "mcp_tool" and user_id:
return obj(tool_config, user_id)
else:
return obj(tool_config)
def execute_action(self, tool_name, action_name, **kwargs):
def execute_action(self, tool_name, action_name, user_id=None, **kwargs):
if tool_name not in self.tools:
raise ValueError(f"Tool '{tool_name}' not loaded")
if tool_name == "mcp_tool" and user_id:
tool_config = self.config.get(tool_name, {})
tool = self.load_tool(tool_name, tool_config, user_id)
return tool.execute_action(action_name, **kwargs)
return self.tools[tool_name].execute_action(action_name, **kwargs)
def get_all_actions_metadata(self):

View File

@@ -69,11 +69,8 @@ class StreamProcessor:
self.decoded_token.get("sub") if self.decoded_token is not None else None
)
self.conversation_id = self.data.get("conversation_id")
self.source = (
{"active_docs": self.data["active_docs"]}
if "active_docs" in self.data
else {}
)
self.source = {}
self.all_sources = []
self.attachments = []
self.history = []
self.agent_config = {}
@@ -85,6 +82,8 @@ class StreamProcessor:
def initialize(self):
"""Initialize all required components for processing"""
self._configure_agent()
self._configure_source()
self._configure_retriever()
self._configure_agent()
self._load_conversation_history()
@@ -171,13 +170,77 @@ class StreamProcessor:
source = data.get("source")
if isinstance(source, DBRef):
source_doc = self.db.dereference(source)
if source_doc:
data["source"] = str(source_doc["_id"])
data["retriever"] = source_doc.get("retriever", data.get("retriever"))
data["chunks"] = source_doc.get("chunks", data.get("chunks"))
else:
data["source"] = None
elif source == "default":
data["source"] = "default"
else:
data["source"] = None
# Handle multiple sources
sources = data.get("sources", [])
if sources and isinstance(sources, list):
sources_list = []
for i, source_ref in enumerate(sources):
if source_ref == "default":
processed_source = {
"id": "default",
"retriever": "classic",
"chunks": data.get("chunks", "2"),
}
sources_list.append(processed_source)
elif isinstance(source_ref, DBRef):
source_doc = self.db.dereference(source_ref)
if source_doc:
processed_source = {
"id": str(source_doc["_id"]),
"retriever": source_doc.get("retriever", "classic"),
"chunks": source_doc.get("chunks", data.get("chunks", "2")),
}
sources_list.append(processed_source)
data["sources"] = sources_list
else:
data["sources"] = []
return data
def _configure_source(self):
"""Configure the source based on agent data"""
api_key = self.data.get("api_key") or self.agent_key
if api_key:
agent_data = self._get_data_from_api_key(api_key)
if agent_data.get("sources") and len(agent_data["sources"]) > 0:
source_ids = [
source["id"] for source in agent_data["sources"] if source.get("id")
]
if source_ids:
self.source = {"active_docs": source_ids}
else:
self.source = {}
self.all_sources = agent_data["sources"]
elif agent_data.get("source"):
self.source = {"active_docs": agent_data["source"]}
self.all_sources = [
{
"id": agent_data["source"],
"retriever": agent_data.get("retriever", "classic"),
}
]
else:
self.source = {}
self.all_sources = []
return
if "active_docs" in self.data:
self.source = {"active_docs": self.data["active_docs"]}
return
self.source = {}
self.all_sources = []
def _configure_agent(self):
"""Configure the agent based on request data"""
agent_id = self.data.get("agent_id")
@@ -203,7 +266,13 @@ class StreamProcessor:
if data_key.get("retriever"):
self.retriever_config["retriever_name"] = data_key["retriever"]
if data_key.get("chunks") is not None:
self.retriever_config["chunks"] = data_key["chunks"]
try:
self.retriever_config["chunks"] = int(data_key["chunks"])
except (ValueError, TypeError):
logger.warning(
f"Invalid chunks value: {data_key['chunks']}, using default value 2"
)
self.retriever_config["chunks"] = 2
elif self.agent_key:
data_key = self._get_data_from_api_key(self.agent_key)
self.agent_config.update(
@@ -224,7 +293,13 @@ class StreamProcessor:
if data_key.get("retriever"):
self.retriever_config["retriever_name"] = data_key["retriever"]
if data_key.get("chunks") is not None:
self.retriever_config["chunks"] = data_key["chunks"]
try:
self.retriever_config["chunks"] = int(data_key["chunks"])
except (ValueError, TypeError):
logger.warning(
f"Invalid chunks value: {data_key['chunks']}, using default value 2"
)
self.retriever_config["chunks"] = 2
else:
self.agent_config.update(
{
@@ -243,7 +318,8 @@ class StreamProcessor:
"token_limit": self.data.get("token_limit", settings.DEFAULT_MAX_HISTORY),
}
if "isNoneDoc" in self.data and self.data["isNoneDoc"]:
api_key = self.data.get("api_key") or self.agent_key
if not api_key and "isNoneDoc" in self.data and self.data["isNoneDoc"]:
self.retriever_config["chunks"] = 0
def create_agent(self):

View File

@@ -1,6 +1,5 @@
import datetime
import json
import logging
from bson.objectid import ObjectId

File diff suppressed because it is too large Load Diff

View File

@@ -26,7 +26,7 @@ class Settings(BaseSettings):
"gpt-4o-mini": 128000,
"gpt-3.5-turbo": 4096,
"claude-2": 1e5,
"gemini-2.0-flash-exp": 1e6,
"gemini-2.5-flash": 1e6,
}
UPLOAD_FOLDER: str = "inputs"
PARSE_PDF_AS_IMAGE: bool = False
@@ -116,6 +116,9 @@ class Settings(BaseSettings):
JWT_SECRET_KEY: str = ""
# Encryption settings
ENCRYPTION_SECRET_KEY: str = "default-docsgpt-encryption-key"
path = Path(__file__).parent.parent.absolute()
settings = Settings(_env_file=path.joinpath(".env"), _env_file_encoding="utf-8")

View File

@@ -143,6 +143,7 @@ class GoogleLLM(BaseLLM):
raise
def _clean_messages_google(self, messages):
"""Convert OpenAI format messages to Google AI format."""
cleaned_messages = []
for message in messages:
role = message.get("role")
@@ -150,6 +151,8 @@ class GoogleLLM(BaseLLM):
if role == "assistant":
role = "model"
elif role == "tool":
role = "model"
parts = []
if role and content is not None:
@@ -188,11 +191,63 @@ class GoogleLLM(BaseLLM):
else:
raise ValueError(f"Unexpected content type: {type(content)}")
if parts:
cleaned_messages.append(types.Content(role=role, parts=parts))
return cleaned_messages
def _clean_schema(self, schema_obj):
"""
Recursively remove unsupported fields from schema objects
and validate required properties.
"""
if not isinstance(schema_obj, dict):
return schema_obj
allowed_fields = {
"type",
"description",
"items",
"properties",
"required",
"enum",
"pattern",
"minimum",
"maximum",
"nullable",
"default",
}
cleaned = {}
for key, value in schema_obj.items():
if key not in allowed_fields:
continue
elif key == "type" and isinstance(value, str):
cleaned[key] = value.upper()
elif isinstance(value, dict):
cleaned[key] = self._clean_schema(value)
elif isinstance(value, list):
cleaned[key] = [self._clean_schema(item) for item in value]
else:
cleaned[key] = value
# Validate that required properties actually exist in properties
if "required" in cleaned and "properties" in cleaned:
valid_required = []
properties_keys = set(cleaned["properties"].keys())
for required_prop in cleaned["required"]:
if required_prop in properties_keys:
valid_required.append(required_prop)
if valid_required:
cleaned["required"] = valid_required
else:
cleaned.pop("required", None)
elif "required" in cleaned and "properties" not in cleaned:
cleaned.pop("required", None)
return cleaned
def _clean_tools_format(self, tools_list):
"""Convert OpenAI format tools to Google AI format."""
genai_tools = []
for tool_data in tools_list:
if tool_data["type"] == "function":
@@ -201,18 +256,16 @@ class GoogleLLM(BaseLLM):
properties = parameters.get("properties", {})
if properties:
cleaned_properties = {}
for k, v in properties.items():
cleaned_properties[k] = self._clean_schema(v)
genai_function = dict(
name=function["name"],
description=function["description"],
parameters={
"type": "OBJECT",
"properties": {
k: {
**v,
"type": v["type"].upper() if v["type"] else None,
}
for k, v in properties.items()
},
"properties": cleaned_properties,
"required": (
parameters["required"]
if "required" in parameters
@@ -242,6 +295,7 @@ class GoogleLLM(BaseLLM):
response_schema=None,
**kwargs,
):
"""Generate content using Google AI API without streaming."""
client = genai.Client(api_key=self.api_key)
if formatting == "openai":
messages = self._clean_messages_google(messages)
@@ -281,6 +335,7 @@ class GoogleLLM(BaseLLM):
response_schema=None,
**kwargs,
):
"""Generate content using Google AI API with streaming."""
client = genai.Client(api_key=self.api_key)
if formatting == "openai":
messages = self._clean_messages_google(messages)
@@ -331,12 +386,15 @@ class GoogleLLM(BaseLLM):
yield chunk.text
def _supports_tools(self):
"""Return whether this LLM supports function calling."""
return True
def _supports_structured_output(self):
"""Return whether this LLM supports structured JSON output."""
return True
def prepare_structured_output_format(self, json_schema):
"""Convert JSON schema to Google AI structured output format."""
if not json_schema:
return None

View File

@@ -205,7 +205,6 @@ class LLMHandler(ABC):
except StopIteration as e:
tool_response, call_id = e.value
break
updated_messages.append(
{
"role": "assistant",
@@ -222,17 +221,36 @@ class LLMHandler(ABC):
)
updated_messages.append(self.create_tool_message(call, tool_response))
except Exception as e:
logger.error(f"Error executing tool: {str(e)}", exc_info=True)
updated_messages.append(
{
"role": "tool",
"content": f"Error executing tool: {str(e)}",
"tool_call_id": call.id,
}
error_call = ToolCall(
id=call.id, name=call.name, arguments=call.arguments
)
error_response = f"Error executing tool: {str(e)}"
error_message = self.create_tool_message(error_call, error_response)
updated_messages.append(error_message)
call_parts = call.name.split("_")
if len(call_parts) >= 2:
tool_id = call_parts[-1] # Last part is tool ID (e.g., "1")
action_name = "_".join(call_parts[:-1])
tool_name = tools_dict.get(tool_id, {}).get("name", "unknown_tool")
full_action_name = f"{action_name}_{tool_id}"
else:
tool_name = "unknown_tool"
action_name = call.name
full_action_name = call.name
yield {
"type": "tool_call",
"data": {
"tool_name": tool_name,
"call_id": call.id,
"action_name": full_action_name,
"arguments": call.arguments,
"error": error_response,
"status": "error",
},
}
return updated_messages
def handle_non_streaming(
@@ -263,13 +281,11 @@ class LLMHandler(ABC):
except StopIteration as e:
messages = e.value
break
response = agent.llm.gen(
model=agent.gpt_model, messages=messages, tools=agent.tools
)
parsed = self.parse_response(response)
self.llm_calls.append(build_stack_data(agent.llm))
return parsed.content
def handle_streaming(

View File

@@ -17,7 +17,6 @@ class GoogleLLMHandler(LLMHandler):
finish_reason="stop",
raw_response=response,
)
if hasattr(response, "candidates"):
parts = response.candidates[0].content.parts if response.candidates else []
tool_calls = [
@@ -41,7 +40,6 @@ class GoogleLLMHandler(LLMHandler):
finish_reason="tool_calls" if tool_calls else "stop",
raw_response=response,
)
else:
tool_calls = []
if hasattr(response, "function_call"):
@@ -61,14 +59,16 @@ class GoogleLLMHandler(LLMHandler):
def create_tool_message(self, tool_call: ToolCall, result: Any) -> Dict:
"""Create Google-style tool message."""
from google.genai import types
return {
"role": "tool",
"role": "model",
"content": [
types.Part.from_function_response(
name=tool_call.name, response={"result": result}
).to_json_dict()
{
"function_response": {
"name": tool_call.name,
"response": {"result": result},
}
}
],
}

View File

@@ -2,6 +2,7 @@ anthropic==0.49.0
boto3==1.38.18
beautifulsoup4==4.13.4
celery==5.4.0
cryptography==42.0.8
dataclasses-json==0.6.7
docx2txt==0.8
duckduckgo-search==7.5.2

View File

@@ -5,10 +5,6 @@ class BaseRetriever(ABC):
def __init__(self):
pass
@abstractmethod
def gen(self, *args, **kwargs):
pass
@abstractmethod
def search(self, *args, **kwargs):
pass

View File

@@ -1,4 +1,5 @@
import logging
from application.core.settings import settings
from application.llm.llm_creator import LLMCreator
from application.retriever.base import BaseRetriever
@@ -20,9 +21,19 @@ class ClassicRAG(BaseRetriever):
api_key=settings.API_KEY,
decoded_token=None,
):
self.original_question = ""
"""Initialize ClassicRAG retriever with vectorstore sources and LLM configuration"""
self.original_question = source.get("question", "")
self.chat_history = chat_history if chat_history is not None else []
self.prompt = prompt
if isinstance(chunks, str):
try:
self.chunks = int(chunks)
except ValueError:
logging.warning(
f"Invalid chunks value '{chunks}', using default value 2"
)
self.chunks = 2
else:
self.chunks = chunks
self.gpt_model = gpt_model
self.token_limit = (
@@ -44,25 +55,52 @@ class ClassicRAG(BaseRetriever):
user_api_key=self.user_api_key,
decoded_token=decoded_token,
)
self.vectorstore = source["active_docs"] if "active_docs" in source else None
if "active_docs" in source and source["active_docs"] is not None:
if isinstance(source["active_docs"], list):
self.vectorstores = source["active_docs"]
else:
self.vectorstores = [source["active_docs"]]
else:
self.vectorstores = []
self.question = self._rephrase_query()
self.decoded_token = decoded_token
self._validate_vectorstore_config()
def _validate_vectorstore_config(self):
"""Validate vectorstore IDs and remove any empty/invalid entries"""
if not self.vectorstores:
logging.warning("No vectorstores configured for retrieval")
return
invalid_ids = [
vs_id for vs_id in self.vectorstores if not vs_id or not vs_id.strip()
]
if invalid_ids:
logging.warning(f"Found invalid vectorstore IDs: {invalid_ids}")
self.vectorstores = [
vs_id for vs_id in self.vectorstores if vs_id and vs_id.strip()
]
def _rephrase_query(self):
"""Rephrase user query with chat history context for better retrieval"""
if (
not self.original_question
or not self.chat_history
or self.chat_history == []
or self.chunks == 0
or self.vectorstore is None
or not self.vectorstores
):
return self.original_question
prompt = f"""Given the following conversation history:
{self.chat_history}
Rephrase the following user question to be a standalone search query
that captures all relevant context from the conversation:
"""
messages = [
@@ -79,44 +117,62 @@ class ClassicRAG(BaseRetriever):
return self.original_question
def _get_data(self):
if self.chunks == 0 or self.vectorstore is None:
docs = []
else:
"""Retrieve relevant documents from configured vectorstores"""
if self.chunks == 0 or not self.vectorstores:
return []
all_docs = []
chunks_per_source = max(1, self.chunks // len(self.vectorstores))
for vectorstore_id in self.vectorstores:
if vectorstore_id:
try:
docsearch = VectorCreator.create_vectorstore(
settings.VECTOR_STORE, self.vectorstore, settings.EMBEDDINGS_KEY
settings.VECTOR_STORE, vectorstore_id, settings.EMBEDDINGS_KEY
)
docs_temp = docsearch.search(self.question, k=self.chunks)
docs = [
docs_temp = docsearch.search(self.question, k=chunks_per_source)
for doc in docs_temp:
if hasattr(doc, "page_content") and hasattr(doc, "metadata"):
page_content = doc.page_content
metadata = doc.metadata
else:
page_content = doc.get("text", doc.get("page_content", ""))
metadata = doc.get("metadata", {})
title = metadata.get(
"title", metadata.get("post_title", page_content)
)
if isinstance(title, str):
title = title.split("/")[-1]
else:
title = str(title).split("/")[-1]
all_docs.append(
{
"title": i.metadata.get(
"title", i.metadata.get("post_title", i.page_content)
).split("/")[-1],
"text": i.page_content,
"source": (
i.metadata.get("source")
if i.metadata.get("source")
else "local"
),
"title": title,
"text": page_content,
"source": metadata.get("source") or vectorstore_id,
}
for i in docs_temp
]
return docs
def gen():
pass
)
except Exception as e:
logging.error(
f"Error searching vectorstore {vectorstore_id}: {e}",
exc_info=True,
)
continue
return all_docs
def search(self, query: str = ""):
"""Search for documents using optional query override"""
if query:
self.original_question = query
self.question = self._rephrase_query()
return self._get_data()
def get_params(self):
"""Return current retriever configuration parameters"""
return {
"question": self.original_question,
"rephrased_question": self.question,
"source": self.vectorstore,
"sources": self.vectorstores,
"chunks": self.chunks,
"token_limit": self.token_limit,
"gpt_model": self.gpt_model,

View File

View File

@@ -0,0 +1,85 @@
import base64
import json
import os
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.ciphers import algorithms, Cipher, modes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
from application.core.settings import settings
def _derive_key(user_id: str, salt: bytes) -> bytes:
app_secret = settings.ENCRYPTION_SECRET_KEY
password = f"{app_secret}#{user_id}".encode()
kdf = PBKDF2HMAC(
algorithm=hashes.SHA256(),
length=32,
salt=salt,
iterations=100000,
backend=default_backend(),
)
return kdf.derive(password)
def encrypt_credentials(credentials: dict, user_id: str) -> str:
if not credentials:
return ""
try:
salt = os.urandom(16)
iv = os.urandom(16)
key = _derive_key(user_id, salt)
json_str = json.dumps(credentials)
cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend())
encryptor = cipher.encryptor()
padded_data = _pad_data(json_str.encode())
encrypted_data = encryptor.update(padded_data) + encryptor.finalize()
result = salt + iv + encrypted_data
return base64.b64encode(result).decode()
except Exception as e:
print(f"Warning: Failed to encrypt credentials: {e}")
return ""
def decrypt_credentials(encrypted_data: str, user_id: str) -> dict:
if not encrypted_data:
return {}
try:
data = base64.b64decode(encrypted_data.encode())
salt = data[:16]
iv = data[16:32]
encrypted_content = data[32:]
key = _derive_key(user_id, salt)
cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend())
decryptor = cipher.decryptor()
decrypted_padded = decryptor.update(encrypted_content) + decryptor.finalize()
decrypted_data = _unpad_data(decrypted_padded)
return json.loads(decrypted_data.decode())
except Exception as e:
print(f"Warning: Failed to decrypt credentials: {e}")
return {}
def _pad_data(data: bytes) -> bytes:
block_size = 16
padding_len = block_size - (len(data) % block_size)
padding = bytes([padding_len]) * padding_len
return data + padding
def _unpad_data(data: bytes) -> bytes:
padding_len = data[-1]
return data[:-padding_len]

View File

@@ -1,12 +1,20 @@
from abc import ABC, abstractmethod
import os
from sentence_transformers import SentenceTransformer
from abc import ABC, abstractmethod
from langchain_openai import OpenAIEmbeddings
from sentence_transformers import SentenceTransformer
from application.core.settings import settings
class EmbeddingsWrapper:
def __init__(self, model_name, *args, **kwargs):
self.model = SentenceTransformer(model_name, config_kwargs={'allow_dangerous_deserialization': True}, *args, **kwargs)
self.model = SentenceTransformer(
model_name,
config_kwargs={"allow_dangerous_deserialization": True},
*args,
**kwargs
)
self.dimension = self.model.get_sentence_embedding_dimension()
def embed_query(self, query: str):
@@ -24,15 +32,14 @@ class EmbeddingsWrapper:
raise ValueError("Input must be a string or a list of strings")
class EmbeddingsSingleton:
_instances = {}
@staticmethod
def get_instance(embeddings_name, *args, **kwargs):
if embeddings_name not in EmbeddingsSingleton._instances:
EmbeddingsSingleton._instances[embeddings_name] = EmbeddingsSingleton._create_instance(
embeddings_name, *args, **kwargs
EmbeddingsSingleton._instances[embeddings_name] = (
EmbeddingsSingleton._create_instance(embeddings_name, *args, **kwargs)
)
return EmbeddingsSingleton._instances[embeddings_name]
@@ -40,9 +47,15 @@ class EmbeddingsSingleton:
def _create_instance(embeddings_name, *args, **kwargs):
embeddings_factory = {
"openai_text-embedding-ada-002": OpenAIEmbeddings,
"huggingface_sentence-transformers/all-mpnet-base-v2": lambda: EmbeddingsWrapper("sentence-transformers/all-mpnet-base-v2"),
"huggingface_sentence-transformers-all-mpnet-base-v2": lambda: EmbeddingsWrapper("sentence-transformers/all-mpnet-base-v2"),
"huggingface_hkunlp/instructor-large": lambda: EmbeddingsWrapper("hkunlp/instructor-large"),
"huggingface_sentence-transformers/all-mpnet-base-v2": lambda: EmbeddingsWrapper(
"sentence-transformers/all-mpnet-base-v2"
),
"huggingface_sentence-transformers-all-mpnet-base-v2": lambda: EmbeddingsWrapper(
"sentence-transformers/all-mpnet-base-v2"
),
"huggingface_hkunlp/instructor-large": lambda: EmbeddingsWrapper(
"hkunlp/instructor-large"
),
}
if embeddings_name in embeddings_factory:
@@ -50,29 +63,58 @@ class EmbeddingsSingleton:
else:
return EmbeddingsWrapper(embeddings_name, *args, **kwargs)
class BaseVectorStore(ABC):
def __init__(self):
pass
@abstractmethod
def search(self, *args, **kwargs):
"""Search for similar documents/chunks in the vectorstore"""
pass
@abstractmethod
def add_texts(self, texts, metadatas=None, *args, **kwargs):
"""Add texts with their embeddings to the vectorstore"""
pass
def delete_index(self, *args, **kwargs):
"""Delete the entire index/collection"""
pass
def save_local(self, *args, **kwargs):
"""Save vectorstore to local storage"""
pass
def get_chunks(self, *args, **kwargs):
"""Get all chunks from the vectorstore"""
pass
def add_chunk(self, text, metadata=None, *args, **kwargs):
"""Add a single chunk to the vectorstore"""
pass
def delete_chunk(self, chunk_id, *args, **kwargs):
"""Delete a specific chunk from the vectorstore"""
pass
def is_azure_configured(self):
return settings.OPENAI_API_BASE and settings.OPENAI_API_VERSION and settings.AZURE_DEPLOYMENT_NAME
return (
settings.OPENAI_API_BASE
and settings.OPENAI_API_VERSION
and settings.AZURE_DEPLOYMENT_NAME
)
def _get_embeddings(self, embeddings_name, embeddings_key=None):
if embeddings_name == "openai_text-embedding-ada-002":
if self.is_azure_configured():
os.environ["OPENAI_API_TYPE"] = "azure"
embedding_instance = EmbeddingsSingleton.get_instance(
embeddings_name,
model=settings.AZURE_EMBEDDINGS_DEPLOYMENT_NAME
embeddings_name, model=settings.AZURE_EMBEDDINGS_DEPLOYMENT_NAME
)
else:
embedding_instance = EmbeddingsSingleton.get_instance(
embeddings_name,
openai_api_key=embeddings_key
embeddings_name, openai_api_key=embeddings_key
)
elif embeddings_name == "huggingface_sentence-transformers/all-mpnet-base-v2":
if os.path.exists("./models/all-mpnet-base-v2"):
@@ -87,4 +129,3 @@ class BaseVectorStore(ABC):
embedding_instance = EmbeddingsSingleton.get_instance(embeddings_name)
return embedding_instance

View File

@@ -0,0 +1,4 @@
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" width="64" height="64" color="#000000" fill="none">
<path d="M3.49994 11.7501L11.6717 3.57855C12.7762 2.47398 14.5672 2.47398 15.6717 3.57855C16.7762 4.68312 16.7762 6.47398 15.6717 7.57855M15.6717 7.57855L9.49994 13.7501M15.6717 7.57855C16.7762 6.47398 18.5672 6.47398 19.6717 7.57855C20.7762 8.68312 20.7762 10.474 19.6717 11.5785L12.7072 18.543C12.3167 18.9335 12.3167 19.5667 12.7072 19.9572L13.9999 21.2499" stroke="currentColor" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"></path>
<path d="M17.4999 9.74921L11.3282 15.921C10.2237 17.0255 8.43272 17.0255 7.32823 15.921C6.22373 14.8164 6.22373 13.0255 7.32823 11.921L13.4999 5.74939" stroke="currentColor" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"></path>
</svg>

After

Width:  |  Height:  |  Size: 831 B

View File

@@ -45,6 +45,7 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
description: '',
image: '',
source: '',
sources: [],
chunks: '',
retriever: '',
prompt_id: 'default',
@@ -150,7 +151,41 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
const formData = new FormData();
formData.append('name', agent.name);
formData.append('description', agent.description);
formData.append('source', agent.source);
if (selectedSourceIds.size > 1) {
const sourcesArray = Array.from(selectedSourceIds)
.map((id) => {
const sourceDoc = sourceDocs?.find(
(source) =>
source.id === id || source.retriever === id || source.name === id,
);
if (sourceDoc?.name === 'Default' && !sourceDoc?.id) {
return 'default';
}
return sourceDoc?.id || id;
})
.filter(Boolean);
formData.append('sources', JSON.stringify(sourcesArray));
formData.append('source', '');
} else if (selectedSourceIds.size === 1) {
const singleSourceId = Array.from(selectedSourceIds)[0];
const sourceDoc = sourceDocs?.find(
(source) =>
source.id === singleSourceId ||
source.retriever === singleSourceId ||
source.name === singleSourceId,
);
let finalSourceId;
if (sourceDoc?.name === 'Default' && !sourceDoc?.id)
finalSourceId = 'default';
else finalSourceId = sourceDoc?.id || singleSourceId;
formData.append('source', String(finalSourceId));
formData.append('sources', JSON.stringify([]));
} else {
formData.append('source', '');
formData.append('sources', JSON.stringify([]));
}
formData.append('chunks', agent.chunks);
formData.append('retriever', agent.retriever);
formData.append('prompt_id', agent.prompt_id);
@@ -196,7 +231,41 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
const formData = new FormData();
formData.append('name', agent.name);
formData.append('description', agent.description);
formData.append('source', agent.source);
if (selectedSourceIds.size > 1) {
const sourcesArray = Array.from(selectedSourceIds)
.map((id) => {
const sourceDoc = sourceDocs?.find(
(source) =>
source.id === id || source.retriever === id || source.name === id,
);
if (sourceDoc?.name === 'Default' && !sourceDoc?.id) {
return 'default';
}
return sourceDoc?.id || id;
})
.filter(Boolean);
formData.append('sources', JSON.stringify(sourcesArray));
formData.append('source', '');
} else if (selectedSourceIds.size === 1) {
const singleSourceId = Array.from(selectedSourceIds)[0];
const sourceDoc = sourceDocs?.find(
(source) =>
source.id === singleSourceId ||
source.retriever === singleSourceId ||
source.name === singleSourceId,
);
let finalSourceId;
if (sourceDoc?.name === 'Default' && !sourceDoc?.id)
finalSourceId = 'default';
else finalSourceId = sourceDoc?.id || singleSourceId;
formData.append('source', String(finalSourceId));
formData.append('sources', JSON.stringify([]));
} else {
formData.append('source', '');
formData.append('sources', JSON.stringify([]));
}
formData.append('chunks', agent.chunks);
formData.append('retriever', agent.retriever);
formData.append('prompt_id', agent.prompt_id);
@@ -293,9 +362,33 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
throw new Error('Failed to fetch agent');
}
const data = await response.json();
if (data.source) setSelectedSourceIds(new Set([data.source]));
else if (data.retriever)
if (data.sources && data.sources.length > 0) {
const mappedSources = data.sources.map((sourceId: string) => {
if (sourceId === 'default') {
const defaultSource = sourceDocs?.find(
(source) => source.name === 'Default',
);
return defaultSource?.retriever || 'classic';
}
return sourceId;
});
setSelectedSourceIds(new Set(mappedSources));
} else if (data.source) {
if (data.source === 'default') {
const defaultSource = sourceDocs?.find(
(source) => source.name === 'Default',
);
setSelectedSourceIds(
new Set([defaultSource?.retriever || 'classic']),
);
} else {
setSelectedSourceIds(new Set([data.source]));
}
} else if (data.retriever) {
setSelectedSourceIds(new Set([data.retriever]));
}
if (data.tools) setSelectedToolIds(new Set(data.tools));
if (data.status === 'draft') setEffectiveMode('draft');
if (data.json_schema) {
@@ -311,24 +404,56 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
}, [agentId, mode, token]);
useEffect(() => {
const selectedSource = Array.from(selectedSourceIds).map((id) =>
const selectedSources = Array.from(selectedSourceIds)
.map((id) =>
sourceDocs?.find(
(source) =>
source.id === id || source.retriever === id || source.name === id,
),
);
if (selectedSource[0]?.model === embeddingsName) {
if (selectedSource[0] && 'id' in selectedSource[0]) {
)
.filter(Boolean);
if (selectedSources.length > 0) {
// Handle multiple sources
if (selectedSources.length > 1) {
// Multiple sources selected - store in sources array
const sourceIds = selectedSources
.map((source) => source?.id)
.filter((id): id is string => Boolean(id));
setAgent((prev) => ({
...prev,
source: selectedSource[0]?.id || 'default',
sources: sourceIds,
source: '', // Clear single source for multiple sources
retriever: '',
}));
} else
} else {
// Single source selected - maintain backward compatibility
const selectedSource = selectedSources[0];
if (selectedSource?.model === embeddingsName) {
if (selectedSource && 'id' in selectedSource) {
setAgent((prev) => ({
...prev,
source: selectedSource?.id || 'default',
sources: [], // Clear sources array for single source
retriever: '',
}));
} else {
setAgent((prev) => ({
...prev,
source: '',
retriever: selectedSource[0]?.retriever || 'classic',
sources: [], // Clear sources array
retriever: selectedSource?.retriever || 'classic',
}));
}
}
}
} else {
// No sources selected
setAgent((prev) => ({
...prev,
source: '',
sources: [],
retriever: '',
}));
}
}, [selectedSourceIds]);
@@ -510,7 +635,7 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
)
.filter(Boolean)
.join(', ')
: 'Select source'}
: 'Select sources'}
</button>
<MultiSelectPopup
isOpen={isSourcePopupOpen}
@@ -526,12 +651,10 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
selectedIds={selectedSourceIds}
onSelectionChange={(newSelectedIds: Set<string | number>) => {
setSelectedSourceIds(newSelectedIds);
setIsSourcePopupOpen(false);
}}
title="Select Source"
title="Select Sources"
searchPlaceholder="Search sources..."
noOptionsMessage="No source available"
singleSelect={true}
noOptionsMessage="No sources available"
/>
</div>
<div className="mt-3">

View File

@@ -10,6 +10,7 @@ export type Agent = {
description: string;
image: string;
source: string;
sources?: string[];
chunks: string;
retriever: string;
prompt_id: string;

View File

@@ -57,6 +57,8 @@ const endpoints = {
DIRECTORY_STRUCTURE: (docId: string) =>
`/api/directory_structure?id=${docId}`,
MANAGE_SOURCE_FILES: '/api/manage_source_files',
MCP_TEST_CONNECTION: '/api/mcp_server/test',
MCP_SAVE_SERVER: '/api/mcp_server/save',
},
CONVERSATION: {
ANSWER: '/api/answer',

View File

@@ -90,7 +90,10 @@ const userService = {
path?: string,
search?: string,
): Promise<any> =>
apiClient.get(endpoints.USER.GET_CHUNKS(docId, page, perPage, path, search), token),
apiClient.get(
endpoints.USER.GET_CHUNKS(docId, page, perPage, path, search),
token,
),
addChunk: (data: any, token: string | null): Promise<any> =>
apiClient.post(endpoints.USER.ADD_CHUNK, data, token),
deleteChunk: (
@@ -105,16 +108,24 @@ const userService = {
apiClient.get(endpoints.USER.DIRECTORY_STRUCTURE(docId), token),
manageSourceFiles: (data: FormData, token: string | null): Promise<any> =>
apiClient.postFormData(endpoints.USER.MANAGE_SOURCE_FILES, data, token),
syncConnector: (docId: string, provider: string, token: string | null): Promise<any> => {
testMCPConnection: (data: any, token: string | null): Promise<any> =>
apiClient.post(endpoints.USER.MCP_TEST_CONNECTION, data, token),
saveMCPServer: (data: any, token: string | null): Promise<any> =>
apiClient.post(endpoints.USER.MCP_SAVE_SERVER, data, token),
syncConnector: (
docId: string,
provider: string,
token: string | null,
): Promise<any> => {
const sessionToken = getSessionToken(provider);
return apiClient.post(
endpoints.USER.SYNC_CONNECTOR,
{
source_id: docId,
session_token: sessionToken,
provider: provider
provider: provider,
},
token
token,
);
},
};

View File

View File

@@ -16,7 +16,12 @@ const providerLabel = (provider: string) => {
return map[provider] || provider.replace(/_/g, ' ');
};
const ConnectorAuth: React.FC<ConnectorAuthProps> = ({ provider, onSuccess, onError, label }) => {
const ConnectorAuth: React.FC<ConnectorAuthProps> = ({
provider,
onSuccess,
onError,
label,
}) => {
const token = useSelector(selectToken);
const completedRef = useRef(false);
const intervalRef = useRef<number | null>(null);
@@ -31,8 +36,12 @@ const ConnectorAuth: React.FC<ConnectorAuthProps> = ({ provider, onSuccess, onEr
const handleAuthMessage = (event: MessageEvent) => {
const successGeneric = event.data?.type === 'connector_auth_success';
const successProvider = event.data?.type === `${provider}_auth_success` || event.data?.type === 'google_drive_auth_success';
const errorProvider = event.data?.type === `${provider}_auth_error` || event.data?.type === 'google_drive_auth_error';
const successProvider =
event.data?.type === `${provider}_auth_success` ||
event.data?.type === 'google_drive_auth_success';
const errorProvider =
event.data?.type === `${provider}_auth_error` ||
event.data?.type === 'google_drive_auth_error';
if (successGeneric || successProvider) {
completedRef.current = true;
@@ -54,12 +63,17 @@ const ConnectorAuth: React.FC<ConnectorAuthProps> = ({ provider, onSuccess, onEr
cleanup();
const apiHost = import.meta.env.VITE_API_HOST;
const authResponse = await fetch(`${apiHost}/api/connectors/auth?provider=${provider}`, {
const authResponse = await fetch(
`${apiHost}/api/connectors/auth?provider=${provider}`,
{
headers: { Authorization: `Bearer ${token}` },
});
},
);
if (!authResponse.ok) {
throw new Error(`Failed to get authorization URL: ${authResponse.status}`);
throw new Error(
`Failed to get authorization URL: ${authResponse.status}`,
);
}
const authData = await authResponse.json();
@@ -70,10 +84,12 @@ const ConnectorAuth: React.FC<ConnectorAuthProps> = ({ provider, onSuccess, onEr
const authWindow = window.open(
authData.authorization_url,
`${provider}-auth`,
'width=500,height=600,scrollbars=yes,resizable=yes'
'width=500,height=600,scrollbars=yes,resizable=yes',
);
if (!authWindow) {
throw new Error('Failed to open authentication window. Please allow popups.');
throw new Error(
'Failed to open authentication window. Please allow popups.',
);
}
window.addEventListener('message', handleAuthMessage as any);
@@ -98,10 +114,13 @@ const ConnectorAuth: React.FC<ConnectorAuthProps> = ({ provider, onSuccess, onEr
return (
<button
onClick={handleAuth}
className="w-full flex items-center justify-center gap-2 rounded-lg bg-blue-500 px-4 py-3 text-white hover:bg-blue-600 transition-colors"
className="flex w-full items-center justify-center gap-2 rounded-lg bg-blue-500 px-4 py-3 text-white transition-colors hover:bg-blue-600"
>
<svg className="h-5 w-5" viewBox="0 0 24 24">
<path fill="currentColor" d="M6.28 3l5.72 10H24l-5.72-10H6.28zm11.44 0L12 13l5.72 10H24L18.28 3h-.56zM0 13l5.72 10h5.72L5.72 13H0z"/>
<path
fill="currentColor"
d="M6.28 3l5.72 10H24l-5.72-10H6.28zm11.44 0L12 13l5.72 10H24L18.28 3h-.56zM0 13l5.72 10h5.72L5.72 13H0z"
/>
</svg>
{buttonLabel}
</button>
@@ -109,4 +128,3 @@ const ConnectorAuth: React.FC<ConnectorAuthProps> = ({ provider, onSuccess, onEr
};
export default ConnectorAuth;

View File

@@ -240,8 +240,6 @@ const ConnectorTreeComponent: React.FC<ConnectorTreeComponentProps> = ({
return current;
};
const getMenuRef = (id: string) => {
if (!menuRefs.current[id]) {
menuRefs.current[id] = React.createRef();

View File

@@ -136,8 +136,6 @@ const FileTreeComponent: React.FC<FileTreeComponentProps> = ({
}
}, [docId, token]);
const navigateToDirectory = (dirName: string) => {
setCurrentPath((prev) => [...prev, dirName]);
};
@@ -445,18 +443,18 @@ const FileTreeComponent: React.FC<FileTreeComponentProps> = ({
const renderPathNavigation = () => {
return (
<div className="mb-0 min-h-[38px] flex flex-col gap-2 text-base sm:flex-row sm:items-center sm:justify-between">
<div className="mb-0 flex min-h-[38px] flex-col gap-2 text-base sm:flex-row sm:items-center sm:justify-between">
{/* Left side with path navigation */}
<div className="flex w-full items-center sm:w-auto">
<button
className="mr-3 flex h-[29px] w-[29px] items-center justify-center rounded-full border p-2 text-sm text-gray-400 dark:border-0 dark:bg-[#28292D] dark:text-gray-500 dark:hover:bg-[#2E2F34] font-medium"
className="mr-3 flex h-[29px] w-[29px] items-center justify-center rounded-full border p-2 text-sm font-medium text-gray-400 dark:border-0 dark:bg-[#28292D] dark:text-gray-500 dark:hover:bg-[#2E2F34]"
onClick={handleBackNavigation}
>
<img src={ArrowLeft} alt="left-arrow" className="h-3 w-3" />
</button>
<div className="flex flex-wrap items-center">
<span className="text-[#7D54D1] font-semibold break-words">
<span className="font-semibold break-words text-[#7D54D1]">
{sourceName}
</span>
{currentPath.length > 0 && (
@@ -487,8 +485,7 @@ const FileTreeComponent: React.FC<FileTreeComponentProps> = ({
</div>
</div>
<div className="flex relative flex-row flex-nowrap items-center gap-2 w-full sm:w-auto justify-end mt-2 sm:mt-0">
<div className="relative mt-2 flex w-full flex-row flex-nowrap items-center justify-end gap-2 sm:mt-0 sm:w-auto">
{processingRef.current && (
<div className="text-sm text-gray-500">
{currentOpRef.current === 'add'
@@ -503,7 +500,7 @@ const FileTreeComponent: React.FC<FileTreeComponentProps> = ({
{!processingRef.current && (
<button
onClick={handleAddFile}
className="bg-purple-30 hover:bg-violets-are-blue flex h-[38px] min-w-[108px] items-center justify-center rounded-full px-4 text-[14px] whitespace-nowrap text-white font-medium"
className="bg-purple-30 hover:bg-violets-are-blue flex h-[38px] min-w-[108px] items-center justify-center rounded-full px-4 text-[14px] font-medium whitespace-nowrap text-white"
title={t('settings.sources.addFile')}
>
{t('settings.sources.addFile')}
@@ -609,7 +606,7 @@ const FileTreeComponent: React.FC<FileTreeComponentProps> = ({
<div ref={menuRef} className="relative">
<button
onClick={(e) => handleMenuClick(e, itemId)}
className="inline-flex h-[35px] w-[24px] shrink-0 items-center justify-center rounded-md transition-colors hover:bg-[#EBEBEB] dark:hover:bg-[#26272E] font-medium"
className="inline-flex h-[35px] w-[24px] shrink-0 items-center justify-center rounded-md font-medium transition-colors hover:bg-[#EBEBEB] dark:hover:bg-[#26272E]"
aria-label={t('settings.sources.menuAlt')}
>
<img
@@ -664,7 +661,7 @@ const FileTreeComponent: React.FC<FileTreeComponentProps> = ({
<div ref={menuRef} className="relative">
<button
onClick={(e) => handleMenuClick(e, itemId)}
className="inline-flex h-[35px] w-[24px] shrink-0 items-center justify-center rounded-md transition-colors hover:bg-[#EBEBEB] dark:hover:bg-[#26272E] font-medium"
className="inline-flex h-[35px] w-[24px] shrink-0 items-center justify-center rounded-md font-medium transition-colors hover:bg-[#EBEBEB] dark:hover:bg-[#26272E]"
aria-label={t('settings.sources.menuAlt')}
>
<img
@@ -756,14 +753,12 @@ const FileTreeComponent: React.FC<FileTreeComponentProps> = ({
}
}}
placeholder={t('settings.sources.searchFiles')}
className={`w-full h-[38px] border border-[#D1D9E0] px-4 py-2 dark:border-[#6A6A6A]
${searchQuery ? 'rounded-t-[24px]' : 'rounded-[24px]'}
bg-transparent focus:outline-none dark:text-[#E0E0E0]`}
className={`h-[38px] w-full border border-[#D1D9E0] px-4 py-2 dark:border-[#6A6A6A] ${searchQuery ? 'rounded-t-[24px]' : 'rounded-[24px]'} bg-transparent focus:outline-none dark:text-[#E0E0E0]`}
/>
{searchQuery && (
<div className="absolute top-full left-0 right-0 z-10 max-h-[calc(100vh-200px)] w-full overflow-hidden rounded-b-[12px] border border-t-0 border-[#D1D9E0] bg-white shadow-lg dark:border-[#6A6A6A] dark:bg-[#1F2023] transition-all duration-200">
<div className="max-h-[calc(100vh-200px)] overflow-y-auto overflow-x-hidden overscroll-contain">
<div className="absolute top-full right-0 left-0 z-10 max-h-[calc(100vh-200px)] w-full overflow-hidden rounded-b-[12px] border border-t-0 border-[#D1D9E0] bg-white shadow-lg transition-all duration-200 dark:border-[#6A6A6A] dark:bg-[#1F2023]">
<div className="max-h-[calc(100vh-200px)] overflow-x-hidden overflow-y-auto overscroll-contain">
{searchResults.length === 0 ? (
<div className="py-2 text-center text-sm text-gray-500 dark:text-gray-400">
{t('settings.sources.noResults')}
@@ -774,7 +769,8 @@ const FileTreeComponent: React.FC<FileTreeComponentProps> = ({
key={index}
onClick={() => handleSearchSelect(result)}
title={result.path}
className={`flex min-w-0 cursor-pointer items-center px-3 py-2 hover:bg-[#ECEEEF] dark:hover:bg-[#27282D] ${index !== searchResults.length - 1
className={`flex min-w-0 cursor-pointer items-center px-3 py-2 hover:bg-[#ECEEEF] dark:hover:bg-[#27282D] ${
index !== searchResults.length - 1
? 'border-b border-[#D1D9E0] dark:border-[#6A6A6A]'
: ''
}`}
@@ -788,7 +784,7 @@ const FileTreeComponent: React.FC<FileTreeComponentProps> = ({
}
className="mr-2 h-4 w-4 flex-shrink-0"
/>
<span className="text-sm dark:text-[#E0E0E0] truncate flex-1">
<span className="flex-1 truncate text-sm dark:text-[#E0E0E0]">
{result.path.split('/').pop() || result.path}
</span>
</div>
@@ -870,7 +866,9 @@ const FileTreeComponent: React.FC<FileTreeComponentProps> = ({
message={
itemToDelete?.isFile
? t('settings.sources.confirmDelete')
: t('settings.sources.deleteDirectoryWarning', { name: itemToDelete?.name })
: t('settings.sources.deleteDirectoryWarning', {
name: itemToDelete?.name,
})
}
modalState={deleteModalState}
setModalState={setDeleteModalState}

View File

@@ -368,8 +368,8 @@ export default function MessageInput({
className="xs:px-3 xs:py-1.5 dark:border-purple-taupe flex max-w-[130px] items-center rounded-[32px] border border-[#AAAAAA] px-2 py-1 transition-colors hover:bg-gray-100 sm:max-w-[150px] dark:hover:bg-[#2C2E3C]"
onClick={() => setIsSourcesPopupOpen(!isSourcesPopupOpen)}
title={
selectedDocs
? selectedDocs.name
selectedDocs && selectedDocs.length > 0
? selectedDocs.map((doc) => doc.name).join(', ')
: t('conversation.sources.title')
}
>
@@ -379,8 +379,10 @@ export default function MessageInput({
className="mr-1 h-3.5 w-3.5 shrink-0 sm:mr-1.5 sm:h-4"
/>
<span className="xs:text-[12px] dark:text-bright-gray truncate overflow-hidden text-[10px] font-medium text-[#5D5D5D] sm:text-[14px]">
{selectedDocs
? selectedDocs.name
{selectedDocs && selectedDocs.length > 0
? selectedDocs.length === 1
? selectedDocs[0].name
: `${selectedDocs.length} sources selected`
: t('conversation.sources.title')}
</span>
{!isTouch && (

View File

@@ -17,7 +17,7 @@ type SourcesPopupProps = {
isOpen: boolean;
onClose: () => void;
anchorRef: React.RefObject<HTMLButtonElement | null>;
handlePostDocumentSelect: (doc: Doc | null) => void;
handlePostDocumentSelect: (doc: Doc[] | null) => void;
setUploadModalState: React.Dispatch<React.SetStateAction<ActiveState>>;
};
@@ -149,9 +149,13 @@ export default function SourcesPopup({
if (option.model === embeddingsName) {
const isSelected =
selectedDocs &&
(option.id
? selectedDocs.id === option.id
: selectedDocs.date === option.date);
Array.isArray(selectedDocs) &&
selectedDocs.length > 0 &&
selectedDocs.some((doc) =>
option.id
? doc.id === option.id
: doc.date === option.date,
);
return (
<div
@@ -159,11 +163,29 @@ export default function SourcesPopup({
className="border-opacity-80 dark:border-dim-gray flex cursor-pointer items-center border-b border-[#D9D9D9] p-3 transition-colors hover:bg-gray-100 dark:text-[14px] dark:hover:bg-[#2C2E3C]"
onClick={() => {
if (isSelected) {
dispatch(setSelectedDocs(null));
handlePostDocumentSelect(null);
const updatedDocs =
selectedDocs && Array.isArray(selectedDocs)
? selectedDocs.filter((doc) =>
option.id
? doc.id !== option.id
: doc.date !== option.date,
)
: [];
dispatch(
setSelectedDocs(
updatedDocs.length > 0 ? updatedDocs : null,
),
);
handlePostDocumentSelect(
updatedDocs.length > 0 ? updatedDocs : null,
);
} else {
dispatch(setSelectedDocs(option));
handlePostDocumentSelect(option);
const updatedDocs =
selectedDocs && Array.isArray(selectedDocs)
? [...selectedDocs, option]
: [option];
dispatch(setSelectedDocs(updatedDocs));
handlePostDocumentSelect(updatedDocs);
}
}}
>

View File

@@ -1,6 +1,6 @@
import 'katex/dist/katex.min.css';
import { forwardRef, Fragment, useRef, useState, useEffect } from 'react';
import { forwardRef, Fragment, useEffect, useRef, useState } from 'react';
import { useTranslation } from 'react-i18next';
import ReactMarkdown from 'react-markdown';
import { useSelector } from 'react-redux';
@@ -12,12 +12,13 @@ import {
import rehypeKatex from 'rehype-katex';
import remarkGfm from 'remark-gfm';
import remarkMath from 'remark-math';
import DocumentationDark from '../assets/documentation-dark.svg';
import ChevronDown from '../assets/chevron-down.svg';
import Cloud from '../assets/cloud.svg';
import DocsGPT3 from '../assets/cute_docsgpt3.svg';
import Dislike from '../assets/dislike.svg?react';
import Document from '../assets/document.svg';
import DocumentationDark from '../assets/documentation-dark.svg';
import Edit from '../assets/edit.svg';
import Like from '../assets/like.svg?react';
import Link from '../assets/link.svg';
@@ -761,7 +762,11 @@ function ToolCalls({ toolCalls }: { toolCalls: ToolCallsType[] }) {
Response
</span>{' '}
<CopyButton
textToCopy={JSON.stringify(toolCall.result, null, 2)}
textToCopy={
toolCall.status === 'error'
? toolCall.error || 'Unknown error'
: JSON.stringify(toolCall.result, null, 2)
}
/>
</p>
{toolCall.status === 'pending' && (
@@ -779,6 +784,16 @@ function ToolCalls({ toolCalls }: { toolCalls: ToolCallsType[] }) {
</span>
</p>
)}
{toolCall.status === 'error' && (
<p className="dark:bg-raisin-black rounded-b-2xl p-2 font-mono text-sm break-words">
<span
className="leading-[23px] text-red-500 dark:text-red-400"
style={{ fontFamily: 'IBMPlexMono-Medium' }}
>
{toolCall.error}
</span>
</p>
)}
</div>
</div>
</Accordion>

View File

@@ -7,7 +7,7 @@ export function handleFetchAnswer(
question: string,
signal: AbortSignal,
token: string | null,
selectedDocs: Doc | null,
selectedDocs: Doc[] | null,
conversationId: string | null,
promptId: string | null,
chunks: string,
@@ -52,10 +52,17 @@ export function handleFetchAnswer(
payload.attachments = attachments;
}
if (selectedDocs && 'id' in selectedDocs) {
payload.active_docs = selectedDocs.id as string;
if (selectedDocs && Array.isArray(selectedDocs)) {
if (selectedDocs.length > 1) {
// Handle multiple documents
payload.active_docs = selectedDocs.map((doc) => doc.id!);
payload.retriever = selectedDocs[0]?.retriever as string;
} else if (selectedDocs.length === 1 && 'id' in selectedDocs[0]) {
// Handle single document (backward compatibility)
payload.active_docs = selectedDocs[0].id as string;
payload.retriever = selectedDocs[0].retriever as string;
}
}
payload.retriever = selectedDocs?.retriever as string;
return conversationService
.answer(payload, token, signal)
.then((response) => {
@@ -84,7 +91,7 @@ export function handleFetchAnswerSteaming(
question: string,
signal: AbortSignal,
token: string | null,
selectedDocs: Doc | null,
selectedDocs: Doc[] | null,
conversationId: string | null,
promptId: string | null,
chunks: string,
@@ -112,10 +119,17 @@ export function handleFetchAnswerSteaming(
payload.attachments = attachments;
}
if (selectedDocs && 'id' in selectedDocs) {
payload.active_docs = selectedDocs.id as string;
if (selectedDocs && Array.isArray(selectedDocs)) {
if (selectedDocs.length > 1) {
// Handle multiple documents
payload.active_docs = selectedDocs.map((doc) => doc.id!);
payload.retriever = selectedDocs[0]?.retriever as string;
} else if (selectedDocs.length === 1 && 'id' in selectedDocs[0]) {
// Handle single document (backward compatibility)
payload.active_docs = selectedDocs[0].id as string;
payload.retriever = selectedDocs[0].retriever as string;
}
}
payload.retriever = selectedDocs?.retriever as string;
return new Promise<Answer>((resolve, reject) => {
conversationService
@@ -171,7 +185,7 @@ export function handleFetchAnswerSteaming(
export function handleSearch(
question: string,
token: string | null,
selectedDocs: Doc | null,
selectedDocs: Doc[] | null,
conversation_id: string | null,
chunks: string,
token_limit: number,
@@ -183,9 +197,17 @@ export function handleSearch(
token_limit: token_limit,
isNoneDoc: selectedDocs === null,
};
if (selectedDocs && 'id' in selectedDocs)
payload.active_docs = selectedDocs.id as string;
payload.retriever = selectedDocs?.retriever as string;
if (selectedDocs && Array.isArray(selectedDocs)) {
if (selectedDocs.length > 1) {
// Handle multiple documents
payload.active_docs = selectedDocs.map((doc) => doc.id!);
payload.retriever = selectedDocs[0]?.retriever as string;
} else if (selectedDocs.length === 1 && 'id' in selectedDocs[0]) {
// Handle single document (backward compatibility)
payload.active_docs = selectedDocs[0].id as string;
payload.retriever = selectedDocs[0].retriever as string;
}
}
return conversationService
.search(payload, token)
.then((response) => response.json())

View File

@@ -54,7 +54,7 @@ export interface Query {
export interface RetrievalPayload {
question: string;
active_docs?: string;
active_docs?: string | string[];
retriever?: string;
conversation_id: string | null;
prompt_id?: string | null;

View File

@@ -4,5 +4,6 @@ export type ToolCallsType = {
call_id: string;
arguments: Record<string, any>;
result?: Record<string, any>;
status?: 'pending' | 'completed';
error?: string;
status?: 'pending' | 'completed' | 'error';
};

View File

@@ -18,11 +18,14 @@ export default function useDefaultDocument() {
const fetchDocs = () => {
getDocs(token).then((data) => {
dispatch(setSourceDocs(data));
if (!selectedDoc)
if (
!selectedDoc ||
(Array.isArray(selectedDoc) && selectedDoc.length === 0)
)
Array.isArray(data) &&
data?.forEach((doc: Doc) => {
if (doc.model && doc.name === 'default') {
dispatch(setSelectedDocs(doc));
dispatch(setSelectedDocs([doc]));
}
});
});

View File

@@ -185,7 +185,39 @@
"cancel": "Cancel",
"addNew": "Add New",
"name": "Name",
"type": "Type"
"type": "Type",
"mcp": {
"addServer": "Add MCP Server",
"editServer": "Edit Server",
"serverName": "Server Name",
"serverUrl": "Server URL",
"headerName": "Header Name",
"timeout": "Timeout (seconds)",
"testConnection": "Test Connection",
"testing": "Testing...",
"saving": "Saving...",
"save": "Save",
"cancel": "Cancel",
"noAuth": "No Authentication",
"placeholders": {
"serverUrl": "https://api.example.com",
"apiKey": "Your secret API key",
"bearerToken": "Your secret token",
"username": "Your username",
"password": "Your password"
},
"errors": {
"nameRequired": "Server name is required",
"urlRequired": "Server URL is required",
"invalidUrl": "Please enter a valid URL",
"apiKeyRequired": "API key is required",
"tokenRequired": "Bearer token is required",
"usernameRequired": "Username is required",
"passwordRequired": "Password is required",
"testFailed": "Connection test failed",
"saveFailed": "Failed to save MCP server"
}
}
}
},
"modals": {

View File

@@ -8,6 +8,7 @@ import { useOutsideAlerter } from '../hooks';
import { ActiveState } from '../models/misc';
import { selectToken } from '../preferences/preferenceSlice';
import ConfigToolModal from './ConfigToolModal';
import MCPServerModal from './MCPServerModal';
import { AvailableToolType } from './types';
import WrapperComponent from './WrapperModal';
@@ -34,6 +35,8 @@ export default function AddToolModal({
React.useState<AvailableToolType | null>(null);
const [configModalState, setConfigModalState] =
React.useState<ActiveState>('INACTIVE');
const [mcpModalState, setMcpModalState] =
React.useState<ActiveState>('INACTIVE');
const [loading, setLoading] = React.useState(false);
useOutsideAlerter(modalRef, () => {
@@ -86,6 +89,9 @@ export default function AddToolModal({
.catch((error) => {
console.error('Failed to create tool:', error);
});
} else if (tool.name === 'mcp_tool') {
setModalState('INACTIVE');
setMcpModalState('ACTIVE');
} else {
setModalState('INACTIVE');
setConfigModalState('ACTIVE');
@@ -95,6 +101,12 @@ export default function AddToolModal({
React.useEffect(() => {
if (modalState === 'ACTIVE') getAvailableTools();
}, [modalState]);
const handleMcpServerAdded = () => {
getUserTools();
setMcpModalState('INACTIVE');
};
return (
<>
{modalState === 'ACTIVE' && (
@@ -166,6 +178,11 @@ export default function AddToolModal({
tool={selectedTool}
getUserTools={getUserTools}
/>
<MCPServerModal
modalState={mcpModalState}
setModalState={setMcpModalState}
onServerSaved={handleMcpServerAdded}
/>
</>
);
}

View File

@@ -0,0 +1,482 @@
import { useRef, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { useSelector } from 'react-redux';
import userService from '../api/services/userService';
import Dropdown from '../components/Dropdown';
import Input from '../components/Input';
import Spinner from '../components/Spinner';
import { useOutsideAlerter } from '../hooks';
import { ActiveState } from '../models/misc';
import { selectToken } from '../preferences/preferenceSlice';
import WrapperComponent from './WrapperModal';
interface MCPServerModalProps {
modalState: ActiveState;
setModalState: (state: ActiveState) => void;
server?: any;
onServerSaved: () => void;
}
const authTypes = [
{ label: 'No Authentication', value: 'none' },
{ label: 'API Key', value: 'api_key' },
{ label: 'Bearer Token', value: 'bearer' },
// { label: 'Basic Authentication', value: 'basic' },
];
export default function MCPServerModal({
modalState,
setModalState,
server,
onServerSaved,
}: MCPServerModalProps) {
const { t } = useTranslation();
const token = useSelector(selectToken);
const modalRef = useRef<HTMLDivElement>(null);
const [formData, setFormData] = useState({
name: server?.displayName || 'My MCP Server',
server_url: server?.server_url || '',
auth_type: server?.auth_type || 'none',
api_key: '',
header_name: 'X-API-Key',
bearer_token: '',
username: '',
password: '',
timeout: server?.timeout || 30,
});
const [loading, setLoading] = useState(false);
const [testing, setTesting] = useState(false);
const [testResult, setTestResult] = useState<{
success: boolean;
message: string;
} | null>(null);
const [errors, setErrors] = useState<{ [key: string]: string }>({});
useOutsideAlerter(modalRef, () => {
if (modalState === 'ACTIVE') {
setModalState('INACTIVE');
resetForm();
}
}, [modalState]);
const resetForm = () => {
setFormData({
name: 'My MCP Server',
server_url: '',
auth_type: 'none',
api_key: '',
header_name: 'X-API-Key',
bearer_token: '',
username: '',
password: '',
timeout: 30,
});
setErrors({});
setTestResult(null);
};
const validateForm = () => {
const requiredFields: { [key: string]: boolean } = {
name: !formData.name.trim(),
server_url: !formData.server_url.trim(),
};
const authFieldChecks: { [key: string]: () => void } = {
api_key: () => {
if (!formData.api_key.trim())
newErrors.api_key = t('settings.tools.mcp.errors.apiKeyRequired');
},
bearer: () => {
if (!formData.bearer_token.trim())
newErrors.bearer_token = t('settings.tools.mcp.errors.tokenRequired');
},
basic: () => {
if (!formData.username.trim())
newErrors.username = t('settings.tools.mcp.errors.usernameRequired');
if (!formData.password.trim())
newErrors.password = t('settings.tools.mcp.errors.passwordRequired');
},
};
const newErrors: { [key: string]: string } = {};
Object.entries(requiredFields).forEach(([field, isEmpty]) => {
if (isEmpty)
newErrors[field] = t(
`settings.tools.mcp.errors.${field === 'name' ? 'nameRequired' : 'urlRequired'}`,
);
});
if (formData.server_url.trim()) {
try {
new URL(formData.server_url);
} catch {
newErrors.server_url = t('settings.tools.mcp.errors.invalidUrl');
}
}
const timeoutValue = formData.timeout === '' ? 30 : formData.timeout;
if (
typeof timeoutValue === 'number' &&
(timeoutValue < 1 || timeoutValue > 300)
)
newErrors.timeout = 'Timeout must be between 1 and 300 seconds';
if (authFieldChecks[formData.auth_type])
authFieldChecks[formData.auth_type]();
setErrors(newErrors);
return Object.keys(newErrors).length === 0;
};
const handleInputChange = (name: string, value: string | number) => {
setFormData((prev) => ({ ...prev, [name]: value }));
if (errors[name]) {
setErrors((prev) => ({ ...prev, [name]: '' }));
}
setTestResult(null);
};
const buildToolConfig = () => {
const config: any = {
server_url: formData.server_url.trim(),
auth_type: formData.auth_type,
timeout: formData.timeout === '' ? 30 : formData.timeout,
};
if (formData.auth_type === 'api_key') {
config.api_key = formData.api_key.trim();
config.api_key_header = formData.header_name.trim() || 'X-API-Key';
} else if (formData.auth_type === 'bearer') {
config.bearer_token = formData.bearer_token.trim();
} else if (formData.auth_type === 'basic') {
config.username = formData.username.trim();
config.password = formData.password.trim();
}
return config;
};
const testConnection = async () => {
if (!validateForm()) return;
setTesting(true);
setTestResult(null);
try {
const config = buildToolConfig();
const response = await userService.testMCPConnection({ config }, token);
const result = await response.json();
setTestResult(result);
} catch (error) {
setTestResult({
success: false,
message: t('settings.tools.mcp.errors.testFailed'),
});
} finally {
setTesting(false);
}
};
const handleSave = async () => {
if (!validateForm()) return;
setLoading(true);
try {
const config = buildToolConfig();
const serverData = {
displayName: formData.name,
config,
status: true,
...(server?.id && { id: server.id }),
};
const response = await userService.saveMCPServer(serverData, token);
const result = await response.json();
if (response.ok && result.success) {
setTestResult({
success: true,
message: result.message,
});
onServerSaved();
setModalState('INACTIVE');
resetForm();
} else {
setErrors({
general: result.error || t('settings.tools.mcp.errors.saveFailed'),
});
}
} catch (error) {
console.error('Error saving MCP server:', error);
setErrors({ general: t('settings.tools.mcp.errors.saveFailed') });
} finally {
setLoading(false);
}
};
const renderAuthFields = () => {
switch (formData.auth_type) {
case 'api_key':
return (
<div className="mb-10">
<div className="mt-6">
<Input
name="api_key"
type="text"
className="rounded-md"
value={formData.api_key}
onChange={(e) => handleInputChange('api_key', e.target.value)}
placeholder={t('settings.tools.mcp.placeholders.apiKey')}
borderVariant="thin"
labelBgClassName="bg-white dark:bg-charleston-green-2"
/>
{errors.api_key && (
<p className="mt-1 text-sm text-red-600">{errors.api_key}</p>
)}
</div>
<div className="mt-5">
<Input
name="header_name"
type="text"
className="rounded-md"
value={formData.header_name}
onChange={(e) =>
handleInputChange('header_name', e.target.value)
}
placeholder={t('settings.tools.mcp.headerName')}
borderVariant="thin"
labelBgClassName="bg-white dark:bg-charleston-green-2"
/>
</div>
</div>
);
case 'bearer':
return (
<div className="mb-10">
<Input
name="bearer_token"
type="text"
className="rounded-md"
value={formData.bearer_token}
onChange={(e) =>
handleInputChange('bearer_token', e.target.value)
}
placeholder={t('settings.tools.mcp.placeholders.bearerToken')}
borderVariant="thin"
labelBgClassName="bg-white dark:bg-charleston-green-2"
/>
{errors.bearer_token && (
<p className="mt-1 text-sm text-red-600">{errors.bearer_token}</p>
)}
</div>
);
case 'basic':
return (
<div className="mb-10">
<div className="mt-6">
<Input
name="username"
type="text"
className="rounded-md"
value={formData.username}
onChange={(e) => handleInputChange('username', e.target.value)}
placeholder={t('settings.tools.mcp.username')}
borderVariant="thin"
labelBgClassName="bg-white dark:bg-charleston-green-2"
/>
{errors.username && (
<p className="mt-1 text-sm text-red-600">{errors.username}</p>
)}
</div>
<div className="mt-5">
<Input
name="password"
type="text"
className="rounded-md"
value={formData.password}
onChange={(e) => handleInputChange('password', e.target.value)}
placeholder={t('settings.tools.mcp.password')}
borderVariant="thin"
labelBgClassName="bg-white dark:bg-charleston-green-2"
/>
{errors.password && (
<p className="mt-1 text-sm text-red-600">{errors.password}</p>
)}
</div>
</div>
);
default:
return null;
}
};
return (
modalState === 'ACTIVE' && (
<WrapperComponent
close={() => {
setModalState('INACTIVE');
resetForm();
}}
className="max-w-[600px] md:w-[80vw] lg:w-[60vw]"
>
<div className="flex h-full flex-col">
<div className="px-6 py-4">
<h2 className="text-jet dark:text-bright-gray text-xl font-semibold">
{server
? t('settings.tools.mcp.editServer')
: t('settings.tools.mcp.addServer')}
</h2>
</div>
<div className="flex-1 px-6">
<div className="space-y-6 py-6">
<div>
<Input
name="name"
type="text"
className="rounded-md"
value={formData.name}
onChange={(e) => handleInputChange('name', e.target.value)}
borderVariant="thin"
placeholder={t('settings.tools.mcp.serverName')}
labelBgClassName="bg-white dark:bg-charleston-green-2"
/>
{errors.name && (
<p className="mt-1 text-sm text-red-600">{errors.name}</p>
)}
</div>
<div>
<Input
name="server_url"
type="text"
className="rounded-md"
value={formData.server_url}
onChange={(e) =>
handleInputChange('server_url', e.target.value)
}
placeholder={t('settings.tools.mcp.serverUrl')}
borderVariant="thin"
labelBgClassName="bg-white dark:bg-charleston-green-2"
/>
{errors.server_url && (
<p className="mt-1 text-sm text-red-600">
{errors.server_url}
</p>
)}
</div>
<Dropdown
placeholder={t('settings.tools.mcp.authType')}
selectedValue={
authTypes.find((type) => type.value === formData.auth_type)
?.label || null
}
onSelect={(selection: { label: string; value: string }) => {
handleInputChange('auth_type', selection.value);
}}
options={authTypes}
size="w-full"
rounded="3xl"
border="border"
/>
{renderAuthFields()}
<div>
<Input
name="timeout"
type="number"
className="rounded-md"
value={formData.timeout}
onChange={(e) => {
const value = e.target.value;
if (value === '') {
handleInputChange('timeout', '');
} else {
const numValue = parseInt(value);
if (!isNaN(numValue) && numValue >= 1) {
handleInputChange('timeout', numValue);
}
}
}}
placeholder={t('settings.tools.mcp.timeout')}
borderVariant="thin"
labelBgClassName="bg-white dark:bg-charleston-green-2"
/>
{errors.timeout && (
<p className="mt-2 text-sm text-red-600">{errors.timeout}</p>
)}
</div>
{testResult && (
<div
className={`rounded-md p-5 ${
testResult.success
? 'bg-green-50 text-green-700 dark:bg-green-900/40 dark:text-green-300'
: 'bg-red-50 text-red-700 dark:bg-red-900 dark:text-red-300'
}`}
>
{testResult.message}
</div>
)}
{errors.general && (
<div className="rounded-2xl bg-red-50 p-5 text-red-700 dark:bg-red-900 dark:text-red-300">
{errors.general}
</div>
)}
</div>
</div>
<div className="px-6 py-2">
<div className="flex flex-col gap-4 sm:flex-row sm:justify-between">
<button
onClick={testConnection}
disabled={testing}
className="border-silver dark:border-dim-gray dark:text-light-gray w-full rounded-3xl border px-6 py-2 text-sm font-medium transition-all hover:bg-gray-100 disabled:opacity-50 sm:w-auto dark:hover:bg-[#767183]/50"
>
{testing ? (
<div className="flex items-center justify-center">
<Spinner size="small" />
<span className="ml-2">
{t('settings.tools.mcp.testing')}
</span>
</div>
) : (
t('settings.tools.mcp.testConnection')
)}
</button>
<div className="flex flex-col-reverse gap-3 sm:flex-row sm:gap-3">
<button
onClick={() => {
setModalState('INACTIVE');
resetForm();
}}
className="dark:text-light-gray w-full cursor-pointer rounded-3xl px-6 py-2 text-sm font-medium hover:bg-gray-100 sm:w-auto dark:bg-transparent dark:hover:bg-[#767183]/50"
>
{t('settings.tools.mcp.cancel')}
</button>
<button
onClick={handleSave}
disabled={loading}
className="bg-purple-30 hover:bg-violets-are-blue w-full rounded-3xl px-6 py-2 text-sm font-medium text-white transition-all disabled:opacity-50 sm:w-auto"
>
{loading ? (
<div className="flex items-center justify-center">
<Spinner size="small" />
<span className="ml-2">
{t('settings.tools.mcp.saving')}
</span>
</div>
) : (
t('settings.tools.mcp.save')
)}
</button>
</div>
</div>
</div>
</div>
</WrapperComponent>
)
);
}

View File

@@ -60,7 +60,7 @@ export const ShareConversationModal = ({
const [sourcePath, setSourcePath] = useState<{
label: string;
value: string;
} | null>(preSelectedDoc ? extractDocPaths([preSelectedDoc])[0] : null);
} | null>(preSelectedDoc ? extractDocPaths(preSelectedDoc)[0] : null);
const handleCopyKey = (url: string) => {
navigator.clipboard.writeText(url);
@@ -105,14 +105,14 @@ export const ShareConversationModal = ({
return (
<WrapperModal close={close}>
<div className="flex flex-col gap-2">
<h2 className="text-xl font-medium text-eerie-black dark:text-chinese-white">
<h2 className="text-eerie-black dark:text-chinese-white text-xl font-medium">
{t('modals.shareConv.label')}
</h2>
<p className="text-sm text-eerie-black dark:text-silver/60">
<p className="text-eerie-black dark:text-silver/60 text-sm">
{t('modals.shareConv.note')}
</p>
<div className="flex items-center justify-between">
<span className="text-lg text-eerie-black dark:text-white">
<span className="text-eerie-black text-lg dark:text-white">
{t('modals.shareConv.option')}
</span>
<ToggleSwitch
@@ -136,19 +136,19 @@ export const ShareConversationModal = ({
</div>
)}
<div className="flex items-baseline justify-between gap-2">
<span className="no-scrollbar w-full overflow-x-auto whitespace-nowrap rounded-full border-2 border-silver px-4 py-3 text-eerie-black dark:border-silver/40 dark:text-white">
<span className="no-scrollbar border-silver text-eerie-black dark:border-silver/40 w-full overflow-x-auto rounded-full border-2 px-4 py-3 whitespace-nowrap dark:text-white">
{`${domain}/share/${identifier ?? '....'}`}
</span>
{status === 'fetched' ? (
<button
className="my-1 h-10 w-28 rounded-full bg-purple-30 p-2 text-sm text-white hover:bg-violets-are-blue"
className="bg-purple-30 hover:bg-violets-are-blue my-1 h-10 w-28 rounded-full p-2 text-sm text-white"
onClick={() => handleCopyKey(`${domain}/share/${identifier}`)}
>
{isCopied ? t('modals.saveKey.copied') : t('modals.saveKey.copy')}
</button>
) : (
<button
className="my-1 flex h-10 w-28 items-center justify-evenly rounded-full bg-purple-30 p-2 text-center text-sm font-normal text-white hover:bg-violets-are-blue"
className="bg-purple-30 hover:bg-violets-are-blue my-1 flex h-10 w-28 items-center justify-evenly rounded-full p-2 text-center text-sm font-normal text-white"
onClick={() => {
shareCoversationPublicly(allowPrompt);
}}

View File

@@ -90,9 +90,9 @@ export function getLocalApiKey(): string | null {
return key;
}
export function getLocalRecentDocs(): string | null {
const doc = localStorage.getItem('DocsGPTRecentDocs');
return doc;
export function getLocalRecentDocs(): Doc[] | null {
const docs = localStorage.getItem('DocsGPTRecentDocs');
return docs ? (JSON.parse(docs) as Doc[]) : null;
}
export function getLocalPrompt(): string | null {
@@ -108,19 +108,20 @@ export function setLocalPrompt(prompt: string): void {
localStorage.setItem('DocsGPTPrompt', prompt);
}
export function setLocalRecentDocs(doc: Doc | null): void {
localStorage.setItem('DocsGPTRecentDocs', JSON.stringify(doc));
export function setLocalRecentDocs(docs: Doc[] | null): void {
if (docs && docs.length > 0) {
localStorage.setItem('DocsGPTRecentDocs', JSON.stringify(docs));
docs.forEach((doc) => {
let docPath = 'default';
if (doc?.type === 'local') {
if (doc.type === 'local') {
docPath = 'local' + '/' + doc.name + '/';
}
userService
.checkDocs(
{
docs: docPath,
},
null,
)
.checkDocs({ docs: docPath }, null)
.then((response) => response.json());
});
} else {
localStorage.removeItem('DocsGPTRecentDocs');
}
}

View File

@@ -15,7 +15,7 @@ export interface Preference {
prompt: { name: string; id: string; type: string };
chunks: string;
token_limit: number;
selectedDocs: Doc | null;
selectedDocs: Doc[] | null;
sourceDocs: Doc[] | null;
conversations: {
data: { name: string; id: string }[] | null;
@@ -34,15 +34,16 @@ const initialState: Preference = {
prompt: { name: 'default', id: 'default', type: 'public' },
chunks: '2',
token_limit: 2000,
selectedDocs: {
selectedDocs: [
{
id: 'default',
name: 'default',
type: 'remote',
date: 'default',
docLink: 'default',
model: 'openai_text-embedding-ada-002',
retriever: 'classic',
} as Doc,
},
] as Doc[],
sourceDocs: null,
conversations: {
data: null,

View File

@@ -1,4 +1,3 @@
import React, { useCallback, useEffect, useRef, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { useDispatch, useSelector } from 'react-redux';
@@ -319,7 +318,7 @@ export default function Sources({
setSearchTerm(e.target.value);
setCurrentPage(1);
}}
className="w-full h-[32px] rounded-full border border-silver dark:border-silver/40 bg-transparent px-3 text-sm text-jet dark:text-bright-gray placeholder:text-gray-400 dark:placeholder:text-gray-500 outline-none focus:border-silver dark:focus:border-silver/60"
className="border-silver dark:border-silver/40 text-jet dark:text-bright-gray focus:border-silver dark:focus:border-silver/60 h-[32px] w-full rounded-full border bg-transparent px-3 text-sm outline-none placeholder:text-gray-400 dark:placeholder:text-gray-500"
/>
</div>
</div>
@@ -336,7 +335,7 @@ export default function Sources({
</div>
<div className="relative w-full">
{loading ? (
<div className="w-full grid grid-cols-1 sm:grid-cols-2 lg:grid-cols-3 xl:grid-cols-4 gap-6 px-2 py-4">
<div className="grid w-full grid-cols-1 gap-6 px-2 py-4 sm:grid-cols-2 lg:grid-cols-3 xl:grid-cols-4">
<SkeletonLoader component="sourceCards" count={rowsPerPage} />
</div>
) : !currentDocuments?.length ? (
@@ -351,14 +350,15 @@ export default function Sources({
</p>
</div>
) : (
<div className="w-full grid grid-cols-1 sm:grid-cols-2 lg:grid-cols-3 xl:grid-cols-4 gap-6 px-2 py-4">
<div className="grid w-full grid-cols-1 gap-6 px-2 py-4 sm:grid-cols-2 lg:grid-cols-3 xl:grid-cols-4">
{currentDocuments.map((document, index) => {
const docId = document.id ? document.id.toString() : '';
return (
<div key={docId} className="relative">
<div
className={`flex h-[130px] w-full flex-col rounded-2xl bg-[#F9F9F9] p-3 transition-all duration-200 dark:bg-[#383838] ${activeMenuId === docId || syncMenuState.docId === docId
className={`flex h-[130px] w-full flex-col rounded-2xl bg-[#F9F9F9] p-3 transition-all duration-200 dark:bg-[#383838] ${
activeMenuId === docId || syncMenuState.docId === docId
? 'scale-[1.05]'
: 'hover:scale-[1.05]'
}`}
@@ -426,7 +426,7 @@ export default function Sources({
<img
src={CalendarIcon}
alt=""
className="w-[14px] h-[14px]"
className="h-[14px] w-[14px]"
/>
<span className="font-inter text-[12px] leading-[18px] font-[500] text-[#848484] dark:text-[#848484]">
{document.date ? formatDate(document.date) : ''}
@@ -436,7 +436,7 @@ export default function Sources({
<img
src={DiscIcon}
alt=""
className="w-[14px] h-[14px]"
className="h-[14px] w-[14px]"
/>
<span className="font-inter text-[12px] leading-[18px] font-[500] text-[#848484] dark:text-[#848484]">
{document.tokens

View File

@@ -30,9 +30,22 @@ export default function ToolConfig({
handleGoBack: () => void;
}) {
const token = useSelector(selectToken);
const [authKey, setAuthKey] = React.useState<string>(
'token' in tool.config ? tool.config.token : '',
);
const [authKey, setAuthKey] = React.useState<string>(() => {
if (tool.name === 'mcp_tool') {
const config = tool.config as any;
if (config.auth_type === 'api_key') {
return config.api_key || '';
} else if (config.auth_type === 'bearer') {
return config.encrypted_token || '';
} else if (config.auth_type === 'basic') {
return config.password || '';
}
return '';
} else if ('token' in tool.config) {
return tool.config.token;
}
return '';
});
const [customName, setCustomName] = React.useState<string>(
tool.customName || '',
);
@@ -97,6 +110,26 @@ export default function ToolConfig({
};
const handleSaveChanges = () => {
let configToSave;
if (tool.name === 'api_tool') {
configToSave = tool.config;
} else if (tool.name === 'mcp_tool') {
configToSave = { ...tool.config } as any;
const mcpConfig = tool.config as any;
if (authKey.trim()) {
if (mcpConfig.auth_type === 'api_key') {
configToSave.api_key = authKey;
} else if (mcpConfig.auth_type === 'bearer') {
configToSave.encrypted_token = authKey;
} else if (mcpConfig.auth_type === 'basic') {
configToSave.password = authKey;
}
}
} else {
configToSave = { token: authKey };
}
userService
.updateTool(
{
@@ -105,7 +138,7 @@ export default function ToolConfig({
displayName: tool.displayName,
customName: customName,
description: tool.description,
config: tool.name === 'api_tool' ? tool.config : { token: authKey },
config: configToSave,
actions: 'actions' in tool ? tool.actions : [],
status: tool.status,
},
@@ -196,7 +229,15 @@ export default function ToolConfig({
<div className="mt-1">
{Object.keys(tool?.config).length !== 0 && tool.name !== 'api_tool' && (
<p className="text-eerie-black dark:text-bright-gray text-sm font-semibold">
{t('settings.tools.authentication')}
{tool.name === 'mcp_tool'
? (tool.config as any)?.auth_type === 'bearer'
? 'Bearer Token'
: (tool.config as any)?.auth_type === 'api_key'
? 'API Key'
: (tool.config as any)?.auth_type === 'basic'
? 'Password'
: t('settings.tools.authentication')
: t('settings.tools.authentication')}
</p>
)}
<div className="mt-4 flex flex-col items-start gap-2 sm:flex-row sm:items-center">
@@ -208,7 +249,17 @@ export default function ToolConfig({
value={authKey}
onChange={(e) => setAuthKey(e.target.value)}
borderVariant="thin"
placeholder={t('modals.configTool.apiKeyPlaceholder')}
placeholder={
tool.name === 'mcp_tool'
? (tool.config as any)?.auth_type === 'bearer'
? 'Bearer Token'
: (tool.config as any)?.auth_type === 'api_key'
? 'API Key'
: (tool.config as any)?.auth_type === 'basic'
? 'Password'
: t('modals.configTool.apiKeyPlaceholder')
: t('modals.configTool.apiKeyPlaceholder')
}
/>
</div>
)}
@@ -450,6 +501,26 @@ export default function ToolConfig({
setModalState={(state) => setShowUnsavedModal(state === 'ACTIVE')}
submitLabel={t('settings.tools.saveAndLeave')}
handleSubmit={() => {
let configToSave;
if (tool.name === 'api_tool') {
configToSave = tool.config;
} else if (tool.name === 'mcp_tool') {
configToSave = { ...tool.config } as any;
const mcpConfig = tool.config as any;
if (authKey.trim()) {
if (mcpConfig.auth_type === 'api_key') {
configToSave.api_key = authKey;
} else if (mcpConfig.auth_type === 'bearer') {
configToSave.encrypted_token = authKey;
} else if (mcpConfig.auth_type === 'basic') {
configToSave.password = authKey;
}
}
} else {
configToSave = { token: authKey };
}
userService
.updateTool(
{
@@ -458,10 +529,7 @@ export default function ToolConfig({
displayName: tool.displayName,
customName: customName,
description: tool.description,
config:
tool.name === 'api_tool'
? tool.config
: { token: authKey },
config: configToSave,
actions: 'actions' in tool ? tool.actions : [],
status: tool.status,
},

View File

@@ -4,8 +4,15 @@ import { useTranslation } from 'react-i18next';
import { useDispatch, useSelector } from 'react-redux';
import userService from '../api/services/userService';
import { getSessionToken } from '../utils/providerUtils';
import {
getSessionToken,
setSessionToken,
removeSessionToken,
} from '../utils/providerUtils';
import { formatDate } from '../utils/dateTimeUtils';
import { formatBytes } from '../utils/stringUtils';
import FileUpload from '../assets/file_upload.svg';
import WebsiteCollect from '../assets/website_collect.svg';
import Dropdown from '../components/Dropdown';
import Input from '../components/Input';
import ToggleSwitch from '../components/ToggleSwitch';
@@ -377,7 +384,8 @@ function Upload({
data?.find(
(d: Doc) => d.type?.toLowerCase() === 'local',
),
));
),
);
});
setProgress(
(progress) =>

View File

@@ -3,7 +3,6 @@
* Follows the convention: {provider}_session_token
*/
export const getSessionToken = (provider: string): string | null => {
return localStorage.getItem(`${provider}_session_token`);
};