Compare commits
1 Commits
pr/1988
...
dependabot
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
37c672b891 |
2
.github/workflows/bandit.yaml
vendored
@@ -21,7 +21,7 @@ jobs:
|
|||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v6
|
||||||
with:
|
with:
|
||||||
python-version: '3.12'
|
python-version: '3.12'
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
|
|||||||
2
.github/workflows/pytest.yml
vendored
@@ -10,7 +10,7 @@ jobs:
|
|||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
- name: Set up Python ${{ matrix.python-version }}
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v6
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python-version }}
|
python-version: ${{ matrix.python-version }}
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
|
|||||||
33
.vscode/launch.json
vendored
@@ -2,11 +2,15 @@
|
|||||||
"version": "0.2.0",
|
"version": "0.2.0",
|
||||||
"configurations": [
|
"configurations": [
|
||||||
{
|
{
|
||||||
"name": "Frontend Debug (npm)",
|
"name": "Docker Debug Frontend",
|
||||||
"type": "node-terminal",
|
|
||||||
"request": "launch",
|
"request": "launch",
|
||||||
"command": "npm run dev",
|
"type": "chrome",
|
||||||
"cwd": "${workspaceFolder}/frontend"
|
"preLaunchTask": "docker-compose: debug:frontend",
|
||||||
|
"url": "http://127.0.0.1:5173",
|
||||||
|
"webRoot": "${workspaceFolder}/frontend",
|
||||||
|
"skipFiles": [
|
||||||
|
"<node_internals>/**"
|
||||||
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "Flask Debugger",
|
"name": "Flask Debugger",
|
||||||
@@ -45,27 +49,6 @@
|
|||||||
"--pool=solo"
|
"--pool=solo"
|
||||||
],
|
],
|
||||||
"cwd": "${workspaceFolder}"
|
"cwd": "${workspaceFolder}"
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Dev Containers (Mongo + Redis)",
|
|
||||||
"type": "node-terminal",
|
|
||||||
"request": "launch",
|
|
||||||
"command": "docker compose -f deployment/docker-compose-dev.yaml up --build",
|
|
||||||
"cwd": "${workspaceFolder}"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"compounds": [
|
|
||||||
{
|
|
||||||
"name": "DocsGPT: Full Stack",
|
|
||||||
"configurations": [
|
|
||||||
"Frontend Debug (npm)",
|
|
||||||
"Flask Debugger",
|
|
||||||
"Celery Debugger"
|
|
||||||
],
|
|
||||||
"presentation": {
|
|
||||||
"group": "DocsGPT",
|
|
||||||
"order": 1
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
21
.vscode/tasks.json
vendored
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
{
|
||||||
|
"version": "2.0.0",
|
||||||
|
"tasks": [
|
||||||
|
{
|
||||||
|
"type": "docker-compose",
|
||||||
|
"label": "docker-compose: debug:frontend",
|
||||||
|
"dockerCompose": {
|
||||||
|
"up": {
|
||||||
|
"detached": true,
|
||||||
|
"services": [
|
||||||
|
"frontend"
|
||||||
|
],
|
||||||
|
"build": true
|
||||||
|
},
|
||||||
|
"files": [
|
||||||
|
"${workspaceFolder}/docker-compose.yaml"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
@@ -55,11 +55,9 @@
|
|||||||
- [x] Agent optimisations (May 2025)
|
- [x] Agent optimisations (May 2025)
|
||||||
- [x] Filesystem sources update (July 2025)
|
- [x] Filesystem sources update (July 2025)
|
||||||
- [x] Json Responses (August 2025)
|
- [x] Json Responses (August 2025)
|
||||||
- [x] MCP support (August 2025)
|
- [ ] Sharepoint integration (August 2025)
|
||||||
- [x] Google Drive integration (September 2025)
|
- [ ] MCP support (August 2025)
|
||||||
- [ ] Add OAuth 2.0 authentication for MCP (September 2025)
|
- [ ] Add OAuth 2.0 authentication for tools and sources (August 2025)
|
||||||
- [ ] Sharepoint integration (October 2025)
|
|
||||||
- [ ] Deep Agents (October 2025)
|
|
||||||
- [ ] Agent scheduling
|
- [ ] Agent scheduling
|
||||||
|
|
||||||
You can find our full roadmap [here](https://github.com/orgs/arc53/projects/2). Please don't hesitate to contribute or create issues, it helps us improve DocsGPT!
|
You can find our full roadmap [here](https://github.com/orgs/arc53/projects/2). Please don't hesitate to contribute or create issues, it helps us improve DocsGPT!
|
||||||
|
|||||||
@@ -140,28 +140,28 @@ class BaseAgent(ABC):
|
|||||||
tool_id, action_name, call_args = parser.parse_args(call)
|
tool_id, action_name, call_args = parser.parse_args(call)
|
||||||
|
|
||||||
call_id = getattr(call, "id", None) or str(uuid.uuid4())
|
call_id = getattr(call, "id", None) or str(uuid.uuid4())
|
||||||
|
|
||||||
# Check if parsing failed
|
# Check if parsing failed
|
||||||
if tool_id is None or action_name is None:
|
if tool_id is None or action_name is None:
|
||||||
error_message = f"Error: Failed to parse LLM tool call. Tool name: {getattr(call, 'name', 'unknown')}"
|
error_message = f"Error: Failed to parse LLM tool call. Tool name: {getattr(call, 'name', 'unknown')}"
|
||||||
logger.error(error_message)
|
logger.error(error_message)
|
||||||
|
|
||||||
tool_call_data = {
|
tool_call_data = {
|
||||||
"tool_name": "unknown",
|
"tool_name": "unknown",
|
||||||
"call_id": call_id,
|
"call_id": call_id,
|
||||||
"action_name": getattr(call, "name", "unknown"),
|
"action_name": getattr(call, 'name', 'unknown'),
|
||||||
"arguments": call_args or {},
|
"arguments": call_args or {},
|
||||||
"result": f"Failed to parse tool call. Invalid tool name format: {getattr(call, 'name', 'unknown')}",
|
"result": f"Failed to parse tool call. Invalid tool name format: {getattr(call, 'name', 'unknown')}",
|
||||||
}
|
}
|
||||||
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
|
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
|
||||||
self.tool_calls.append(tool_call_data)
|
self.tool_calls.append(tool_call_data)
|
||||||
return "Failed to parse tool call.", call_id
|
return "Failed to parse tool call.", call_id
|
||||||
|
|
||||||
# Check if tool_id exists in available tools
|
# Check if tool_id exists in available tools
|
||||||
if tool_id not in tools_dict:
|
if tool_id not in tools_dict:
|
||||||
error_message = f"Error: Tool ID '{tool_id}' extracted from LLM call not found in available tools_dict. Available IDs: {list(tools_dict.keys())}"
|
error_message = f"Error: Tool ID '{tool_id}' extracted from LLM call not found in available tools_dict. Available IDs: {list(tools_dict.keys())}"
|
||||||
logger.error(error_message)
|
logger.error(error_message)
|
||||||
|
|
||||||
# Return error result
|
# Return error result
|
||||||
tool_call_data = {
|
tool_call_data = {
|
||||||
"tool_name": "unknown",
|
"tool_name": "unknown",
|
||||||
@@ -173,7 +173,7 @@ class BaseAgent(ABC):
|
|||||||
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
|
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
|
||||||
self.tool_calls.append(tool_call_data)
|
self.tool_calls.append(tool_call_data)
|
||||||
return f"Tool with ID {tool_id} not found.", call_id
|
return f"Tool with ID {tool_id} not found.", call_id
|
||||||
|
|
||||||
tool_call_data = {
|
tool_call_data = {
|
||||||
"tool_name": tools_dict[tool_id]["name"],
|
"tool_name": tools_dict[tool_id]["name"],
|
||||||
"call_id": call_id,
|
"call_id": call_id,
|
||||||
@@ -225,7 +225,6 @@ class BaseAgent(ABC):
|
|||||||
if tool_data["name"] == "api_tool"
|
if tool_data["name"] == "api_tool"
|
||||||
else tool_data["config"]
|
else tool_data["config"]
|
||||||
),
|
),
|
||||||
user_id=self.user, # Pass user ID for MCP tools credential decryption
|
|
||||||
)
|
)
|
||||||
if tool_data["name"] == "api_tool":
|
if tool_data["name"] == "api_tool":
|
||||||
print(
|
print(
|
||||||
@@ -264,15 +263,7 @@ class BaseAgent(ABC):
|
|||||||
query: str,
|
query: str,
|
||||||
retrieved_data: List[Dict],
|
retrieved_data: List[Dict],
|
||||||
) -> List[Dict]:
|
) -> List[Dict]:
|
||||||
docs_with_filenames = []
|
docs_together = "\n".join([doc["text"] for doc in retrieved_data])
|
||||||
for doc in retrieved_data:
|
|
||||||
filename = doc.get("filename") or doc.get("title") or doc.get("source")
|
|
||||||
if filename:
|
|
||||||
chunk_header = str(filename)
|
|
||||||
docs_with_filenames.append(f"{chunk_header}\n{doc['text']}")
|
|
||||||
else:
|
|
||||||
docs_with_filenames.append(doc["text"])
|
|
||||||
docs_together = "\n\n".join(docs_with_filenames)
|
|
||||||
p_chat_combine = system_prompt.replace("{summaries}", docs_together)
|
p_chat_combine = system_prompt.replace("{summaries}", docs_together)
|
||||||
messages_combine = [{"role": "system", "content": p_chat_combine}]
|
messages_combine = [{"role": "system", "content": p_chat_combine}]
|
||||||
|
|
||||||
|
|||||||
@@ -1,861 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
import base64
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import time
|
|
||||||
from typing import Any, Dict, List, Optional
|
|
||||||
from urllib.parse import parse_qs, urlparse
|
|
||||||
|
|
||||||
from application.agents.tools.base import Tool
|
|
||||||
from application.api.user.tasks import mcp_oauth_status_task, mcp_oauth_task
|
|
||||||
from application.cache import get_redis_instance
|
|
||||||
|
|
||||||
from application.core.mongo_db import MongoDB
|
|
||||||
|
|
||||||
from application.core.settings import settings
|
|
||||||
|
|
||||||
from application.security.encryption import decrypt_credentials
|
|
||||||
from fastmcp import Client
|
|
||||||
from fastmcp.client.auth import BearerAuth
|
|
||||||
from fastmcp.client.transports import (
|
|
||||||
SSETransport,
|
|
||||||
StdioTransport,
|
|
||||||
StreamableHttpTransport,
|
|
||||||
)
|
|
||||||
from mcp.client.auth import OAuthClientProvider, TokenStorage
|
|
||||||
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken
|
|
||||||
|
|
||||||
from pydantic import AnyHttpUrl, ValidationError
|
|
||||||
from redis import Redis
|
|
||||||
|
|
||||||
mongo = MongoDB.get_client()
|
|
||||||
db = mongo[settings.MONGO_DB_NAME]
|
|
||||||
|
|
||||||
_mcp_clients_cache = {}
|
|
||||||
|
|
||||||
|
|
||||||
class MCPTool(Tool):
|
|
||||||
"""
|
|
||||||
MCP Tool
|
|
||||||
Connect to remote Model Context Protocol (MCP) servers to access dynamic tools and resources. Supports various authentication methods and provides secure access to external services through the MCP protocol.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config: Dict[str, Any], user_id: Optional[str] = None):
|
|
||||||
"""
|
|
||||||
Initialize the MCP Tool with configuration.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config: Dictionary containing MCP server configuration:
|
|
||||||
- server_url: URL of the remote MCP server
|
|
||||||
- transport_type: Transport type (auto, sse, http, stdio)
|
|
||||||
- auth_type: Type of authentication (bearer, oauth, api_key, basic, none)
|
|
||||||
- encrypted_credentials: Encrypted credentials (if available)
|
|
||||||
- timeout: Request timeout in seconds (default: 30)
|
|
||||||
- headers: Custom headers for requests
|
|
||||||
- command: Command for STDIO transport
|
|
||||||
- args: Arguments for STDIO transport
|
|
||||||
- oauth_scopes: OAuth scopes for oauth auth type
|
|
||||||
- oauth_client_name: OAuth client name for oauth auth type
|
|
||||||
user_id: User ID for decrypting credentials (required if encrypted_credentials exist)
|
|
||||||
"""
|
|
||||||
self.config = config
|
|
||||||
self.user_id = user_id
|
|
||||||
self.server_url = config.get("server_url", "")
|
|
||||||
self.transport_type = config.get("transport_type", "auto")
|
|
||||||
self.auth_type = config.get("auth_type", "none")
|
|
||||||
self.timeout = config.get("timeout", 30)
|
|
||||||
self.custom_headers = config.get("headers", {})
|
|
||||||
|
|
||||||
self.auth_credentials = {}
|
|
||||||
if config.get("encrypted_credentials") and user_id:
|
|
||||||
self.auth_credentials = decrypt_credentials(
|
|
||||||
config["encrypted_credentials"], user_id
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.auth_credentials = config.get("auth_credentials", {})
|
|
||||||
self.oauth_scopes = config.get("oauth_scopes", [])
|
|
||||||
self.oauth_task_id = config.get("oauth_task_id", None)
|
|
||||||
self.oauth_client_name = config.get("oauth_client_name", "DocsGPT-MCP")
|
|
||||||
self.redirect_uri = f"{settings.API_URL}/api/mcp_server/callback"
|
|
||||||
|
|
||||||
self.available_tools = []
|
|
||||||
self._cache_key = self._generate_cache_key()
|
|
||||||
self._client = None
|
|
||||||
|
|
||||||
# Only validate and setup if server_url is provided and not OAuth
|
|
||||||
|
|
||||||
if self.server_url and self.auth_type != "oauth":
|
|
||||||
self._setup_client()
|
|
||||||
|
|
||||||
def _generate_cache_key(self) -> str:
|
|
||||||
"""Generate a unique cache key for this MCP server configuration."""
|
|
||||||
auth_key = ""
|
|
||||||
if self.auth_type == "oauth":
|
|
||||||
scopes_str = ",".join(self.oauth_scopes) if self.oauth_scopes else "none"
|
|
||||||
auth_key = f"oauth:{self.oauth_client_name}:{scopes_str}"
|
|
||||||
elif self.auth_type in ["bearer"]:
|
|
||||||
token = self.auth_credentials.get(
|
|
||||||
"bearer_token", ""
|
|
||||||
) or self.auth_credentials.get("access_token", "")
|
|
||||||
auth_key = f"bearer:{token[:10]}..." if token else "bearer:none"
|
|
||||||
elif self.auth_type == "api_key":
|
|
||||||
api_key = self.auth_credentials.get("api_key", "")
|
|
||||||
auth_key = f"apikey:{api_key[:10]}..." if api_key else "apikey:none"
|
|
||||||
elif self.auth_type == "basic":
|
|
||||||
username = self.auth_credentials.get("username", "")
|
|
||||||
auth_key = f"basic:{username}"
|
|
||||||
else:
|
|
||||||
auth_key = "none"
|
|
||||||
return f"{self.server_url}#{self.transport_type}#{auth_key}"
|
|
||||||
|
|
||||||
def _setup_client(self):
|
|
||||||
"""Setup FastMCP client with proper transport and authentication."""
|
|
||||||
global _mcp_clients_cache
|
|
||||||
if self._cache_key in _mcp_clients_cache:
|
|
||||||
cached_data = _mcp_clients_cache[self._cache_key]
|
|
||||||
if time.time() - cached_data["created_at"] < 1800:
|
|
||||||
self._client = cached_data["client"]
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
del _mcp_clients_cache[self._cache_key]
|
|
||||||
transport = self._create_transport()
|
|
||||||
auth = None
|
|
||||||
|
|
||||||
if self.auth_type == "oauth":
|
|
||||||
redis_client = get_redis_instance()
|
|
||||||
auth = DocsGPTOAuth(
|
|
||||||
mcp_url=self.server_url,
|
|
||||||
scopes=self.oauth_scopes,
|
|
||||||
redis_client=redis_client,
|
|
||||||
redirect_uri=self.redirect_uri,
|
|
||||||
task_id=self.oauth_task_id,
|
|
||||||
db=db,
|
|
||||||
user_id=self.user_id,
|
|
||||||
)
|
|
||||||
elif self.auth_type == "bearer":
|
|
||||||
token = self.auth_credentials.get(
|
|
||||||
"bearer_token", ""
|
|
||||||
) or self.auth_credentials.get("access_token", "")
|
|
||||||
if token:
|
|
||||||
auth = BearerAuth(token)
|
|
||||||
self._client = Client(transport, auth=auth)
|
|
||||||
_mcp_clients_cache[self._cache_key] = {
|
|
||||||
"client": self._client,
|
|
||||||
"created_at": time.time(),
|
|
||||||
}
|
|
||||||
|
|
||||||
def _create_transport(self):
|
|
||||||
"""Create appropriate transport based on configuration."""
|
|
||||||
headers = {"Content-Type": "application/json", "User-Agent": "DocsGPT-MCP/1.0"}
|
|
||||||
headers.update(self.custom_headers)
|
|
||||||
|
|
||||||
if self.auth_type == "api_key":
|
|
||||||
api_key = self.auth_credentials.get("api_key", "")
|
|
||||||
header_name = self.auth_credentials.get("api_key_header", "X-API-Key")
|
|
||||||
if api_key:
|
|
||||||
headers[header_name] = api_key
|
|
||||||
elif self.auth_type == "basic":
|
|
||||||
username = self.auth_credentials.get("username", "")
|
|
||||||
password = self.auth_credentials.get("password", "")
|
|
||||||
if username and password:
|
|
||||||
credentials = base64.b64encode(
|
|
||||||
f"{username}:{password}".encode()
|
|
||||||
).decode()
|
|
||||||
headers["Authorization"] = f"Basic {credentials}"
|
|
||||||
if self.transport_type == "auto":
|
|
||||||
if "sse" in self.server_url.lower() or self.server_url.endswith("/sse"):
|
|
||||||
transport_type = "sse"
|
|
||||||
else:
|
|
||||||
transport_type = "http"
|
|
||||||
else:
|
|
||||||
transport_type = self.transport_type
|
|
||||||
if transport_type == "sse":
|
|
||||||
headers.update({"Accept": "text/event-stream", "Cache-Control": "no-cache"})
|
|
||||||
return SSETransport(url=self.server_url, headers=headers)
|
|
||||||
elif transport_type == "http":
|
|
||||||
return StreamableHttpTransport(url=self.server_url, headers=headers)
|
|
||||||
elif transport_type == "stdio":
|
|
||||||
command = self.config.get("command", "python")
|
|
||||||
args = self.config.get("args", [])
|
|
||||||
env = self.auth_credentials if self.auth_credentials else None
|
|
||||||
return StdioTransport(command=command, args=args, env=env)
|
|
||||||
else:
|
|
||||||
return StreamableHttpTransport(url=self.server_url, headers=headers)
|
|
||||||
|
|
||||||
def _format_tools(self, tools_response) -> List[Dict]:
|
|
||||||
"""Format tools response to match expected format."""
|
|
||||||
if hasattr(tools_response, "tools"):
|
|
||||||
tools = tools_response.tools
|
|
||||||
elif isinstance(tools_response, list):
|
|
||||||
tools = tools_response
|
|
||||||
else:
|
|
||||||
tools = []
|
|
||||||
tools_dict = []
|
|
||||||
for tool in tools:
|
|
||||||
if hasattr(tool, "name"):
|
|
||||||
tool_dict = {
|
|
||||||
"name": tool.name,
|
|
||||||
"description": tool.description,
|
|
||||||
}
|
|
||||||
if hasattr(tool, "inputSchema"):
|
|
||||||
tool_dict["inputSchema"] = tool.inputSchema
|
|
||||||
tools_dict.append(tool_dict)
|
|
||||||
elif isinstance(tool, dict):
|
|
||||||
tools_dict.append(tool)
|
|
||||||
else:
|
|
||||||
if hasattr(tool, "model_dump"):
|
|
||||||
tools_dict.append(tool.model_dump())
|
|
||||||
else:
|
|
||||||
tools_dict.append({"name": str(tool), "description": ""})
|
|
||||||
return tools_dict
|
|
||||||
|
|
||||||
async def _execute_with_client(self, operation: str, *args, **kwargs):
|
|
||||||
"""Execute operation with FastMCP client."""
|
|
||||||
if not self._client:
|
|
||||||
raise Exception("FastMCP client not initialized")
|
|
||||||
async with self._client:
|
|
||||||
if operation == "ping":
|
|
||||||
return await self._client.ping()
|
|
||||||
elif operation == "list_tools":
|
|
||||||
tools_response = await self._client.list_tools()
|
|
||||||
self.available_tools = self._format_tools(tools_response)
|
|
||||||
return self.available_tools
|
|
||||||
elif operation == "call_tool":
|
|
||||||
tool_name = args[0]
|
|
||||||
tool_args = kwargs
|
|
||||||
return await self._client.call_tool(tool_name, tool_args)
|
|
||||||
elif operation == "list_resources":
|
|
||||||
return await self._client.list_resources()
|
|
||||||
elif operation == "list_prompts":
|
|
||||||
return await self._client.list_prompts()
|
|
||||||
else:
|
|
||||||
raise Exception(f"Unknown operation: {operation}")
|
|
||||||
|
|
||||||
def _run_async_operation(self, operation: str, *args, **kwargs):
|
|
||||||
"""Run async operation in sync context."""
|
|
||||||
try:
|
|
||||||
try:
|
|
||||||
loop = asyncio.get_running_loop()
|
|
||||||
import concurrent.futures
|
|
||||||
|
|
||||||
def run_in_thread():
|
|
||||||
new_loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(new_loop)
|
|
||||||
try:
|
|
||||||
return new_loop.run_until_complete(
|
|
||||||
self._execute_with_client(operation, *args, **kwargs)
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
new_loop.close()
|
|
||||||
|
|
||||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
|
||||||
future = executor.submit(run_in_thread)
|
|
||||||
return future.result(timeout=self.timeout)
|
|
||||||
except RuntimeError:
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
try:
|
|
||||||
return loop.run_until_complete(
|
|
||||||
self._execute_with_client(operation, *args, **kwargs)
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
loop.close()
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error occurred while running async operation: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def discover_tools(self) -> List[Dict]:
|
|
||||||
"""
|
|
||||||
Discover available tools from the MCP server using FastMCP.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of tool definitions from the server
|
|
||||||
"""
|
|
||||||
if not self.server_url:
|
|
||||||
return []
|
|
||||||
if not self._client:
|
|
||||||
self._setup_client()
|
|
||||||
try:
|
|
||||||
tools = self._run_async_operation("list_tools")
|
|
||||||
self.available_tools = tools
|
|
||||||
return self.available_tools
|
|
||||||
except Exception as e:
|
|
||||||
raise Exception(f"Failed to discover tools from MCP server: {str(e)}")
|
|
||||||
|
|
||||||
def execute_action(self, action_name: str, **kwargs) -> Any:
|
|
||||||
"""
|
|
||||||
Execute an action on the remote MCP server using FastMCP.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
action_name: Name of the action to execute
|
|
||||||
**kwargs: Parameters for the action
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Result from the MCP server
|
|
||||||
"""
|
|
||||||
if not self.server_url:
|
|
||||||
raise Exception("No MCP server configured")
|
|
||||||
if not self._client:
|
|
||||||
self._setup_client()
|
|
||||||
cleaned_kwargs = {}
|
|
||||||
for key, value in kwargs.items():
|
|
||||||
if value == "" or value is None:
|
|
||||||
continue
|
|
||||||
cleaned_kwargs[key] = value
|
|
||||||
try:
|
|
||||||
result = self._run_async_operation(
|
|
||||||
"call_tool", action_name, **cleaned_kwargs
|
|
||||||
)
|
|
||||||
return self._format_result(result)
|
|
||||||
except Exception as e:
|
|
||||||
raise Exception(f"Failed to execute action '{action_name}': {str(e)}")
|
|
||||||
|
|
||||||
def _format_result(self, result) -> Dict:
|
|
||||||
"""Format FastMCP result to match expected format."""
|
|
||||||
if hasattr(result, "content"):
|
|
||||||
content_list = []
|
|
||||||
for content_item in result.content:
|
|
||||||
if hasattr(content_item, "text"):
|
|
||||||
content_list.append({"type": "text", "text": content_item.text})
|
|
||||||
elif hasattr(content_item, "data"):
|
|
||||||
content_list.append({"type": "data", "data": content_item.data})
|
|
||||||
else:
|
|
||||||
content_list.append(
|
|
||||||
{"type": "unknown", "content": str(content_item)}
|
|
||||||
)
|
|
||||||
return {
|
|
||||||
"content": content_list,
|
|
||||||
"isError": getattr(result, "isError", False),
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
return result
|
|
||||||
|
|
||||||
def test_connection(self) -> Dict:
|
|
||||||
"""
|
|
||||||
Test the connection to the MCP server and validate functionality.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary with connection test results including tool count
|
|
||||||
"""
|
|
||||||
if not self.server_url:
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"message": "No MCP server URL configured",
|
|
||||||
"tools_count": 0,
|
|
||||||
"transport_type": self.transport_type,
|
|
||||||
"auth_type": self.auth_type,
|
|
||||||
"error_type": "ConfigurationError",
|
|
||||||
}
|
|
||||||
if not self._client:
|
|
||||||
self._setup_client()
|
|
||||||
try:
|
|
||||||
if self.auth_type == "oauth":
|
|
||||||
return self._test_oauth_connection()
|
|
||||||
else:
|
|
||||||
return self._test_regular_connection()
|
|
||||||
except Exception as e:
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"message": f"Connection failed: {str(e)}",
|
|
||||||
"tools_count": 0,
|
|
||||||
"transport_type": self.transport_type,
|
|
||||||
"auth_type": self.auth_type,
|
|
||||||
"error_type": type(e).__name__,
|
|
||||||
}
|
|
||||||
|
|
||||||
def _test_regular_connection(self) -> Dict:
|
|
||||||
"""Test connection for non-OAuth auth types."""
|
|
||||||
try:
|
|
||||||
self._run_async_operation("ping")
|
|
||||||
ping_success = True
|
|
||||||
except Exception:
|
|
||||||
ping_success = False
|
|
||||||
tools = self.discover_tools()
|
|
||||||
|
|
||||||
message = f"Successfully connected to MCP server. Found {len(tools)} tools."
|
|
||||||
if not ping_success:
|
|
||||||
message += " (Ping not supported, but tool discovery worked)"
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"message": message,
|
|
||||||
"tools_count": len(tools),
|
|
||||||
"transport_type": self.transport_type,
|
|
||||||
"auth_type": self.auth_type,
|
|
||||||
"ping_supported": ping_success,
|
|
||||||
"tools": [tool.get("name", "unknown") for tool in tools],
|
|
||||||
}
|
|
||||||
|
|
||||||
def _test_oauth_connection(self) -> Dict:
|
|
||||||
"""Test connection for OAuth auth type with proper async handling."""
|
|
||||||
try:
|
|
||||||
task = mcp_oauth_task.delay(config=self.config, user=self.user_id)
|
|
||||||
if not task:
|
|
||||||
raise Exception("Failed to start OAuth authentication")
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"requires_oauth": True,
|
|
||||||
"task_id": task.id,
|
|
||||||
"status": "pending",
|
|
||||||
"message": "OAuth flow started",
|
|
||||||
}
|
|
||||||
except Exception as e:
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"message": f"OAuth connection failed: {str(e)}",
|
|
||||||
"tools_count": 0,
|
|
||||||
"transport_type": self.transport_type,
|
|
||||||
"auth_type": self.auth_type,
|
|
||||||
"error_type": type(e).__name__,
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_actions_metadata(self) -> List[Dict]:
|
|
||||||
"""
|
|
||||||
Get metadata for all available actions.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of action metadata dictionaries
|
|
||||||
"""
|
|
||||||
actions = []
|
|
||||||
for tool in self.available_tools:
|
|
||||||
input_schema = (
|
|
||||||
tool.get("inputSchema")
|
|
||||||
or tool.get("input_schema")
|
|
||||||
or tool.get("schema")
|
|
||||||
or tool.get("parameters")
|
|
||||||
)
|
|
||||||
|
|
||||||
parameters_schema = {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {},
|
|
||||||
"required": [],
|
|
||||||
}
|
|
||||||
|
|
||||||
if input_schema:
|
|
||||||
if isinstance(input_schema, dict):
|
|
||||||
if "properties" in input_schema:
|
|
||||||
parameters_schema = {
|
|
||||||
"type": input_schema.get("type", "object"),
|
|
||||||
"properties": input_schema.get("properties", {}),
|
|
||||||
"required": input_schema.get("required", []),
|
|
||||||
}
|
|
||||||
|
|
||||||
for key in ["additionalProperties", "description"]:
|
|
||||||
if key in input_schema:
|
|
||||||
parameters_schema[key] = input_schema[key]
|
|
||||||
else:
|
|
||||||
parameters_schema["properties"] = input_schema
|
|
||||||
action = {
|
|
||||||
"name": tool.get("name", ""),
|
|
||||||
"description": tool.get("description", ""),
|
|
||||||
"parameters": parameters_schema,
|
|
||||||
}
|
|
||||||
actions.append(action)
|
|
||||||
return actions
|
|
||||||
|
|
||||||
def get_config_requirements(self) -> Dict:
|
|
||||||
"""Get configuration requirements for the MCP tool."""
|
|
||||||
return {
|
|
||||||
"server_url": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "URL of the remote MCP server (e.g., https://api.example.com/mcp or https://docs.mcp.cloudflare.com/sse)",
|
|
||||||
"required": True,
|
|
||||||
},
|
|
||||||
"transport_type": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Transport type for connection",
|
|
||||||
"enum": ["auto", "sse", "http", "stdio"],
|
|
||||||
"default": "auto",
|
|
||||||
"required": False,
|
|
||||||
"help": {
|
|
||||||
"auto": "Automatically detect best transport",
|
|
||||||
"sse": "Server-Sent Events (for real-time streaming)",
|
|
||||||
"http": "HTTP streaming (recommended for production)",
|
|
||||||
"stdio": "Standard I/O (for local servers)",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"auth_type": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Authentication type",
|
|
||||||
"enum": ["none", "bearer", "oauth", "api_key", "basic"],
|
|
||||||
"default": "none",
|
|
||||||
"required": True,
|
|
||||||
"help": {
|
|
||||||
"none": "No authentication",
|
|
||||||
"bearer": "Bearer token authentication",
|
|
||||||
"oauth": "OAuth 2.1 authentication (with frontend integration)",
|
|
||||||
"api_key": "API key authentication",
|
|
||||||
"basic": "Basic authentication",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"auth_credentials": {
|
|
||||||
"type": "object",
|
|
||||||
"description": "Authentication credentials (varies by auth_type)",
|
|
||||||
"required": False,
|
|
||||||
"properties": {
|
|
||||||
"bearer_token": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Bearer token for bearer auth",
|
|
||||||
},
|
|
||||||
"access_token": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Access token for OAuth (if pre-obtained)",
|
|
||||||
},
|
|
||||||
"api_key": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "API key for api_key auth",
|
|
||||||
},
|
|
||||||
"api_key_header": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Header name for API key (default: X-API-Key)",
|
|
||||||
},
|
|
||||||
"username": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Username for basic auth",
|
|
||||||
},
|
|
||||||
"password": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Password for basic auth",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"oauth_scopes": {
|
|
||||||
"type": "array",
|
|
||||||
"description": "OAuth scopes to request (for oauth auth_type)",
|
|
||||||
"items": {"type": "string"},
|
|
||||||
"required": False,
|
|
||||||
"default": [],
|
|
||||||
},
|
|
||||||
"oauth_client_name": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Client name for OAuth registration (for oauth auth_type)",
|
|
||||||
"default": "DocsGPT-MCP",
|
|
||||||
"required": False,
|
|
||||||
},
|
|
||||||
"headers": {
|
|
||||||
"type": "object",
|
|
||||||
"description": "Custom headers to send with requests",
|
|
||||||
"required": False,
|
|
||||||
},
|
|
||||||
"timeout": {
|
|
||||||
"type": "integer",
|
|
||||||
"description": "Request timeout in seconds",
|
|
||||||
"default": 30,
|
|
||||||
"minimum": 1,
|
|
||||||
"maximum": 300,
|
|
||||||
"required": False,
|
|
||||||
},
|
|
||||||
"command": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Command to run for STDIO transport (e.g., 'python')",
|
|
||||||
"required": False,
|
|
||||||
},
|
|
||||||
"args": {
|
|
||||||
"type": "array",
|
|
||||||
"description": "Arguments for STDIO command",
|
|
||||||
"items": {"type": "string"},
|
|
||||||
"required": False,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class DocsGPTOAuth(OAuthClientProvider):
|
|
||||||
"""
|
|
||||||
Custom OAuth handler for DocsGPT that uses frontend redirect instead of browser.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
mcp_url: str,
|
|
||||||
redirect_uri: str,
|
|
||||||
redis_client: Redis | None = None,
|
|
||||||
redis_prefix: str = "mcp_oauth:",
|
|
||||||
task_id: str = None,
|
|
||||||
scopes: str | list[str] | None = None,
|
|
||||||
client_name: str = "DocsGPT-MCP",
|
|
||||||
user_id=None,
|
|
||||||
db=None,
|
|
||||||
additional_client_metadata: dict[str, Any] | None = None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Initialize custom OAuth client provider for DocsGPT.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
mcp_url: Full URL to the MCP endpoint
|
|
||||||
redirect_uri: Custom redirect URI for DocsGPT frontend
|
|
||||||
redis_client: Redis client for storing auth state
|
|
||||||
redis_prefix: Prefix for Redis keys
|
|
||||||
task_id: Task ID for tracking auth status
|
|
||||||
scopes: OAuth scopes to request
|
|
||||||
client_name: Name for this client during registration
|
|
||||||
user_id: User ID for token storage
|
|
||||||
db: Database instance for token storage
|
|
||||||
additional_client_metadata: Extra fields for OAuthClientMetadata
|
|
||||||
"""
|
|
||||||
|
|
||||||
self.redirect_uri = redirect_uri
|
|
||||||
self.redis_client = redis_client
|
|
||||||
self.redis_prefix = redis_prefix
|
|
||||||
self.task_id = task_id
|
|
||||||
self.user_id = user_id
|
|
||||||
self.db = db
|
|
||||||
|
|
||||||
parsed_url = urlparse(mcp_url)
|
|
||||||
self.server_base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
|
|
||||||
|
|
||||||
if isinstance(scopes, list):
|
|
||||||
scopes = " ".join(scopes)
|
|
||||||
client_metadata = OAuthClientMetadata(
|
|
||||||
client_name=client_name,
|
|
||||||
redirect_uris=[AnyHttpUrl(redirect_uri)],
|
|
||||||
grant_types=["authorization_code", "refresh_token"],
|
|
||||||
response_types=["code"],
|
|
||||||
scope=scopes,
|
|
||||||
**(additional_client_metadata or {}),
|
|
||||||
)
|
|
||||||
|
|
||||||
storage = DBTokenStorage(
|
|
||||||
server_url=self.server_base_url, user_id=self.user_id, db_client=self.db
|
|
||||||
)
|
|
||||||
|
|
||||||
super().__init__(
|
|
||||||
server_url=self.server_base_url,
|
|
||||||
client_metadata=client_metadata,
|
|
||||||
storage=storage,
|
|
||||||
redirect_handler=self.redirect_handler,
|
|
||||||
callback_handler=self.callback_handler,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.auth_url = None
|
|
||||||
self.extracted_state = None
|
|
||||||
|
|
||||||
def _process_auth_url(self, authorization_url: str) -> tuple[str, str]:
|
|
||||||
"""Process authorization URL to extract state"""
|
|
||||||
try:
|
|
||||||
parsed_url = urlparse(authorization_url)
|
|
||||||
query_params = parse_qs(parsed_url.query)
|
|
||||||
|
|
||||||
state_params = query_params.get("state", [])
|
|
||||||
if state_params:
|
|
||||||
state = state_params[0]
|
|
||||||
else:
|
|
||||||
raise ValueError("No state in auth URL")
|
|
||||||
return authorization_url, state
|
|
||||||
except Exception as e:
|
|
||||||
raise Exception(f"Failed to process auth URL: {e}")
|
|
||||||
|
|
||||||
async def redirect_handler(self, authorization_url: str) -> None:
|
|
||||||
"""Store auth URL and state in Redis for frontend to use."""
|
|
||||||
auth_url, state = self._process_auth_url(authorization_url)
|
|
||||||
logging.info(
|
|
||||||
"[DocsGPTOAuth] Processed auth_url: %s, state: %s", auth_url, state
|
|
||||||
)
|
|
||||||
self.auth_url = auth_url
|
|
||||||
self.extracted_state = state
|
|
||||||
|
|
||||||
if self.redis_client and self.extracted_state:
|
|
||||||
key = f"{self.redis_prefix}auth_url:{self.extracted_state}"
|
|
||||||
self.redis_client.setex(key, 600, auth_url)
|
|
||||||
logging.info("[DocsGPTOAuth] Stored auth_url in Redis: %s", key)
|
|
||||||
|
|
||||||
if self.task_id:
|
|
||||||
status_key = f"mcp_oauth_status:{self.task_id}"
|
|
||||||
status_data = {
|
|
||||||
"status": "requires_redirect",
|
|
||||||
"message": "OAuth authorization required",
|
|
||||||
"authorization_url": self.auth_url,
|
|
||||||
"state": self.extracted_state,
|
|
||||||
"requires_oauth": True,
|
|
||||||
"task_id": self.task_id,
|
|
||||||
}
|
|
||||||
self.redis_client.setex(status_key, 600, json.dumps(status_data))
|
|
||||||
|
|
||||||
async def callback_handler(self) -> tuple[str, str | None]:
|
|
||||||
"""Wait for auth code from Redis using the state value."""
|
|
||||||
if not self.redis_client or not self.extracted_state:
|
|
||||||
raise Exception("Redis client or state not configured for OAuth")
|
|
||||||
poll_interval = 1
|
|
||||||
max_wait_time = 300
|
|
||||||
code_key = f"{self.redis_prefix}code:{self.extracted_state}"
|
|
||||||
|
|
||||||
if self.task_id:
|
|
||||||
status_key = f"mcp_oauth_status:{self.task_id}"
|
|
||||||
status_data = {
|
|
||||||
"status": "awaiting_callback",
|
|
||||||
"message": "Waiting for OAuth callback...",
|
|
||||||
"authorization_url": self.auth_url,
|
|
||||||
"state": self.extracted_state,
|
|
||||||
"requires_oauth": True,
|
|
||||||
"task_id": self.task_id,
|
|
||||||
}
|
|
||||||
self.redis_client.setex(status_key, 600, json.dumps(status_data))
|
|
||||||
start_time = time.time()
|
|
||||||
while time.time() - start_time < max_wait_time:
|
|
||||||
code_data = self.redis_client.get(code_key)
|
|
||||||
if code_data:
|
|
||||||
code = code_data.decode()
|
|
||||||
returned_state = self.extracted_state
|
|
||||||
|
|
||||||
self.redis_client.delete(code_key)
|
|
||||||
self.redis_client.delete(
|
|
||||||
f"{self.redis_prefix}auth_url:{self.extracted_state}"
|
|
||||||
)
|
|
||||||
self.redis_client.delete(
|
|
||||||
f"{self.redis_prefix}state:{self.extracted_state}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.task_id:
|
|
||||||
status_data = {
|
|
||||||
"status": "callback_received",
|
|
||||||
"message": "OAuth callback received, completing authentication...",
|
|
||||||
"task_id": self.task_id,
|
|
||||||
}
|
|
||||||
self.redis_client.setex(status_key, 600, json.dumps(status_data))
|
|
||||||
return code, returned_state
|
|
||||||
error_key = f"{self.redis_prefix}error:{self.extracted_state}"
|
|
||||||
error_data = self.redis_client.get(error_key)
|
|
||||||
if error_data:
|
|
||||||
error_msg = error_data.decode()
|
|
||||||
self.redis_client.delete(error_key)
|
|
||||||
self.redis_client.delete(
|
|
||||||
f"{self.redis_prefix}auth_url:{self.extracted_state}"
|
|
||||||
)
|
|
||||||
self.redis_client.delete(
|
|
||||||
f"{self.redis_prefix}state:{self.extracted_state}"
|
|
||||||
)
|
|
||||||
raise Exception(f"OAuth error: {error_msg}")
|
|
||||||
await asyncio.sleep(poll_interval)
|
|
||||||
self.redis_client.delete(f"{self.redis_prefix}auth_url:{self.extracted_state}")
|
|
||||||
self.redis_client.delete(f"{self.redis_prefix}state:{self.extracted_state}")
|
|
||||||
raise Exception("OAuth callback timeout: no code received within 5 minutes")
|
|
||||||
|
|
||||||
|
|
||||||
class DBTokenStorage(TokenStorage):
|
|
||||||
def __init__(self, server_url: str, user_id: str, db_client):
|
|
||||||
self.server_url = server_url
|
|
||||||
self.user_id = user_id
|
|
||||||
self.db_client = db_client
|
|
||||||
self.collection = db_client["connector_sessions"]
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_base_url(url: str) -> str:
|
|
||||||
parsed = urlparse(url)
|
|
||||||
return f"{parsed.scheme}://{parsed.netloc}"
|
|
||||||
|
|
||||||
def get_db_key(self) -> dict:
|
|
||||||
return {
|
|
||||||
"server_url": self.get_base_url(self.server_url),
|
|
||||||
"user_id": self.user_id,
|
|
||||||
}
|
|
||||||
|
|
||||||
async def get_tokens(self) -> OAuthToken | None:
|
|
||||||
doc = await asyncio.to_thread(self.collection.find_one, self.get_db_key())
|
|
||||||
if not doc or "tokens" not in doc:
|
|
||||||
return None
|
|
||||||
try:
|
|
||||||
tokens = OAuthToken.model_validate(doc["tokens"])
|
|
||||||
return tokens
|
|
||||||
except ValidationError as e:
|
|
||||||
logging.error(f"Could not load tokens: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def set_tokens(self, tokens: OAuthToken) -> None:
|
|
||||||
await asyncio.to_thread(
|
|
||||||
self.collection.update_one,
|
|
||||||
self.get_db_key(),
|
|
||||||
{"$set": {"tokens": tokens.model_dump()}},
|
|
||||||
True,
|
|
||||||
)
|
|
||||||
logging.info(f"Saved tokens for {self.get_base_url(self.server_url)}")
|
|
||||||
|
|
||||||
async def get_client_info(self) -> OAuthClientInformationFull | None:
|
|
||||||
doc = await asyncio.to_thread(self.collection.find_one, self.get_db_key())
|
|
||||||
if not doc or "client_info" not in doc:
|
|
||||||
return None
|
|
||||||
try:
|
|
||||||
client_info = OAuthClientInformationFull.model_validate(doc["client_info"])
|
|
||||||
tokens = await self.get_tokens()
|
|
||||||
if tokens is None:
|
|
||||||
logging.debug(
|
|
||||||
"No tokens found, clearing client info to force fresh registration."
|
|
||||||
)
|
|
||||||
await asyncio.to_thread(
|
|
||||||
self.collection.update_one,
|
|
||||||
self.get_db_key(),
|
|
||||||
{"$unset": {"client_info": ""}},
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
return client_info
|
|
||||||
except ValidationError as e:
|
|
||||||
logging.error(f"Could not load client info: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _serialize_client_info(self, info: dict) -> dict:
|
|
||||||
if "redirect_uris" in info and isinstance(info["redirect_uris"], list):
|
|
||||||
info["redirect_uris"] = [str(u) for u in info["redirect_uris"]]
|
|
||||||
return info
|
|
||||||
|
|
||||||
async def set_client_info(self, client_info: OAuthClientInformationFull) -> None:
|
|
||||||
serialized_info = self._serialize_client_info(client_info.model_dump())
|
|
||||||
await asyncio.to_thread(
|
|
||||||
self.collection.update_one,
|
|
||||||
self.get_db_key(),
|
|
||||||
{"$set": {"client_info": serialized_info}},
|
|
||||||
True,
|
|
||||||
)
|
|
||||||
logging.info(f"Saved client info for {self.get_base_url(self.server_url)}")
|
|
||||||
|
|
||||||
async def clear(self) -> None:
|
|
||||||
await asyncio.to_thread(self.collection.delete_one, self.get_db_key())
|
|
||||||
logging.info(f"Cleared OAuth cache for {self.get_base_url(self.server_url)}")
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def clear_all(cls, db_client) -> None:
|
|
||||||
collection = db_client["connector_sessions"]
|
|
||||||
await asyncio.to_thread(collection.delete_many, {})
|
|
||||||
logging.info("Cleared all OAuth client cache data.")
|
|
||||||
|
|
||||||
|
|
||||||
class MCPOAuthManager:
|
|
||||||
"""Manager for handling MCP OAuth callbacks."""
|
|
||||||
|
|
||||||
def __init__(self, redis_client: Redis | None, redis_prefix: str = "mcp_oauth:"):
|
|
||||||
self.redis_client = redis_client
|
|
||||||
self.redis_prefix = redis_prefix
|
|
||||||
|
|
||||||
def handle_oauth_callback(
|
|
||||||
self, state: str, code: str, error: Optional[str] = None
|
|
||||||
) -> bool:
|
|
||||||
"""
|
|
||||||
Handle OAuth callback from provider.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
state: The state parameter from OAuth callback
|
|
||||||
code: The authorization code from OAuth callback
|
|
||||||
error: Error message if OAuth failed
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if successful, False otherwise
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
if not self.redis_client or not state:
|
|
||||||
raise Exception("Redis client or state not provided")
|
|
||||||
if error:
|
|
||||||
error_key = f"{self.redis_prefix}error:{state}"
|
|
||||||
self.redis_client.setex(error_key, 300, error)
|
|
||||||
raise Exception(f"OAuth error received: {error}")
|
|
||||||
code_key = f"{self.redis_prefix}code:{state}"
|
|
||||||
self.redis_client.setex(code_key, 300, code)
|
|
||||||
|
|
||||||
state_key = f"{self.redis_prefix}state:{state}"
|
|
||||||
self.redis_client.setex(state_key, 300, "completed")
|
|
||||||
|
|
||||||
return True
|
|
||||||
except Exception as e:
|
|
||||||
logging.error(f"Error handling OAuth callback: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def get_oauth_status(self, task_id: str) -> Dict[str, Any]:
|
|
||||||
"""Get current status of OAuth flow using provided task_id."""
|
|
||||||
if not task_id:
|
|
||||||
return {"status": "not_started", "message": "OAuth flow not started"}
|
|
||||||
return mcp_oauth_status_task(task_id)
|
|
||||||
@@ -23,23 +23,16 @@ class ToolManager:
|
|||||||
tool_config = self.config.get(name, {})
|
tool_config = self.config.get(name, {})
|
||||||
self.tools[name] = obj(tool_config)
|
self.tools[name] = obj(tool_config)
|
||||||
|
|
||||||
def load_tool(self, tool_name, tool_config, user_id=None):
|
def load_tool(self, tool_name, tool_config):
|
||||||
self.config[tool_name] = tool_config
|
self.config[tool_name] = tool_config
|
||||||
module = importlib.import_module(f"application.agents.tools.{tool_name}")
|
module = importlib.import_module(f"application.agents.tools.{tool_name}")
|
||||||
for member_name, obj in inspect.getmembers(module, inspect.isclass):
|
for member_name, obj in inspect.getmembers(module, inspect.isclass):
|
||||||
if issubclass(obj, Tool) and obj is not Tool:
|
if issubclass(obj, Tool) and obj is not Tool:
|
||||||
if tool_name == "mcp_tool" and user_id:
|
return obj(tool_config)
|
||||||
return obj(tool_config, user_id)
|
|
||||||
else:
|
|
||||||
return obj(tool_config)
|
|
||||||
|
|
||||||
def execute_action(self, tool_name, action_name, user_id=None, **kwargs):
|
def execute_action(self, tool_name, action_name, **kwargs):
|
||||||
if tool_name not in self.tools:
|
if tool_name not in self.tools:
|
||||||
raise ValueError(f"Tool '{tool_name}' not loaded")
|
raise ValueError(f"Tool '{tool_name}' not loaded")
|
||||||
if tool_name == "mcp_tool" and user_id:
|
|
||||||
tool_config = self.config.get(tool_name, {})
|
|
||||||
tool = self.load_tool(tool_name, tool_config, user_id)
|
|
||||||
return tool.execute_action(action_name, **kwargs)
|
|
||||||
return self.tools[tool_name].execute_action(action_name, **kwargs)
|
return self.tools[tool_name].execute_action(action_name, **kwargs)
|
||||||
|
|
||||||
def get_all_actions_metadata(self):
|
def get_all_actions_metadata(self):
|
||||||
|
|||||||
@@ -69,8 +69,11 @@ class StreamProcessor:
|
|||||||
self.decoded_token.get("sub") if self.decoded_token is not None else None
|
self.decoded_token.get("sub") if self.decoded_token is not None else None
|
||||||
)
|
)
|
||||||
self.conversation_id = self.data.get("conversation_id")
|
self.conversation_id = self.data.get("conversation_id")
|
||||||
self.source = {}
|
self.source = (
|
||||||
self.all_sources = []
|
{"active_docs": self.data["active_docs"]}
|
||||||
|
if "active_docs" in self.data
|
||||||
|
else {}
|
||||||
|
)
|
||||||
self.attachments = []
|
self.attachments = []
|
||||||
self.history = []
|
self.history = []
|
||||||
self.agent_config = {}
|
self.agent_config = {}
|
||||||
@@ -82,8 +85,6 @@ class StreamProcessor:
|
|||||||
|
|
||||||
def initialize(self):
|
def initialize(self):
|
||||||
"""Initialize all required components for processing"""
|
"""Initialize all required components for processing"""
|
||||||
self._configure_agent()
|
|
||||||
self._configure_source()
|
|
||||||
self._configure_retriever()
|
self._configure_retriever()
|
||||||
self._configure_agent()
|
self._configure_agent()
|
||||||
self._load_conversation_history()
|
self._load_conversation_history()
|
||||||
@@ -170,77 +171,13 @@ class StreamProcessor:
|
|||||||
source = data.get("source")
|
source = data.get("source")
|
||||||
if isinstance(source, DBRef):
|
if isinstance(source, DBRef):
|
||||||
source_doc = self.db.dereference(source)
|
source_doc = self.db.dereference(source)
|
||||||
if source_doc:
|
data["source"] = str(source_doc["_id"])
|
||||||
data["source"] = str(source_doc["_id"])
|
data["retriever"] = source_doc.get("retriever", data.get("retriever"))
|
||||||
data["retriever"] = source_doc.get("retriever", data.get("retriever"))
|
data["chunks"] = source_doc.get("chunks", data.get("chunks"))
|
||||||
data["chunks"] = source_doc.get("chunks", data.get("chunks"))
|
|
||||||
else:
|
|
||||||
data["source"] = None
|
|
||||||
elif source == "default":
|
|
||||||
data["source"] = "default"
|
|
||||||
else:
|
else:
|
||||||
data["source"] = None
|
data["source"] = None
|
||||||
# Handle multiple sources
|
|
||||||
|
|
||||||
sources = data.get("sources", [])
|
|
||||||
if sources and isinstance(sources, list):
|
|
||||||
sources_list = []
|
|
||||||
for i, source_ref in enumerate(sources):
|
|
||||||
if source_ref == "default":
|
|
||||||
processed_source = {
|
|
||||||
"id": "default",
|
|
||||||
"retriever": "classic",
|
|
||||||
"chunks": data.get("chunks", "2"),
|
|
||||||
}
|
|
||||||
sources_list.append(processed_source)
|
|
||||||
elif isinstance(source_ref, DBRef):
|
|
||||||
source_doc = self.db.dereference(source_ref)
|
|
||||||
if source_doc:
|
|
||||||
processed_source = {
|
|
||||||
"id": str(source_doc["_id"]),
|
|
||||||
"retriever": source_doc.get("retriever", "classic"),
|
|
||||||
"chunks": source_doc.get("chunks", data.get("chunks", "2")),
|
|
||||||
}
|
|
||||||
sources_list.append(processed_source)
|
|
||||||
data["sources"] = sources_list
|
|
||||||
else:
|
|
||||||
data["sources"] = []
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
def _configure_source(self):
|
|
||||||
"""Configure the source based on agent data"""
|
|
||||||
api_key = self.data.get("api_key") or self.agent_key
|
|
||||||
|
|
||||||
if api_key:
|
|
||||||
agent_data = self._get_data_from_api_key(api_key)
|
|
||||||
|
|
||||||
if agent_data.get("sources") and len(agent_data["sources"]) > 0:
|
|
||||||
source_ids = [
|
|
||||||
source["id"] for source in agent_data["sources"] if source.get("id")
|
|
||||||
]
|
|
||||||
if source_ids:
|
|
||||||
self.source = {"active_docs": source_ids}
|
|
||||||
else:
|
|
||||||
self.source = {}
|
|
||||||
self.all_sources = agent_data["sources"]
|
|
||||||
elif agent_data.get("source"):
|
|
||||||
self.source = {"active_docs": agent_data["source"]}
|
|
||||||
self.all_sources = [
|
|
||||||
{
|
|
||||||
"id": agent_data["source"],
|
|
||||||
"retriever": agent_data.get("retriever", "classic"),
|
|
||||||
}
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
self.source = {}
|
|
||||||
self.all_sources = []
|
|
||||||
return
|
|
||||||
if "active_docs" in self.data:
|
|
||||||
self.source = {"active_docs": self.data["active_docs"]}
|
|
||||||
return
|
|
||||||
self.source = {}
|
|
||||||
self.all_sources = []
|
|
||||||
|
|
||||||
def _configure_agent(self):
|
def _configure_agent(self):
|
||||||
"""Configure the agent based on request data"""
|
"""Configure the agent based on request data"""
|
||||||
agent_id = self.data.get("agent_id")
|
agent_id = self.data.get("agent_id")
|
||||||
@@ -266,13 +203,7 @@ class StreamProcessor:
|
|||||||
if data_key.get("retriever"):
|
if data_key.get("retriever"):
|
||||||
self.retriever_config["retriever_name"] = data_key["retriever"]
|
self.retriever_config["retriever_name"] = data_key["retriever"]
|
||||||
if data_key.get("chunks") is not None:
|
if data_key.get("chunks") is not None:
|
||||||
try:
|
self.retriever_config["chunks"] = data_key["chunks"]
|
||||||
self.retriever_config["chunks"] = int(data_key["chunks"])
|
|
||||||
except (ValueError, TypeError):
|
|
||||||
logger.warning(
|
|
||||||
f"Invalid chunks value: {data_key['chunks']}, using default value 2"
|
|
||||||
)
|
|
||||||
self.retriever_config["chunks"] = 2
|
|
||||||
elif self.agent_key:
|
elif self.agent_key:
|
||||||
data_key = self._get_data_from_api_key(self.agent_key)
|
data_key = self._get_data_from_api_key(self.agent_key)
|
||||||
self.agent_config.update(
|
self.agent_config.update(
|
||||||
@@ -293,13 +224,7 @@ class StreamProcessor:
|
|||||||
if data_key.get("retriever"):
|
if data_key.get("retriever"):
|
||||||
self.retriever_config["retriever_name"] = data_key["retriever"]
|
self.retriever_config["retriever_name"] = data_key["retriever"]
|
||||||
if data_key.get("chunks") is not None:
|
if data_key.get("chunks") is not None:
|
||||||
try:
|
self.retriever_config["chunks"] = data_key["chunks"]
|
||||||
self.retriever_config["chunks"] = int(data_key["chunks"])
|
|
||||||
except (ValueError, TypeError):
|
|
||||||
logger.warning(
|
|
||||||
f"Invalid chunks value: {data_key['chunks']}, using default value 2"
|
|
||||||
)
|
|
||||||
self.retriever_config["chunks"] = 2
|
|
||||||
else:
|
else:
|
||||||
self.agent_config.update(
|
self.agent_config.update(
|
||||||
{
|
{
|
||||||
@@ -318,8 +243,7 @@ class StreamProcessor:
|
|||||||
"token_limit": self.data.get("token_limit", settings.DEFAULT_MAX_HISTORY),
|
"token_limit": self.data.get("token_limit", settings.DEFAULT_MAX_HISTORY),
|
||||||
}
|
}
|
||||||
|
|
||||||
api_key = self.data.get("api_key") or self.agent_key
|
if "isNoneDoc" in self.data and self.data["isNoneDoc"]:
|
||||||
if not api_key and "isNoneDoc" in self.data and self.data["isNoneDoc"]:
|
|
||||||
self.retriever_config["chunks"] = 0
|
self.retriever_config["chunks"] = 0
|
||||||
|
|
||||||
def create_agent(self):
|
def create_agent(self):
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
import base64
|
|
||||||
import datetime
|
import datetime
|
||||||
import json
|
import json
|
||||||
import uuid
|
import logging
|
||||||
|
|
||||||
|
|
||||||
from bson.objectid import ObjectId
|
from bson.objectid import ObjectId
|
||||||
@@ -15,6 +14,8 @@ from flask import (
|
|||||||
from flask_restx import fields, Namespace, Resource
|
from flask_restx import fields, Namespace, Resource
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
from application.api.user.tasks import (
|
from application.api.user.tasks import (
|
||||||
ingest_connector_task,
|
ingest_connector_task,
|
||||||
)
|
)
|
||||||
@@ -172,7 +173,7 @@ class ConnectorSources(Resource):
|
|||||||
return make_response(jsonify({"success": False}), 401)
|
return make_response(jsonify({"success": False}), 401)
|
||||||
user = decoded_token.get("sub")
|
user = decoded_token.get("sub")
|
||||||
try:
|
try:
|
||||||
sources = sources_collection.find({"user": user, "type": "connector:file"}).sort("date", -1)
|
sources = sources_collection.find({"user": user, "type": "connector"}).sort("date", -1)
|
||||||
connector_sources = []
|
connector_sources = []
|
||||||
for source in sources:
|
for source in sources:
|
||||||
connector_sources.append({
|
connector_sources.append({
|
||||||
@@ -234,24 +235,8 @@ class ConnectorAuth(Resource):
|
|||||||
if not ConnectorCreator.is_supported(provider):
|
if not ConnectorCreator.is_supported(provider):
|
||||||
return make_response(jsonify({"success": False, "error": f"Unsupported provider: {provider}"}), 400)
|
return make_response(jsonify({"success": False, "error": f"Unsupported provider: {provider}"}), 400)
|
||||||
|
|
||||||
decoded_token = request.decoded_token
|
import uuid
|
||||||
if not decoded_token:
|
state = str(uuid.uuid4())
|
||||||
return make_response(jsonify({"success": False, "error": "Unauthorized"}), 401)
|
|
||||||
user_id = decoded_token.get('sub')
|
|
||||||
|
|
||||||
now = datetime.datetime.now(datetime.timezone.utc)
|
|
||||||
result = sessions_collection.insert_one({
|
|
||||||
"provider": provider,
|
|
||||||
"user": user_id,
|
|
||||||
"status": "pending",
|
|
||||||
"created_at": now
|
|
||||||
})
|
|
||||||
state_dict = {
|
|
||||||
"provider": provider,
|
|
||||||
"object_id": str(result.inserted_id)
|
|
||||||
}
|
|
||||||
state = base64.urlsafe_b64encode(json.dumps(state_dict).encode()).decode()
|
|
||||||
|
|
||||||
auth = ConnectorCreator.create_auth(provider)
|
auth = ConnectorCreator.create_auth(provider)
|
||||||
authorization_url = auth.get_authorization_url(state=state)
|
authorization_url = auth.get_authorization_url(state=state)
|
||||||
return make_response(jsonify({
|
return make_response(jsonify({
|
||||||
@@ -272,30 +257,25 @@ class ConnectorsCallback(Resource):
|
|||||||
try:
|
try:
|
||||||
from application.parser.connectors.connector_creator import ConnectorCreator
|
from application.parser.connectors.connector_creator import ConnectorCreator
|
||||||
from flask import request, redirect
|
from flask import request, redirect
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
provider = request.args.get('provider', 'google_drive')
|
||||||
authorization_code = request.args.get('code')
|
authorization_code = request.args.get('code')
|
||||||
state = request.args.get('state')
|
_ = request.args.get('state')
|
||||||
error = request.args.get('error')
|
error = request.args.get('error')
|
||||||
|
|
||||||
state_dict = json.loads(base64.urlsafe_b64decode(state.encode()).decode())
|
|
||||||
provider = state_dict["provider"]
|
|
||||||
state_object_id = state_dict["object_id"]
|
|
||||||
|
|
||||||
if error:
|
if error:
|
||||||
if error == "access_denied":
|
return redirect(f"/api/connectors/callback-status?status=error&message=OAuth+error:+{error}.+Please+try+again+and+make+sure+to+grant+all+requested+permissions,+including+offline+access.&provider={provider}")
|
||||||
return redirect(f"/api/connectors/callback-status?status=cancelled&message=Authentication+was+cancelled.+You+can+try+again+if+you'd+like+to+connect+your+account.&provider={provider}")
|
|
||||||
else:
|
|
||||||
current_app.logger.warning(f"OAuth error in callback: {error}")
|
|
||||||
return redirect(f"/api/connectors/callback-status?status=error&message=Authentication+failed.+Please+try+again+and+make+sure+to+grant+all+requested+permissions.&provider={provider}")
|
|
||||||
|
|
||||||
if not authorization_code:
|
if not authorization_code:
|
||||||
return redirect(f"/api/connectors/callback-status?status=error&message=Authentication+failed.+Please+try+again+and+make+sure+to+grant+all+requested+permissions.&provider={provider}")
|
return redirect(f"/api/connectors/callback-status?status=error&message=Authorization+code+not+provided.+Please+complete+the+authorization+process+and+make+sure+to+grant+offline+access.&provider={provider}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
auth = ConnectorCreator.create_auth(provider)
|
auth = ConnectorCreator.create_auth(provider)
|
||||||
token_info = auth.exchange_code_for_tokens(authorization_code)
|
token_info = auth.exchange_code_for_tokens(authorization_code)
|
||||||
|
|
||||||
session_token = str(uuid.uuid4())
|
session_token = str(uuid.uuid4())
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
credentials = auth.create_credentials_from_token_info(token_info)
|
credentials = auth.create_credentials_from_token_info(token_info)
|
||||||
@@ -310,31 +290,30 @@ class ConnectorsCallback(Resource):
|
|||||||
"access_token": token_info.get("access_token"),
|
"access_token": token_info.get("access_token"),
|
||||||
"refresh_token": token_info.get("refresh_token"),
|
"refresh_token": token_info.get("refresh_token"),
|
||||||
"token_uri": token_info.get("token_uri"),
|
"token_uri": token_info.get("token_uri"),
|
||||||
"expiry": token_info.get("expiry")
|
"expiry": token_info.get("expiry"),
|
||||||
|
"scopes": token_info.get("scopes")
|
||||||
}
|
}
|
||||||
|
|
||||||
sessions_collection.find_one_and_update(
|
user_id = request.decoded_token.get("sub") if getattr(request, "decoded_token", None) else None
|
||||||
{"_id": ObjectId(state_object_id), "provider": provider},
|
sessions_collection.insert_one({
|
||||||
{
|
"session_token": session_token,
|
||||||
"$set": {
|
"user": user_id,
|
||||||
"session_token": session_token,
|
"token_info": sanitized_token_info,
|
||||||
"token_info": sanitized_token_info,
|
"created_at": datetime.datetime.now(datetime.timezone.utc),
|
||||||
"user_email": user_email,
|
"user_email": user_email,
|
||||||
"status": "authorized"
|
"provider": provider
|
||||||
}
|
})
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Redirect to success page with session token and user email
|
# Redirect to success page with session token and user email
|
||||||
return redirect(f"/api/connectors/callback-status?status=success&message=Authentication+successful&provider={provider}&session_token={session_token}&user_email={user_email}")
|
return redirect(f"/api/connectors/callback-status?status=success&message=Authentication+successful&provider={provider}&session_token={session_token}&user_email={user_email}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
current_app.logger.error(f"Error exchanging code for tokens: {str(e)}", exc_info=True)
|
current_app.logger.error(f"Error exchanging code for tokens: {str(e)}", exc_info=True)
|
||||||
return redirect(f"/api/connectors/callback-status?status=error&message=Authentication+failed.+Please+try+again+and+make+sure+to+grant+all+requested+permissions.&provider={provider}")
|
return redirect(f"/api/connectors/callback-status?status=error&message=Failed+to+exchange+authorization+code+for+tokens:+{str(e)}&provider={provider}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
current_app.logger.error(f"Error handling connector callback: {e}")
|
current_app.logger.error(f"Error handling connector callback: {e}")
|
||||||
return redirect("/api/connectors/callback-status?status=error&message=Authentication+failed.+Please+try+again+and+make+sure+to+grant+all+requested+permissions.")
|
return redirect(f"/api/connectors/callback-status?status=error&message=Failed+to+complete+connector+authentication:+{str(e)}.+Please+try+again+and+make+sure+to+grant+all+requested+permissions,+including+offline+access.")
|
||||||
|
|
||||||
|
|
||||||
@connectors_ns.route("/api/connectors/refresh")
|
@connectors_ns.route("/api/connectors/refresh")
|
||||||
@@ -360,15 +339,8 @@ class ConnectorRefresh(Resource):
|
|||||||
|
|
||||||
@connectors_ns.route("/api/connectors/files")
|
@connectors_ns.route("/api/connectors/files")
|
||||||
class ConnectorFiles(Resource):
|
class ConnectorFiles(Resource):
|
||||||
@api.expect(api.model("ConnectorFilesModel", {
|
@api.expect(api.model("ConnectorFilesModel", {"provider": fields.String(required=True), "session_token": fields.String(required=True), "folder_id": fields.String(required=False), "limit": fields.Integer(required=False), "page_token": fields.String(required=False)}))
|
||||||
"provider": fields.String(required=True),
|
@api.doc(description="List files from a connector provider (supports pagination)")
|
||||||
"session_token": fields.String(required=True),
|
|
||||||
"folder_id": fields.String(required=False),
|
|
||||||
"limit": fields.Integer(required=False),
|
|
||||||
"page_token": fields.String(required=False),
|
|
||||||
"search_query": fields.String(required=False)
|
|
||||||
}))
|
|
||||||
@api.doc(description="List files from a connector provider (supports pagination and search)")
|
|
||||||
def post(self):
|
def post(self):
|
||||||
try:
|
try:
|
||||||
data = request.get_json()
|
data = request.get_json()
|
||||||
@@ -377,11 +349,10 @@ class ConnectorFiles(Resource):
|
|||||||
folder_id = data.get('folder_id')
|
folder_id = data.get('folder_id')
|
||||||
limit = data.get('limit', 10)
|
limit = data.get('limit', 10)
|
||||||
page_token = data.get('page_token')
|
page_token = data.get('page_token')
|
||||||
search_query = data.get('search_query')
|
|
||||||
|
|
||||||
if not provider or not session_token:
|
if not provider or not session_token:
|
||||||
return make_response(jsonify({"success": False, "error": "provider and session_token are required"}), 400)
|
return make_response(jsonify({"success": False, "error": "provider and session_token are required"}), 400)
|
||||||
|
|
||||||
|
|
||||||
decoded_token = request.decoded_token
|
decoded_token = request.decoded_token
|
||||||
if not decoded_token:
|
if not decoded_token:
|
||||||
return make_response(jsonify({"success": False, "error": "Unauthorized"}), 401)
|
return make_response(jsonify({"success": False, "error": "Unauthorized"}), 401)
|
||||||
@@ -391,17 +362,13 @@ class ConnectorFiles(Resource):
|
|||||||
return make_response(jsonify({"success": False, "error": "Invalid or unauthorized session"}), 401)
|
return make_response(jsonify({"success": False, "error": "Invalid or unauthorized session"}), 401)
|
||||||
|
|
||||||
loader = ConnectorCreator.create_connector(provider, session_token)
|
loader = ConnectorCreator.create_connector(provider, session_token)
|
||||||
input_config = {
|
documents = loader.load_data({
|
||||||
'limit': limit,
|
'limit': limit,
|
||||||
'list_only': True,
|
'list_only': True,
|
||||||
'session_token': session_token,
|
'session_token': session_token,
|
||||||
'folder_id': folder_id,
|
'folder_id': folder_id,
|
||||||
'page_token': page_token
|
'page_token': page_token
|
||||||
}
|
})
|
||||||
if search_query:
|
|
||||||
input_config['search_query'] = search_query
|
|
||||||
|
|
||||||
documents = loader.load_data(input_config)
|
|
||||||
|
|
||||||
files = []
|
files = []
|
||||||
for doc in documents[:limit]:
|
for doc in documents[:limit]:
|
||||||
@@ -419,20 +386,13 @@ class ConnectorFiles(Resource):
|
|||||||
'name': metadata.get('file_name', 'Unknown File'),
|
'name': metadata.get('file_name', 'Unknown File'),
|
||||||
'type': metadata.get('mime_type', 'unknown'),
|
'type': metadata.get('mime_type', 'unknown'),
|
||||||
'size': metadata.get('size', None),
|
'size': metadata.get('size', None),
|
||||||
'modifiedTime': formatted_time,
|
'modifiedTime': formatted_time
|
||||||
'isFolder': metadata.get('is_folder', False)
|
|
||||||
})
|
})
|
||||||
|
|
||||||
next_token = getattr(loader, 'next_page_token', None)
|
next_token = getattr(loader, 'next_page_token', None)
|
||||||
has_more = bool(next_token)
|
has_more = bool(next_token)
|
||||||
|
|
||||||
return make_response(jsonify({
|
return make_response(jsonify({"success": True, "files": files, "total": len(files), "next_page_token": next_token, "has_more": has_more}), 200)
|
||||||
"success": True,
|
|
||||||
"files": files,
|
|
||||||
"total": len(files),
|
|
||||||
"next_page_token": next_token,
|
|
||||||
"has_more": has_more
|
|
||||||
}), 200)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
current_app.logger.error(f"Error loading connector files: {e}")
|
current_app.logger.error(f"Error loading connector files: {e}")
|
||||||
return make_response(jsonify({"success": False, "error": f"Failed to load files: {str(e)}"}), 500)
|
return make_response(jsonify({"success": False, "error": f"Failed to load files: {str(e)}"}), 500)
|
||||||
@@ -441,7 +401,7 @@ class ConnectorFiles(Resource):
|
|||||||
@connectors_ns.route("/api/connectors/validate-session")
|
@connectors_ns.route("/api/connectors/validate-session")
|
||||||
class ConnectorValidateSession(Resource):
|
class ConnectorValidateSession(Resource):
|
||||||
@api.expect(api.model("ConnectorValidateSessionModel", {"provider": fields.String(required=True), "session_token": fields.String(required=True)}))
|
@api.expect(api.model("ConnectorValidateSessionModel", {"provider": fields.String(required=True), "session_token": fields.String(required=True)}))
|
||||||
@api.doc(description="Validate connector session token and return user info and access token")
|
@api.doc(description="Validate connector session token and return user info")
|
||||||
def post(self):
|
def post(self):
|
||||||
try:
|
try:
|
||||||
data = request.get_json()
|
data = request.get_json()
|
||||||
@@ -450,6 +410,7 @@ class ConnectorValidateSession(Resource):
|
|||||||
if not provider or not session_token:
|
if not provider or not session_token:
|
||||||
return make_response(jsonify({"success": False, "error": "provider and session_token are required"}), 400)
|
return make_response(jsonify({"success": False, "error": "provider and session_token are required"}), 400)
|
||||||
|
|
||||||
|
|
||||||
decoded_token = request.decoded_token
|
decoded_token = request.decoded_token
|
||||||
if not decoded_token:
|
if not decoded_token:
|
||||||
return make_response(jsonify({"success": False, "error": "Unauthorized"}), 401)
|
return make_response(jsonify({"success": False, "error": "Unauthorized"}), 401)
|
||||||
@@ -463,36 +424,10 @@ class ConnectorValidateSession(Resource):
|
|||||||
auth = ConnectorCreator.create_auth(provider)
|
auth = ConnectorCreator.create_auth(provider)
|
||||||
is_expired = auth.is_token_expired(token_info)
|
is_expired = auth.is_token_expired(token_info)
|
||||||
|
|
||||||
if is_expired and token_info.get('refresh_token'):
|
|
||||||
try:
|
|
||||||
refreshed_token_info = auth.refresh_access_token(token_info.get('refresh_token'))
|
|
||||||
sanitized_token_info = {
|
|
||||||
"access_token": refreshed_token_info.get("access_token"),
|
|
||||||
"refresh_token": refreshed_token_info.get("refresh_token"),
|
|
||||||
"token_uri": refreshed_token_info.get("token_uri"),
|
|
||||||
"expiry": refreshed_token_info.get("expiry")
|
|
||||||
}
|
|
||||||
sessions_collection.update_one(
|
|
||||||
{"session_token": session_token},
|
|
||||||
{"$set": {"token_info": sanitized_token_info}}
|
|
||||||
)
|
|
||||||
token_info = sanitized_token_info
|
|
||||||
is_expired = False
|
|
||||||
except Exception as refresh_error:
|
|
||||||
current_app.logger.error(f"Failed to refresh token: {refresh_error}")
|
|
||||||
|
|
||||||
if is_expired:
|
|
||||||
return make_response(jsonify({
|
|
||||||
"success": False,
|
|
||||||
"expired": True,
|
|
||||||
"error": "Session token has expired. Please reconnect."
|
|
||||||
}), 401)
|
|
||||||
|
|
||||||
return make_response(jsonify({
|
return make_response(jsonify({
|
||||||
"success": True,
|
"success": True,
|
||||||
"expired": False,
|
"expired": is_expired,
|
||||||
"user_email": session.get('user_email', 'Connected User'),
|
"user_email": session.get('user_email', 'Connected User')
|
||||||
"access_token": token_info.get('access_token')
|
|
||||||
}), 200)
|
}), 200)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
current_app.logger.error(f"Error validating connector session: {e}")
|
current_app.logger.error(f"Error validating connector session: {e}")
|
||||||
@@ -652,23 +587,20 @@ class ConnectorCallbackStatus(Resource):
|
|||||||
.container {{ max-width: 600px; margin: 0 auto; }}
|
.container {{ max-width: 600px; margin: 0 auto; }}
|
||||||
.success {{ color: #4CAF50; }}
|
.success {{ color: #4CAF50; }}
|
||||||
.error {{ color: #F44336; }}
|
.error {{ color: #F44336; }}
|
||||||
.cancelled {{ color: #FF9800; }}
|
|
||||||
</style>
|
</style>
|
||||||
<script>
|
<script>
|
||||||
window.onload = function() {{
|
window.onload = function() {{
|
||||||
const status = "{status}";
|
const status = "{status}";
|
||||||
const sessionToken = "{session_token}";
|
const sessionToken = "{session_token}";
|
||||||
const userEmail = "{user_email}";
|
const userEmail = "{user_email}";
|
||||||
|
|
||||||
if (status === "success" && window.opener) {{
|
if (status === "success" && window.opener) {{
|
||||||
window.opener.postMessage({{
|
window.opener.postMessage({{
|
||||||
type: '{provider}_auth_success',
|
type: '{provider}_auth_success',
|
||||||
session_token: sessionToken,
|
session_token: sessionToken,
|
||||||
user_email: userEmail
|
user_email: userEmail
|
||||||
}}, '*');
|
}}, '*');
|
||||||
|
|
||||||
setTimeout(() => window.close(), 3000);
|
|
||||||
}} else if (status === "cancelled" || status === "error") {{
|
|
||||||
setTimeout(() => window.close(), 3000);
|
setTimeout(() => window.close(), 3000);
|
||||||
}}
|
}}
|
||||||
}};
|
}};
|
||||||
@@ -681,7 +613,7 @@ class ConnectorCallbackStatus(Resource):
|
|||||||
<p>{message}</p>
|
<p>{message}</p>
|
||||||
{f'<p>Connected as: {user_email}</p>' if status == 'success' else ''}
|
{f'<p>Connected as: {user_email}</p>' if status == 'success' else ''}
|
||||||
</div>
|
</div>
|
||||||
<p><small>You can close this window. {f"Your {provider.replace('_', ' ').title()} is now connected and ready to use." if status == 'success' else "Feel free to close this window."}</small></p>
|
<p><small>You can close this window. {f"Your {provider.replace('_', ' ').title()} is now connected and ready to use." if status == 'success' else ''}</small></p>
|
||||||
</div>
|
</div>
|
||||||
</body>
|
</body>
|
||||||
</html>
|
</html>
|
||||||
|
|||||||
@@ -5,8 +5,6 @@ from application.worker import (
|
|||||||
agent_webhook_worker,
|
agent_webhook_worker,
|
||||||
attachment_worker,
|
attachment_worker,
|
||||||
ingest_worker,
|
ingest_worker,
|
||||||
mcp_oauth,
|
|
||||||
mcp_oauth_status,
|
|
||||||
remote_worker,
|
remote_worker,
|
||||||
sync_worker,
|
sync_worker,
|
||||||
)
|
)
|
||||||
@@ -27,7 +25,6 @@ def ingest_remote(self, source_data, job_name, user, loader):
|
|||||||
@celery.task(bind=True)
|
@celery.task(bind=True)
|
||||||
def reingest_source_task(self, source_id, user):
|
def reingest_source_task(self, source_id, user):
|
||||||
from application.worker import reingest_source_worker
|
from application.worker import reingest_source_worker
|
||||||
|
|
||||||
resp = reingest_source_worker(self, source_id, user)
|
resp = reingest_source_worker(self, source_id, user)
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
@@ -63,10 +60,9 @@ def ingest_connector_task(
|
|||||||
retriever="classic",
|
retriever="classic",
|
||||||
operation_mode="upload",
|
operation_mode="upload",
|
||||||
doc_id=None,
|
doc_id=None,
|
||||||
sync_frequency="never",
|
sync_frequency="never"
|
||||||
):
|
):
|
||||||
from application.worker import ingest_connector
|
from application.worker import ingest_connector
|
||||||
|
|
||||||
resp = ingest_connector(
|
resp = ingest_connector(
|
||||||
self,
|
self,
|
||||||
job_name,
|
job_name,
|
||||||
@@ -79,7 +75,7 @@ def ingest_connector_task(
|
|||||||
retriever=retriever,
|
retriever=retriever,
|
||||||
operation_mode=operation_mode,
|
operation_mode=operation_mode,
|
||||||
doc_id=doc_id,
|
doc_id=doc_id,
|
||||||
sync_frequency=sync_frequency,
|
sync_frequency=sync_frequency
|
||||||
)
|
)
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
@@ -98,15 +94,3 @@ def setup_periodic_tasks(sender, **kwargs):
|
|||||||
timedelta(days=30),
|
timedelta(days=30),
|
||||||
schedule_syncs.s("monthly"),
|
schedule_syncs.s("monthly"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@celery.task(bind=True)
|
|
||||||
def mcp_oauth_task(self, config, user):
|
|
||||||
resp = mcp_oauth(self, config, user)
|
|
||||||
return resp
|
|
||||||
|
|
||||||
|
|
||||||
@celery.task(bind=True)
|
|
||||||
def mcp_oauth_status_task(self, task_id):
|
|
||||||
resp = mcp_oauth_status(self, task_id)
|
|
||||||
return resp
|
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ class Settings(BaseSettings):
|
|||||||
"gpt-4o-mini": 128000,
|
"gpt-4o-mini": 128000,
|
||||||
"gpt-3.5-turbo": 4096,
|
"gpt-3.5-turbo": 4096,
|
||||||
"claude-2": 1e5,
|
"claude-2": 1e5,
|
||||||
"gemini-2.5-flash": 1e6,
|
"gemini-2.0-flash-exp": 1e6,
|
||||||
}
|
}
|
||||||
UPLOAD_FOLDER: str = "inputs"
|
UPLOAD_FOLDER: str = "inputs"
|
||||||
PARSE_PDF_AS_IMAGE: bool = False
|
PARSE_PDF_AS_IMAGE: bool = False
|
||||||
@@ -43,7 +43,8 @@ class Settings(BaseSettings):
|
|||||||
# Google Drive integration
|
# Google Drive integration
|
||||||
GOOGLE_CLIENT_ID: Optional[str] = None # Replace with your actual Google OAuth client ID
|
GOOGLE_CLIENT_ID: Optional[str] = None # Replace with your actual Google OAuth client ID
|
||||||
GOOGLE_CLIENT_SECRET: Optional[str] = None# Replace with your actual Google OAuth client secret
|
GOOGLE_CLIENT_SECRET: Optional[str] = None# Replace with your actual Google OAuth client secret
|
||||||
CONNECTOR_REDIRECT_BASE_URI: Optional[str] = "http://127.0.0.1:7091/api/connectors/callback" ##add redirect url as it is to your provider's console(gcp)
|
CONNECTOR_REDIRECT_BASE_URI: Optional[str] = "http://127.0.0.1:7091/api/connectors/callback"
|
||||||
|
##append ?provider={provider_name} in your Provider console like http://127.0.0.1:7091/api/connectors/callback?provider=google_drive
|
||||||
|
|
||||||
|
|
||||||
# LLM Cache
|
# LLM Cache
|
||||||
@@ -95,7 +96,7 @@ class Settings(BaseSettings):
|
|||||||
QDRANT_HOST: Optional[str] = None
|
QDRANT_HOST: Optional[str] = None
|
||||||
QDRANT_PATH: Optional[str] = None
|
QDRANT_PATH: Optional[str] = None
|
||||||
QDRANT_DISTANCE_FUNC: str = "Cosine"
|
QDRANT_DISTANCE_FUNC: str = "Cosine"
|
||||||
|
|
||||||
# PGVector vectorstore config
|
# PGVector vectorstore config
|
||||||
PGVECTOR_CONNECTION_STRING: Optional[str] = None
|
PGVECTOR_CONNECTION_STRING: Optional[str] = None
|
||||||
# Milvus vectorstore config
|
# Milvus vectorstore config
|
||||||
@@ -115,9 +116,6 @@ class Settings(BaseSettings):
|
|||||||
|
|
||||||
JWT_SECRET_KEY: str = ""
|
JWT_SECRET_KEY: str = ""
|
||||||
|
|
||||||
# Encryption settings
|
|
||||||
ENCRYPTION_SECRET_KEY: str = "default-docsgpt-encryption-key"
|
|
||||||
|
|
||||||
|
|
||||||
path = Path(__file__).parent.parent.absolute()
|
path = Path(__file__).parent.parent.absolute()
|
||||||
settings = Settings(_env_file=path.joinpath(".env"), _env_file_encoding="utf-8")
|
settings = Settings(_env_file=path.joinpath(".env"), _env_file_encoding="utf-8")
|
||||||
|
|||||||
@@ -143,7 +143,6 @@ class GoogleLLM(BaseLLM):
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
def _clean_messages_google(self, messages):
|
def _clean_messages_google(self, messages):
|
||||||
"""Convert OpenAI format messages to Google AI format."""
|
|
||||||
cleaned_messages = []
|
cleaned_messages = []
|
||||||
for message in messages:
|
for message in messages:
|
||||||
role = message.get("role")
|
role = message.get("role")
|
||||||
@@ -151,8 +150,6 @@ class GoogleLLM(BaseLLM):
|
|||||||
|
|
||||||
if role == "assistant":
|
if role == "assistant":
|
||||||
role = "model"
|
role = "model"
|
||||||
elif role == "tool":
|
|
||||||
role = "model"
|
|
||||||
|
|
||||||
parts = []
|
parts = []
|
||||||
if role and content is not None:
|
if role and content is not None:
|
||||||
@@ -191,63 +188,11 @@ class GoogleLLM(BaseLLM):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unexpected content type: {type(content)}")
|
raise ValueError(f"Unexpected content type: {type(content)}")
|
||||||
|
|
||||||
if parts:
|
cleaned_messages.append(types.Content(role=role, parts=parts))
|
||||||
cleaned_messages.append(types.Content(role=role, parts=parts))
|
|
||||||
|
|
||||||
return cleaned_messages
|
return cleaned_messages
|
||||||
|
|
||||||
def _clean_schema(self, schema_obj):
|
|
||||||
"""
|
|
||||||
Recursively remove unsupported fields from schema objects
|
|
||||||
and validate required properties.
|
|
||||||
"""
|
|
||||||
if not isinstance(schema_obj, dict):
|
|
||||||
return schema_obj
|
|
||||||
allowed_fields = {
|
|
||||||
"type",
|
|
||||||
"description",
|
|
||||||
"items",
|
|
||||||
"properties",
|
|
||||||
"required",
|
|
||||||
"enum",
|
|
||||||
"pattern",
|
|
||||||
"minimum",
|
|
||||||
"maximum",
|
|
||||||
"nullable",
|
|
||||||
"default",
|
|
||||||
}
|
|
||||||
|
|
||||||
cleaned = {}
|
|
||||||
for key, value in schema_obj.items():
|
|
||||||
if key not in allowed_fields:
|
|
||||||
continue
|
|
||||||
elif key == "type" and isinstance(value, str):
|
|
||||||
cleaned[key] = value.upper()
|
|
||||||
elif isinstance(value, dict):
|
|
||||||
cleaned[key] = self._clean_schema(value)
|
|
||||||
elif isinstance(value, list):
|
|
||||||
cleaned[key] = [self._clean_schema(item) for item in value]
|
|
||||||
else:
|
|
||||||
cleaned[key] = value
|
|
||||||
|
|
||||||
# Validate that required properties actually exist in properties
|
|
||||||
if "required" in cleaned and "properties" in cleaned:
|
|
||||||
valid_required = []
|
|
||||||
properties_keys = set(cleaned["properties"].keys())
|
|
||||||
for required_prop in cleaned["required"]:
|
|
||||||
if required_prop in properties_keys:
|
|
||||||
valid_required.append(required_prop)
|
|
||||||
if valid_required:
|
|
||||||
cleaned["required"] = valid_required
|
|
||||||
else:
|
|
||||||
cleaned.pop("required", None)
|
|
||||||
elif "required" in cleaned and "properties" not in cleaned:
|
|
||||||
cleaned.pop("required", None)
|
|
||||||
|
|
||||||
return cleaned
|
|
||||||
|
|
||||||
def _clean_tools_format(self, tools_list):
|
def _clean_tools_format(self, tools_list):
|
||||||
"""Convert OpenAI format tools to Google AI format."""
|
|
||||||
genai_tools = []
|
genai_tools = []
|
||||||
for tool_data in tools_list:
|
for tool_data in tools_list:
|
||||||
if tool_data["type"] == "function":
|
if tool_data["type"] == "function":
|
||||||
@@ -256,16 +201,18 @@ class GoogleLLM(BaseLLM):
|
|||||||
properties = parameters.get("properties", {})
|
properties = parameters.get("properties", {})
|
||||||
|
|
||||||
if properties:
|
if properties:
|
||||||
cleaned_properties = {}
|
|
||||||
for k, v in properties.items():
|
|
||||||
cleaned_properties[k] = self._clean_schema(v)
|
|
||||||
|
|
||||||
genai_function = dict(
|
genai_function = dict(
|
||||||
name=function["name"],
|
name=function["name"],
|
||||||
description=function["description"],
|
description=function["description"],
|
||||||
parameters={
|
parameters={
|
||||||
"type": "OBJECT",
|
"type": "OBJECT",
|
||||||
"properties": cleaned_properties,
|
"properties": {
|
||||||
|
k: {
|
||||||
|
**v,
|
||||||
|
"type": v["type"].upper() if v["type"] else None,
|
||||||
|
}
|
||||||
|
for k, v in properties.items()
|
||||||
|
},
|
||||||
"required": (
|
"required": (
|
||||||
parameters["required"]
|
parameters["required"]
|
||||||
if "required" in parameters
|
if "required" in parameters
|
||||||
@@ -295,7 +242,6 @@ class GoogleLLM(BaseLLM):
|
|||||||
response_schema=None,
|
response_schema=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""Generate content using Google AI API without streaming."""
|
|
||||||
client = genai.Client(api_key=self.api_key)
|
client = genai.Client(api_key=self.api_key)
|
||||||
if formatting == "openai":
|
if formatting == "openai":
|
||||||
messages = self._clean_messages_google(messages)
|
messages = self._clean_messages_google(messages)
|
||||||
@@ -335,7 +281,6 @@ class GoogleLLM(BaseLLM):
|
|||||||
response_schema=None,
|
response_schema=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""Generate content using Google AI API with streaming."""
|
|
||||||
client = genai.Client(api_key=self.api_key)
|
client = genai.Client(api_key=self.api_key)
|
||||||
if formatting == "openai":
|
if formatting == "openai":
|
||||||
messages = self._clean_messages_google(messages)
|
messages = self._clean_messages_google(messages)
|
||||||
@@ -386,15 +331,12 @@ class GoogleLLM(BaseLLM):
|
|||||||
yield chunk.text
|
yield chunk.text
|
||||||
|
|
||||||
def _supports_tools(self):
|
def _supports_tools(self):
|
||||||
"""Return whether this LLM supports function calling."""
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _supports_structured_output(self):
|
def _supports_structured_output(self):
|
||||||
"""Return whether this LLM supports structured JSON output."""
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def prepare_structured_output_format(self, json_schema):
|
def prepare_structured_output_format(self, json_schema):
|
||||||
"""Convert JSON schema to Google AI structured output format."""
|
|
||||||
if not json_schema:
|
if not json_schema:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
@@ -205,6 +205,7 @@ class LLMHandler(ABC):
|
|||||||
except StopIteration as e:
|
except StopIteration as e:
|
||||||
tool_response, call_id = e.value
|
tool_response, call_id = e.value
|
||||||
break
|
break
|
||||||
|
|
||||||
updated_messages.append(
|
updated_messages.append(
|
||||||
{
|
{
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
@@ -221,36 +222,17 @@ class LLMHandler(ABC):
|
|||||||
)
|
)
|
||||||
|
|
||||||
updated_messages.append(self.create_tool_message(call, tool_response))
|
updated_messages.append(self.create_tool_message(call, tool_response))
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error executing tool: {str(e)}", exc_info=True)
|
logger.error(f"Error executing tool: {str(e)}", exc_info=True)
|
||||||
error_call = ToolCall(
|
updated_messages.append(
|
||||||
id=call.id, name=call.name, arguments=call.arguments
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"content": f"Error executing tool: {str(e)}",
|
||||||
|
"tool_call_id": call.id,
|
||||||
|
}
|
||||||
)
|
)
|
||||||
error_response = f"Error executing tool: {str(e)}"
|
|
||||||
error_message = self.create_tool_message(error_call, error_response)
|
|
||||||
updated_messages.append(error_message)
|
|
||||||
|
|
||||||
call_parts = call.name.split("_")
|
|
||||||
if len(call_parts) >= 2:
|
|
||||||
tool_id = call_parts[-1] # Last part is tool ID (e.g., "1")
|
|
||||||
action_name = "_".join(call_parts[:-1])
|
|
||||||
tool_name = tools_dict.get(tool_id, {}).get("name", "unknown_tool")
|
|
||||||
full_action_name = f"{action_name}_{tool_id}"
|
|
||||||
else:
|
|
||||||
tool_name = "unknown_tool"
|
|
||||||
action_name = call.name
|
|
||||||
full_action_name = call.name
|
|
||||||
yield {
|
|
||||||
"type": "tool_call",
|
|
||||||
"data": {
|
|
||||||
"tool_name": tool_name,
|
|
||||||
"call_id": call.id,
|
|
||||||
"action_name": full_action_name,
|
|
||||||
"arguments": call.arguments,
|
|
||||||
"error": error_response,
|
|
||||||
"status": "error",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
return updated_messages
|
return updated_messages
|
||||||
|
|
||||||
def handle_non_streaming(
|
def handle_non_streaming(
|
||||||
@@ -281,11 +263,13 @@ class LLMHandler(ABC):
|
|||||||
except StopIteration as e:
|
except StopIteration as e:
|
||||||
messages = e.value
|
messages = e.value
|
||||||
break
|
break
|
||||||
|
|
||||||
response = agent.llm.gen(
|
response = agent.llm.gen(
|
||||||
model=agent.gpt_model, messages=messages, tools=agent.tools
|
model=agent.gpt_model, messages=messages, tools=agent.tools
|
||||||
)
|
)
|
||||||
parsed = self.parse_response(response)
|
parsed = self.parse_response(response)
|
||||||
self.llm_calls.append(build_stack_data(agent.llm))
|
self.llm_calls.append(build_stack_data(agent.llm))
|
||||||
|
|
||||||
return parsed.content
|
return parsed.content
|
||||||
|
|
||||||
def handle_streaming(
|
def handle_streaming(
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ class GoogleLLMHandler(LLMHandler):
|
|||||||
finish_reason="stop",
|
finish_reason="stop",
|
||||||
raw_response=response,
|
raw_response=response,
|
||||||
)
|
)
|
||||||
|
|
||||||
if hasattr(response, "candidates"):
|
if hasattr(response, "candidates"):
|
||||||
parts = response.candidates[0].content.parts if response.candidates else []
|
parts = response.candidates[0].content.parts if response.candidates else []
|
||||||
tool_calls = [
|
tool_calls = [
|
||||||
@@ -40,6 +41,7 @@ class GoogleLLMHandler(LLMHandler):
|
|||||||
finish_reason="tool_calls" if tool_calls else "stop",
|
finish_reason="tool_calls" if tool_calls else "stop",
|
||||||
raw_response=response,
|
raw_response=response,
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
tool_calls = []
|
tool_calls = []
|
||||||
if hasattr(response, "function_call"):
|
if hasattr(response, "function_call"):
|
||||||
@@ -59,16 +61,14 @@ class GoogleLLMHandler(LLMHandler):
|
|||||||
|
|
||||||
def create_tool_message(self, tool_call: ToolCall, result: Any) -> Dict:
|
def create_tool_message(self, tool_call: ToolCall, result: Any) -> Dict:
|
||||||
"""Create Google-style tool message."""
|
"""Create Google-style tool message."""
|
||||||
|
from google.genai import types
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"role": "model",
|
"role": "tool",
|
||||||
"content": [
|
"content": [
|
||||||
{
|
types.Part.from_function_response(
|
||||||
"function_response": {
|
name=tool_call.name, response={"result": result}
|
||||||
"name": tool_call.name,
|
).to_json_dict()
|
||||||
"response": {"result": result},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -17,13 +17,14 @@ class GoogleDriveAuth(BaseConnectorAuth):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
SCOPES = [
|
SCOPES = [
|
||||||
'https://www.googleapis.com/auth/drive.file'
|
'https://www.googleapis.com/auth/drive.readonly',
|
||||||
|
'https://www.googleapis.com/auth/drive.metadata.readonly'
|
||||||
]
|
]
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.client_id = settings.GOOGLE_CLIENT_ID
|
self.client_id = settings.GOOGLE_CLIENT_ID
|
||||||
self.client_secret = settings.GOOGLE_CLIENT_SECRET
|
self.client_secret = settings.GOOGLE_CLIENT_SECRET
|
||||||
self.redirect_uri = f"{settings.CONNECTOR_REDIRECT_BASE_URI}"
|
self.redirect_uri = f"{settings.CONNECTOR_REDIRECT_BASE_URI}?provider=google_drive"
|
||||||
|
|
||||||
if not self.client_id or not self.client_secret:
|
if not self.client_id or not self.client_secret:
|
||||||
raise ValueError("Google OAuth credentials not configured. Please set GOOGLE_CLIENT_ID and GOOGLE_CLIENT_SECRET in settings.")
|
raise ValueError("Google OAuth credentials not configured. Please set GOOGLE_CLIENT_ID and GOOGLE_CLIENT_SECRET in settings.")
|
||||||
@@ -49,7 +50,7 @@ class GoogleDriveAuth(BaseConnectorAuth):
|
|||||||
authorization_url, _ = flow.authorization_url(
|
authorization_url, _ = flow.authorization_url(
|
||||||
access_type='offline',
|
access_type='offline',
|
||||||
prompt='consent',
|
prompt='consent',
|
||||||
include_granted_scopes='false',
|
include_granted_scopes='true',
|
||||||
state=state
|
state=state
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -32,10 +32,6 @@ class GoogleDriveLoader(BaseConnectorLoader):
|
|||||||
'text/plain': '.txt',
|
'text/plain': '.txt',
|
||||||
'text/csv': '.csv',
|
'text/csv': '.csv',
|
||||||
'text/html': '.html',
|
'text/html': '.html',
|
||||||
'text/markdown': '.md',
|
|
||||||
'text/x-rst': '.rst',
|
|
||||||
'application/json': '.json',
|
|
||||||
'application/epub+zip': '.epub',
|
|
||||||
'application/rtf': '.rtf',
|
'application/rtf': '.rtf',
|
||||||
'image/jpeg': '.jpg',
|
'image/jpeg': '.jpg',
|
||||||
'image/jpg': '.jpg',
|
'image/jpg': '.jpg',
|
||||||
@@ -124,7 +120,6 @@ class GoogleDriveLoader(BaseConnectorLoader):
|
|||||||
list_only = inputs.get('list_only', False)
|
list_only = inputs.get('list_only', False)
|
||||||
load_content = not list_only
|
load_content = not list_only
|
||||||
page_token = inputs.get('page_token')
|
page_token = inputs.get('page_token')
|
||||||
search_query = inputs.get('search_query')
|
|
||||||
self.next_page_token = None
|
self.next_page_token = None
|
||||||
|
|
||||||
if file_ids:
|
if file_ids:
|
||||||
@@ -133,18 +128,12 @@ class GoogleDriveLoader(BaseConnectorLoader):
|
|||||||
try:
|
try:
|
||||||
doc = self._load_file_by_id(file_id, load_content=load_content)
|
doc = self._load_file_by_id(file_id, load_content=load_content)
|
||||||
if doc:
|
if doc:
|
||||||
if not search_query or (
|
documents.append(doc)
|
||||||
search_query.lower() in doc.extra_info.get('file_name', '').lower()
|
|
||||||
):
|
|
||||||
documents.append(doc)
|
|
||||||
elif hasattr(self, '_credential_refreshed') and self._credential_refreshed:
|
elif hasattr(self, '_credential_refreshed') and self._credential_refreshed:
|
||||||
self._credential_refreshed = False
|
self._credential_refreshed = False
|
||||||
logging.info(f"Retrying load of file {file_id} after credential refresh")
|
logging.info(f"Retrying load of file {file_id} after credential refresh")
|
||||||
doc = self._load_file_by_id(file_id, load_content=load_content)
|
doc = self._load_file_by_id(file_id, load_content=load_content)
|
||||||
if doc and (
|
if doc:
|
||||||
not search_query or
|
|
||||||
search_query.lower() in doc.extra_info.get('file_name', '').lower()
|
|
||||||
):
|
|
||||||
documents.append(doc)
|
documents.append(doc)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error loading file {file_id}: {e}")
|
logging.error(f"Error loading file {file_id}: {e}")
|
||||||
@@ -152,13 +141,7 @@ class GoogleDriveLoader(BaseConnectorLoader):
|
|||||||
else:
|
else:
|
||||||
# Browsing mode: list immediate children of provided folder or root
|
# Browsing mode: list immediate children of provided folder or root
|
||||||
parent_id = folder_id if folder_id else 'root'
|
parent_id = folder_id if folder_id else 'root'
|
||||||
documents = self._list_items_in_parent(
|
documents = self._list_items_in_parent(parent_id, limit=limit, load_content=load_content, page_token=page_token)
|
||||||
parent_id,
|
|
||||||
limit=limit,
|
|
||||||
load_content=load_content,
|
|
||||||
page_token=page_token,
|
|
||||||
search_query=search_query
|
|
||||||
)
|
|
||||||
|
|
||||||
logging.info(f"Loaded {len(documents)} documents from Google Drive")
|
logging.info(f"Loaded {len(documents)} documents from Google Drive")
|
||||||
return documents
|
return documents
|
||||||
@@ -201,18 +184,13 @@ class GoogleDriveLoader(BaseConnectorLoader):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _list_items_in_parent(self, parent_id: str, limit: int = 100, load_content: bool = False, page_token: Optional[str] = None, search_query: Optional[str] = None) -> List[Document]:
|
def _list_items_in_parent(self, parent_id: str, limit: int = 100, load_content: bool = False, page_token: Optional[str] = None) -> List[Document]:
|
||||||
self._ensure_service()
|
self._ensure_service()
|
||||||
|
|
||||||
documents: List[Document] = []
|
documents: List[Document] = []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
query = f"'{parent_id}' in parents and trashed=false"
|
query = f"'{parent_id}' in parents and trashed=false"
|
||||||
|
|
||||||
if search_query:
|
|
||||||
safe_search = search_query.replace("'", "\\'")
|
|
||||||
query += f" and name contains '{safe_search}'"
|
|
||||||
|
|
||||||
next_token_out: Optional[str] = None
|
next_token_out: Optional[str] = None
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
@@ -227,8 +205,7 @@ class GoogleDriveLoader(BaseConnectorLoader):
|
|||||||
q=query,
|
q=query,
|
||||||
fields='nextPageToken,files(id,name,mimeType,size,createdTime,modifiedTime,parents)',
|
fields='nextPageToken,files(id,name,mimeType,size,createdTime,modifiedTime,parents)',
|
||||||
pageToken=page_token,
|
pageToken=page_token,
|
||||||
pageSize=page_size,
|
pageSize=page_size
|
||||||
orderBy='name'
|
|
||||||
).execute()
|
).execute()
|
||||||
|
|
||||||
items = results.get('files', [])
|
items = results.get('files', [])
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ anthropic==0.49.0
|
|||||||
boto3==1.38.18
|
boto3==1.38.18
|
||||||
beautifulsoup4==4.13.4
|
beautifulsoup4==4.13.4
|
||||||
celery==5.4.0
|
celery==5.4.0
|
||||||
cryptography==42.0.8
|
|
||||||
dataclasses-json==0.6.7
|
dataclasses-json==0.6.7
|
||||||
docx2txt==0.8
|
docx2txt==0.8
|
||||||
duckduckgo-search==7.5.2
|
duckduckgo-search==7.5.2
|
||||||
@@ -12,7 +11,6 @@ esprima==4.0.1
|
|||||||
esutils==1.0.1
|
esutils==1.0.1
|
||||||
Flask==3.1.1
|
Flask==3.1.1
|
||||||
faiss-cpu==1.9.0.post1
|
faiss-cpu==1.9.0.post1
|
||||||
fastmcp==2.11.0
|
|
||||||
flask-restx==1.3.0
|
flask-restx==1.3.0
|
||||||
google-genai==1.3.0
|
google-genai==1.3.0
|
||||||
google-api-python-client==2.179.0
|
google-api-python-client==2.179.0
|
||||||
@@ -57,13 +55,13 @@ prompt-toolkit==3.0.51
|
|||||||
protobuf==5.29.3
|
protobuf==5.29.3
|
||||||
psycopg2-binary==2.9.10
|
psycopg2-binary==2.9.10
|
||||||
py==1.11.0
|
py==1.11.0
|
||||||
pydantic
|
pydantic==2.10.6
|
||||||
pydantic-core
|
pydantic-core==2.27.2
|
||||||
pydantic-settings
|
pydantic-settings==2.7.1
|
||||||
pymongo==4.11.3
|
pymongo==4.11.3
|
||||||
pypdf==5.5.0
|
pypdf==5.5.0
|
||||||
python-dateutil==2.9.0.post0
|
python-dateutil==2.9.0.post0
|
||||||
python-dotenv
|
python-dotenv==1.0.1
|
||||||
python-jose==3.4.0
|
python-jose==3.4.0
|
||||||
python-pptx==1.0.2
|
python-pptx==1.0.2
|
||||||
redis==5.2.1
|
redis==5.2.1
|
||||||
@@ -83,7 +81,7 @@ tzdata==2024.2
|
|||||||
urllib3==2.3.0
|
urllib3==2.3.0
|
||||||
vine==5.1.0
|
vine==5.1.0
|
||||||
wcwidth==0.2.13
|
wcwidth==0.2.13
|
||||||
werkzeug>=3.1.0,<3.1.2
|
werkzeug==3.1.3
|
||||||
yarl==1.20.0
|
yarl==1.20.0
|
||||||
markdownify==1.1.0
|
markdownify==1.1.0
|
||||||
tldextract==5.1.3
|
tldextract==5.1.3
|
||||||
|
|||||||
@@ -5,6 +5,10 @@ class BaseRetriever(ABC):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def gen(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def search(self, *args, **kwargs):
|
def search(self, *args, **kwargs):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -1,6 +1,4 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
|
|
||||||
from application.core.settings import settings
|
from application.core.settings import settings
|
||||||
from application.llm.llm_creator import LLMCreator
|
from application.llm.llm_creator import LLMCreator
|
||||||
from application.retriever.base import BaseRetriever
|
from application.retriever.base import BaseRetriever
|
||||||
@@ -22,20 +20,10 @@ class ClassicRAG(BaseRetriever):
|
|||||||
api_key=settings.API_KEY,
|
api_key=settings.API_KEY,
|
||||||
decoded_token=None,
|
decoded_token=None,
|
||||||
):
|
):
|
||||||
"""Initialize ClassicRAG retriever with vectorstore sources and LLM configuration"""
|
self.original_question = ""
|
||||||
self.original_question = source.get("question", "")
|
|
||||||
self.chat_history = chat_history if chat_history is not None else []
|
self.chat_history = chat_history if chat_history is not None else []
|
||||||
self.prompt = prompt
|
self.prompt = prompt
|
||||||
if isinstance(chunks, str):
|
self.chunks = chunks
|
||||||
try:
|
|
||||||
self.chunks = int(chunks)
|
|
||||||
except ValueError:
|
|
||||||
logging.warning(
|
|
||||||
f"Invalid chunks value '{chunks}', using default value 2"
|
|
||||||
)
|
|
||||||
self.chunks = 2
|
|
||||||
else:
|
|
||||||
self.chunks = chunks
|
|
||||||
self.gpt_model = gpt_model
|
self.gpt_model = gpt_model
|
||||||
self.token_limit = (
|
self.token_limit = (
|
||||||
token_limit
|
token_limit
|
||||||
@@ -56,52 +44,25 @@ class ClassicRAG(BaseRetriever):
|
|||||||
user_api_key=self.user_api_key,
|
user_api_key=self.user_api_key,
|
||||||
decoded_token=decoded_token,
|
decoded_token=decoded_token,
|
||||||
)
|
)
|
||||||
|
self.vectorstore = source["active_docs"] if "active_docs" in source else None
|
||||||
if "active_docs" in source and source["active_docs"] is not None:
|
|
||||||
if isinstance(source["active_docs"], list):
|
|
||||||
self.vectorstores = source["active_docs"]
|
|
||||||
else:
|
|
||||||
self.vectorstores = [source["active_docs"]]
|
|
||||||
else:
|
|
||||||
self.vectorstores = []
|
|
||||||
self.question = self._rephrase_query()
|
self.question = self._rephrase_query()
|
||||||
self.decoded_token = decoded_token
|
self.decoded_token = decoded_token
|
||||||
self._validate_vectorstore_config()
|
|
||||||
|
|
||||||
def _validate_vectorstore_config(self):
|
|
||||||
"""Validate vectorstore IDs and remove any empty/invalid entries"""
|
|
||||||
if not self.vectorstores:
|
|
||||||
logging.warning("No vectorstores configured for retrieval")
|
|
||||||
return
|
|
||||||
invalid_ids = [
|
|
||||||
vs_id for vs_id in self.vectorstores if not vs_id or not vs_id.strip()
|
|
||||||
]
|
|
||||||
if invalid_ids:
|
|
||||||
logging.warning(f"Found invalid vectorstore IDs: {invalid_ids}")
|
|
||||||
self.vectorstores = [
|
|
||||||
vs_id for vs_id in self.vectorstores if vs_id and vs_id.strip()
|
|
||||||
]
|
|
||||||
|
|
||||||
def _rephrase_query(self):
|
def _rephrase_query(self):
|
||||||
"""Rephrase user query with chat history context for better retrieval"""
|
|
||||||
if (
|
if (
|
||||||
not self.original_question
|
not self.original_question
|
||||||
or not self.chat_history
|
or not self.chat_history
|
||||||
or self.chat_history == []
|
or self.chat_history == []
|
||||||
or self.chunks == 0
|
or self.chunks == 0
|
||||||
or not self.vectorstores
|
or self.vectorstore is None
|
||||||
):
|
):
|
||||||
return self.original_question
|
return self.original_question
|
||||||
prompt = f"""Given the following conversation history:
|
|
||||||
|
|
||||||
|
prompt = f"""Given the following conversation history:
|
||||||
{self.chat_history}
|
{self.chat_history}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Rephrase the following user question to be a standalone search query
|
Rephrase the following user question to be a standalone search query
|
||||||
|
|
||||||
that captures all relevant context from the conversation:
|
that captures all relevant context from the conversation:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
@@ -118,75 +79,44 @@ class ClassicRAG(BaseRetriever):
|
|||||||
return self.original_question
|
return self.original_question
|
||||||
|
|
||||||
def _get_data(self):
|
def _get_data(self):
|
||||||
"""Retrieve relevant documents from configured vectorstores"""
|
if self.chunks == 0 or self.vectorstore is None:
|
||||||
if self.chunks == 0 or not self.vectorstores:
|
docs = []
|
||||||
return []
|
else:
|
||||||
all_docs = []
|
docsearch = VectorCreator.create_vectorstore(
|
||||||
chunks_per_source = max(1, self.chunks // len(self.vectorstores))
|
settings.VECTOR_STORE, self.vectorstore, settings.EMBEDDINGS_KEY
|
||||||
|
)
|
||||||
|
docs_temp = docsearch.search(self.question, k=self.chunks)
|
||||||
|
docs = [
|
||||||
|
{
|
||||||
|
"title": i.metadata.get(
|
||||||
|
"title", i.metadata.get("post_title", i.page_content)
|
||||||
|
).split("/")[-1],
|
||||||
|
"text": i.page_content,
|
||||||
|
"source": (
|
||||||
|
i.metadata.get("source")
|
||||||
|
if i.metadata.get("source")
|
||||||
|
else "local"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
for i in docs_temp
|
||||||
|
]
|
||||||
|
|
||||||
for vectorstore_id in self.vectorstores:
|
return docs
|
||||||
if vectorstore_id:
|
|
||||||
try:
|
|
||||||
docsearch = VectorCreator.create_vectorstore(
|
|
||||||
settings.VECTOR_STORE, vectorstore_id, settings.EMBEDDINGS_KEY
|
|
||||||
)
|
|
||||||
docs_temp = docsearch.search(self.question, k=chunks_per_source)
|
|
||||||
|
|
||||||
for doc in docs_temp:
|
def gen():
|
||||||
if hasattr(doc, "page_content") and hasattr(doc, "metadata"):
|
pass
|
||||||
page_content = doc.page_content
|
|
||||||
metadata = doc.metadata
|
|
||||||
else:
|
|
||||||
page_content = doc.get("text", doc.get("page_content", ""))
|
|
||||||
metadata = doc.get("metadata", {})
|
|
||||||
title = metadata.get(
|
|
||||||
"title", metadata.get("post_title", page_content)
|
|
||||||
)
|
|
||||||
if not isinstance(title, str):
|
|
||||||
title = str(title)
|
|
||||||
title = title.split("/")[-1]
|
|
||||||
|
|
||||||
filename = (
|
|
||||||
metadata.get("filename")
|
|
||||||
or metadata.get("file_name")
|
|
||||||
or metadata.get("source")
|
|
||||||
)
|
|
||||||
if isinstance(filename, str):
|
|
||||||
filename = os.path.basename(filename) or filename
|
|
||||||
else:
|
|
||||||
filename = title
|
|
||||||
if not filename:
|
|
||||||
filename = title
|
|
||||||
source_path = metadata.get("source") or vectorstore_id
|
|
||||||
all_docs.append(
|
|
||||||
{
|
|
||||||
"title": title,
|
|
||||||
"text": page_content,
|
|
||||||
"source": source_path,
|
|
||||||
"filename": filename,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logging.error(
|
|
||||||
f"Error searching vectorstore {vectorstore_id}: {e}",
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
return all_docs
|
|
||||||
|
|
||||||
def search(self, query: str = ""):
|
def search(self, query: str = ""):
|
||||||
"""Search for documents using optional query override"""
|
|
||||||
if query:
|
if query:
|
||||||
self.original_question = query
|
self.original_question = query
|
||||||
self.question = self._rephrase_query()
|
self.question = self._rephrase_query()
|
||||||
return self._get_data()
|
return self._get_data()
|
||||||
|
|
||||||
def get_params(self):
|
def get_params(self):
|
||||||
"""Return current retriever configuration parameters"""
|
|
||||||
return {
|
return {
|
||||||
"question": self.original_question,
|
"question": self.original_question,
|
||||||
"rephrased_question": self.question,
|
"rephrased_question": self.question,
|
||||||
"sources": self.vectorstores,
|
"source": self.vectorstore,
|
||||||
"chunks": self.chunks,
|
"chunks": self.chunks,
|
||||||
"token_limit": self.token_limit,
|
"token_limit": self.token_limit,
|
||||||
"gpt_model": self.gpt_model,
|
"gpt_model": self.gpt_model,
|
||||||
|
|||||||
@@ -1,85 +0,0 @@
|
|||||||
import base64
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
|
|
||||||
from cryptography.hazmat.backends import default_backend
|
|
||||||
from cryptography.hazmat.primitives import hashes
|
|
||||||
from cryptography.hazmat.primitives.ciphers import algorithms, Cipher, modes
|
|
||||||
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
|
|
||||||
|
|
||||||
from application.core.settings import settings
|
|
||||||
|
|
||||||
|
|
||||||
def _derive_key(user_id: str, salt: bytes) -> bytes:
|
|
||||||
app_secret = settings.ENCRYPTION_SECRET_KEY
|
|
||||||
|
|
||||||
password = f"{app_secret}#{user_id}".encode()
|
|
||||||
|
|
||||||
kdf = PBKDF2HMAC(
|
|
||||||
algorithm=hashes.SHA256(),
|
|
||||||
length=32,
|
|
||||||
salt=salt,
|
|
||||||
iterations=100000,
|
|
||||||
backend=default_backend(),
|
|
||||||
)
|
|
||||||
|
|
||||||
return kdf.derive(password)
|
|
||||||
|
|
||||||
|
|
||||||
def encrypt_credentials(credentials: dict, user_id: str) -> str:
|
|
||||||
if not credentials:
|
|
||||||
return ""
|
|
||||||
try:
|
|
||||||
salt = os.urandom(16)
|
|
||||||
iv = os.urandom(16)
|
|
||||||
key = _derive_key(user_id, salt)
|
|
||||||
|
|
||||||
json_str = json.dumps(credentials)
|
|
||||||
|
|
||||||
cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend())
|
|
||||||
encryptor = cipher.encryptor()
|
|
||||||
|
|
||||||
padded_data = _pad_data(json_str.encode())
|
|
||||||
encrypted_data = encryptor.update(padded_data) + encryptor.finalize()
|
|
||||||
|
|
||||||
result = salt + iv + encrypted_data
|
|
||||||
return base64.b64encode(result).decode()
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Warning: Failed to encrypt credentials: {e}")
|
|
||||||
return ""
|
|
||||||
|
|
||||||
|
|
||||||
def decrypt_credentials(encrypted_data: str, user_id: str) -> dict:
|
|
||||||
if not encrypted_data:
|
|
||||||
return {}
|
|
||||||
try:
|
|
||||||
data = base64.b64decode(encrypted_data.encode())
|
|
||||||
|
|
||||||
salt = data[:16]
|
|
||||||
iv = data[16:32]
|
|
||||||
encrypted_content = data[32:]
|
|
||||||
|
|
||||||
key = _derive_key(user_id, salt)
|
|
||||||
|
|
||||||
cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend())
|
|
||||||
decryptor = cipher.decryptor()
|
|
||||||
|
|
||||||
decrypted_padded = decryptor.update(encrypted_content) + decryptor.finalize()
|
|
||||||
decrypted_data = _unpad_data(decrypted_padded)
|
|
||||||
|
|
||||||
return json.loads(decrypted_data.decode())
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Warning: Failed to decrypt credentials: {e}")
|
|
||||||
return {}
|
|
||||||
|
|
||||||
|
|
||||||
def _pad_data(data: bytes) -> bytes:
|
|
||||||
block_size = 16
|
|
||||||
padding_len = block_size - (len(data) % block_size)
|
|
||||||
padding = bytes([padding_len]) * padding_len
|
|
||||||
return data + padding
|
|
||||||
|
|
||||||
|
|
||||||
def _unpad_data(data: bytes) -> bytes:
|
|
||||||
padding_len = data[-1]
|
|
||||||
return data[:-padding_len]
|
|
||||||
@@ -1,28 +1,20 @@
|
|||||||
import os
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
import os
|
||||||
from langchain_openai import OpenAIEmbeddings
|
|
||||||
from sentence_transformers import SentenceTransformer
|
from sentence_transformers import SentenceTransformer
|
||||||
|
from langchain_openai import OpenAIEmbeddings
|
||||||
from application.core.settings import settings
|
from application.core.settings import settings
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingsWrapper:
|
class EmbeddingsWrapper:
|
||||||
def __init__(self, model_name, *args, **kwargs):
|
def __init__(self, model_name, *args, **kwargs):
|
||||||
self.model = SentenceTransformer(
|
self.model = SentenceTransformer(model_name, config_kwargs={'allow_dangerous_deserialization': True}, *args, **kwargs)
|
||||||
model_name,
|
|
||||||
config_kwargs={"allow_dangerous_deserialization": True},
|
|
||||||
*args,
|
|
||||||
**kwargs
|
|
||||||
)
|
|
||||||
self.dimension = self.model.get_sentence_embedding_dimension()
|
self.dimension = self.model.get_sentence_embedding_dimension()
|
||||||
|
|
||||||
def embed_query(self, query: str):
|
def embed_query(self, query: str):
|
||||||
return self.model.encode(query).tolist()
|
return self.model.encode(query).tolist()
|
||||||
|
|
||||||
def embed_documents(self, documents: list):
|
def embed_documents(self, documents: list):
|
||||||
return self.model.encode(documents).tolist()
|
return self.model.encode(documents).tolist()
|
||||||
|
|
||||||
def __call__(self, text):
|
def __call__(self, text):
|
||||||
if isinstance(text, str):
|
if isinstance(text, str):
|
||||||
return self.embed_query(text)
|
return self.embed_query(text)
|
||||||
@@ -32,14 +24,15 @@ class EmbeddingsWrapper:
|
|||||||
raise ValueError("Input must be a string or a list of strings")
|
raise ValueError("Input must be a string or a list of strings")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingsSingleton:
|
class EmbeddingsSingleton:
|
||||||
_instances = {}
|
_instances = {}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_instance(embeddings_name, *args, **kwargs):
|
def get_instance(embeddings_name, *args, **kwargs):
|
||||||
if embeddings_name not in EmbeddingsSingleton._instances:
|
if embeddings_name not in EmbeddingsSingleton._instances:
|
||||||
EmbeddingsSingleton._instances[embeddings_name] = (
|
EmbeddingsSingleton._instances[embeddings_name] = EmbeddingsSingleton._create_instance(
|
||||||
EmbeddingsSingleton._create_instance(embeddings_name, *args, **kwargs)
|
embeddings_name, *args, **kwargs
|
||||||
)
|
)
|
||||||
return EmbeddingsSingleton._instances[embeddings_name]
|
return EmbeddingsSingleton._instances[embeddings_name]
|
||||||
|
|
||||||
@@ -47,15 +40,9 @@ class EmbeddingsSingleton:
|
|||||||
def _create_instance(embeddings_name, *args, **kwargs):
|
def _create_instance(embeddings_name, *args, **kwargs):
|
||||||
embeddings_factory = {
|
embeddings_factory = {
|
||||||
"openai_text-embedding-ada-002": OpenAIEmbeddings,
|
"openai_text-embedding-ada-002": OpenAIEmbeddings,
|
||||||
"huggingface_sentence-transformers/all-mpnet-base-v2": lambda: EmbeddingsWrapper(
|
"huggingface_sentence-transformers/all-mpnet-base-v2": lambda: EmbeddingsWrapper("sentence-transformers/all-mpnet-base-v2"),
|
||||||
"sentence-transformers/all-mpnet-base-v2"
|
"huggingface_sentence-transformers-all-mpnet-base-v2": lambda: EmbeddingsWrapper("sentence-transformers/all-mpnet-base-v2"),
|
||||||
),
|
"huggingface_hkunlp/instructor-large": lambda: EmbeddingsWrapper("hkunlp/instructor-large"),
|
||||||
"huggingface_sentence-transformers-all-mpnet-base-v2": lambda: EmbeddingsWrapper(
|
|
||||||
"sentence-transformers/all-mpnet-base-v2"
|
|
||||||
),
|
|
||||||
"huggingface_hkunlp/instructor-large": lambda: EmbeddingsWrapper(
|
|
||||||
"hkunlp/instructor-large"
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if embeddings_name in embeddings_factory:
|
if embeddings_name in embeddings_factory:
|
||||||
@@ -63,63 +50,34 @@ class EmbeddingsSingleton:
|
|||||||
else:
|
else:
|
||||||
return EmbeddingsWrapper(embeddings_name, *args, **kwargs)
|
return EmbeddingsWrapper(embeddings_name, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class BaseVectorStore(ABC):
|
class BaseVectorStore(ABC):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def search(self, *args, **kwargs):
|
def search(self, *args, **kwargs):
|
||||||
"""Search for similar documents/chunks in the vectorstore"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def add_texts(self, texts, metadatas=None, *args, **kwargs):
|
|
||||||
"""Add texts with their embeddings to the vectorstore"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def delete_index(self, *args, **kwargs):
|
|
||||||
"""Delete the entire index/collection"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def save_local(self, *args, **kwargs):
|
|
||||||
"""Save vectorstore to local storage"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def get_chunks(self, *args, **kwargs):
|
|
||||||
"""Get all chunks from the vectorstore"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def add_chunk(self, text, metadata=None, *args, **kwargs):
|
|
||||||
"""Add a single chunk to the vectorstore"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def delete_chunk(self, chunk_id, *args, **kwargs):
|
|
||||||
"""Delete a specific chunk from the vectorstore"""
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def is_azure_configured(self):
|
def is_azure_configured(self):
|
||||||
return (
|
return settings.OPENAI_API_BASE and settings.OPENAI_API_VERSION and settings.AZURE_DEPLOYMENT_NAME
|
||||||
settings.OPENAI_API_BASE
|
|
||||||
and settings.OPENAI_API_VERSION
|
|
||||||
and settings.AZURE_DEPLOYMENT_NAME
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_embeddings(self, embeddings_name, embeddings_key=None):
|
def _get_embeddings(self, embeddings_name, embeddings_key=None):
|
||||||
if embeddings_name == "openai_text-embedding-ada-002":
|
if embeddings_name == "openai_text-embedding-ada-002":
|
||||||
if self.is_azure_configured():
|
if self.is_azure_configured():
|
||||||
os.environ["OPENAI_API_TYPE"] = "azure"
|
os.environ["OPENAI_API_TYPE"] = "azure"
|
||||||
embedding_instance = EmbeddingsSingleton.get_instance(
|
embedding_instance = EmbeddingsSingleton.get_instance(
|
||||||
embeddings_name, model=settings.AZURE_EMBEDDINGS_DEPLOYMENT_NAME
|
embeddings_name,
|
||||||
|
model=settings.AZURE_EMBEDDINGS_DEPLOYMENT_NAME
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
embedding_instance = EmbeddingsSingleton.get_instance(
|
embedding_instance = EmbeddingsSingleton.get_instance(
|
||||||
embeddings_name, openai_api_key=embeddings_key
|
embeddings_name,
|
||||||
|
openai_api_key=embeddings_key
|
||||||
)
|
)
|
||||||
elif embeddings_name == "huggingface_sentence-transformers/all-mpnet-base-v2":
|
elif embeddings_name == "huggingface_sentence-transformers/all-mpnet-base-v2":
|
||||||
if os.path.exists("./models/all-mpnet-base-v2"):
|
if os.path.exists("./models/all-mpnet-base-v2"):
|
||||||
embedding_instance = EmbeddingsSingleton.get_instance(
|
embedding_instance = EmbeddingsSingleton.get_instance(
|
||||||
embeddings_name="./models/all-mpnet-base-v2",
|
embeddings_name = "./models/all-mpnet-base-v2",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
embedding_instance = EmbeddingsSingleton.get_instance(
|
embedding_instance = EmbeddingsSingleton.get_instance(
|
||||||
@@ -129,3 +87,4 @@ class BaseVectorStore(ABC):
|
|||||||
embedding_instance = EmbeddingsSingleton.get_instance(embeddings_name)
|
embedding_instance = EmbeddingsSingleton.get_instance(embeddings_name)
|
||||||
|
|
||||||
return embedding_instance
|
return embedding_instance
|
||||||
|
|
||||||
|
|||||||
@@ -19,7 +19,6 @@ from bson.objectid import ObjectId
|
|||||||
from application.agents.agent_creator import AgentCreator
|
from application.agents.agent_creator import AgentCreator
|
||||||
from application.api.answer.services.stream_processor import get_prompt
|
from application.api.answer.services.stream_processor import get_prompt
|
||||||
|
|
||||||
from application.cache import get_redis_instance
|
|
||||||
from application.core.mongo_db import MongoDB
|
from application.core.mongo_db import MongoDB
|
||||||
from application.core.settings import settings
|
from application.core.settings import settings
|
||||||
from application.parser.chunking import Chunker
|
from application.parser.chunking import Chunker
|
||||||
@@ -215,7 +214,8 @@ def run_agent_logic(agent_config, input_data):
|
|||||||
|
|
||||||
|
|
||||||
def ingest_worker(
|
def ingest_worker(
|
||||||
self, directory, formats, job_name, file_path, filename, user, retriever="classic"
|
self, directory, formats, job_name, file_path, filename, user,
|
||||||
|
retriever="classic"
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Ingest and process documents.
|
Ingest and process documents.
|
||||||
@@ -240,7 +240,7 @@ def ingest_worker(
|
|||||||
sample = False
|
sample = False
|
||||||
|
|
||||||
storage = StorageCreator.get_storage()
|
storage = StorageCreator.get_storage()
|
||||||
|
|
||||||
logging.info(f"Ingest path: {file_path}", extra={"user": user, "job": job_name})
|
logging.info(f"Ingest path: {file_path}", extra={"user": user, "job": job_name})
|
||||||
|
|
||||||
# Create temporary working directory
|
# Create temporary working directory
|
||||||
@@ -253,32 +253,30 @@ def ingest_worker(
|
|||||||
# Handle directory case
|
# Handle directory case
|
||||||
logging.info(f"Processing directory: {file_path}")
|
logging.info(f"Processing directory: {file_path}")
|
||||||
files_list = storage.list_files(file_path)
|
files_list = storage.list_files(file_path)
|
||||||
|
|
||||||
for storage_file_path in files_list:
|
for storage_file_path in files_list:
|
||||||
if storage.is_directory(storage_file_path):
|
if storage.is_directory(storage_file_path):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Create relative path structure in temp directory
|
# Create relative path structure in temp directory
|
||||||
rel_path = os.path.relpath(storage_file_path, file_path)
|
rel_path = os.path.relpath(storage_file_path, file_path)
|
||||||
local_file_path = os.path.join(temp_dir, rel_path)
|
local_file_path = os.path.join(temp_dir, rel_path)
|
||||||
|
|
||||||
os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
|
os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
|
||||||
|
|
||||||
# Download file
|
# Download file
|
||||||
try:
|
try:
|
||||||
file_data = storage.get_file(storage_file_path)
|
file_data = storage.get_file(storage_file_path)
|
||||||
with open(local_file_path, "wb") as f:
|
with open(local_file_path, "wb") as f:
|
||||||
f.write(file_data.read())
|
f.write(file_data.read())
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(
|
logging.error(f"Error downloading file {storage_file_path}: {e}")
|
||||||
f"Error downloading file {storage_file_path}: {e}"
|
|
||||||
)
|
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
# Handle single file case
|
# Handle single file case
|
||||||
temp_filename = os.path.basename(file_path)
|
temp_filename = os.path.basename(file_path)
|
||||||
temp_file_path = os.path.join(temp_dir, temp_filename)
|
temp_file_path = os.path.join(temp_dir, temp_filename)
|
||||||
|
|
||||||
file_data = storage.get_file(file_path)
|
file_data = storage.get_file(file_path)
|
||||||
with open(temp_file_path, "wb") as f:
|
with open(temp_file_path, "wb") as f:
|
||||||
f.write(file_data.read())
|
f.write(file_data.read())
|
||||||
@@ -287,10 +285,7 @@ def ingest_worker(
|
|||||||
if temp_filename.endswith(".zip"):
|
if temp_filename.endswith(".zip"):
|
||||||
logging.info(f"Extracting zip file: {temp_filename}")
|
logging.info(f"Extracting zip file: {temp_filename}")
|
||||||
extract_zip_recursive(
|
extract_zip_recursive(
|
||||||
temp_file_path,
|
temp_file_path, temp_dir, current_depth=0, max_depth=RECURSION_DEPTH
|
||||||
temp_dir,
|
|
||||||
current_depth=0,
|
|
||||||
max_depth=RECURSION_DEPTH,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.update_state(state="PROGRESS", meta={"current": 1})
|
self.update_state(state="PROGRESS", meta={"current": 1})
|
||||||
@@ -305,8 +300,8 @@ def ingest_worker(
|
|||||||
file_metadata=metadata_from_filename,
|
file_metadata=metadata_from_filename,
|
||||||
)
|
)
|
||||||
raw_docs = reader.load_data()
|
raw_docs = reader.load_data()
|
||||||
|
|
||||||
directory_structure = getattr(reader, "directory_structure", {})
|
directory_structure = getattr(reader, 'directory_structure', {})
|
||||||
logging.info(f"Directory structure from reader: {directory_structure}")
|
logging.info(f"Directory structure from reader: {directory_structure}")
|
||||||
|
|
||||||
chunker = Chunker(
|
chunker = Chunker(
|
||||||
@@ -376,10 +371,7 @@ def reingest_source_worker(self, source_id, user):
|
|||||||
try:
|
try:
|
||||||
from application.vectorstore.vector_creator import VectorCreator
|
from application.vectorstore.vector_creator import VectorCreator
|
||||||
|
|
||||||
self.update_state(
|
self.update_state(state="PROGRESS", meta={"current": 10, "status": "Initializing re-ingestion scan"})
|
||||||
state="PROGRESS",
|
|
||||||
meta={"current": 10, "status": "Initializing re-ingestion scan"},
|
|
||||||
)
|
|
||||||
|
|
||||||
source = sources_collection.find_one({"_id": ObjectId(source_id), "user": user})
|
source = sources_collection.find_one({"_id": ObjectId(source_id), "user": user})
|
||||||
if not source:
|
if not source:
|
||||||
@@ -388,9 +380,7 @@ def reingest_source_worker(self, source_id, user):
|
|||||||
storage = StorageCreator.get_storage()
|
storage = StorageCreator.get_storage()
|
||||||
source_file_path = source.get("file_path", "")
|
source_file_path = source.get("file_path", "")
|
||||||
|
|
||||||
self.update_state(
|
self.update_state(state="PROGRESS", meta={"current": 20, "status": "Scanning current files"})
|
||||||
state="PROGRESS", meta={"current": 20, "status": "Scanning current files"}
|
|
||||||
)
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as temp_dir:
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
# Download all files from storage to temp directory, preserving directory structure
|
# Download all files from storage to temp directory, preserving directory structure
|
||||||
@@ -401,6 +391,7 @@ def reingest_source_worker(self, source_id, user):
|
|||||||
if storage.is_directory(storage_file_path):
|
if storage.is_directory(storage_file_path):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
||||||
rel_path = os.path.relpath(storage_file_path, source_file_path)
|
rel_path = os.path.relpath(storage_file_path, source_file_path)
|
||||||
local_file_path = os.path.join(temp_dir, rel_path)
|
local_file_path = os.path.join(temp_dir, rel_path)
|
||||||
|
|
||||||
@@ -412,39 +403,23 @@ def reingest_source_worker(self, source_id, user):
|
|||||||
with open(local_file_path, "wb") as f:
|
with open(local_file_path, "wb") as f:
|
||||||
f.write(file_data.read())
|
f.write(file_data.read())
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(
|
logging.error(f"Error downloading file {storage_file_path}: {e}")
|
||||||
f"Error downloading file {storage_file_path}: {e}"
|
|
||||||
)
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
reader = SimpleDirectoryReader(
|
reader = SimpleDirectoryReader(
|
||||||
input_dir=temp_dir,
|
input_dir=temp_dir,
|
||||||
recursive=True,
|
recursive=True,
|
||||||
required_exts=[
|
required_exts=[
|
||||||
".rst",
|
".rst", ".md", ".pdf", ".txt", ".docx", ".csv", ".epub",
|
||||||
".md",
|
".html", ".mdx", ".json", ".xlsx", ".pptx", ".png",
|
||||||
".pdf",
|
".jpg", ".jpeg",
|
||||||
".txt",
|
|
||||||
".docx",
|
|
||||||
".csv",
|
|
||||||
".epub",
|
|
||||||
".html",
|
|
||||||
".mdx",
|
|
||||||
".json",
|
|
||||||
".xlsx",
|
|
||||||
".pptx",
|
|
||||||
".png",
|
|
||||||
".jpg",
|
|
||||||
".jpeg",
|
|
||||||
],
|
],
|
||||||
exclude_hidden=True,
|
exclude_hidden=True,
|
||||||
file_metadata=metadata_from_filename,
|
file_metadata=metadata_from_filename,
|
||||||
)
|
)
|
||||||
reader.load_data()
|
reader.load_data()
|
||||||
directory_structure = reader.directory_structure
|
directory_structure = reader.directory_structure
|
||||||
logging.info(
|
logging.info(f"Directory structure built with token counts: {directory_structure}")
|
||||||
f"Directory structure built with token counts: {directory_structure}"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
old_directory_structure = source.get("directory_structure") or {}
|
old_directory_structure = source.get("directory_structure") or {}
|
||||||
@@ -458,17 +433,11 @@ def reingest_source_worker(self, source_id, user):
|
|||||||
files = set()
|
files = set()
|
||||||
if isinstance(struct, dict):
|
if isinstance(struct, dict):
|
||||||
for name, meta in struct.items():
|
for name, meta in struct.items():
|
||||||
current_path = (
|
current_path = os.path.join(prefix, name) if prefix else name
|
||||||
os.path.join(prefix, name) if prefix else name
|
if isinstance(meta, dict) and ("type" in meta and "size_bytes" in meta):
|
||||||
)
|
|
||||||
if isinstance(meta, dict) and (
|
|
||||||
"type" in meta and "size_bytes" in meta
|
|
||||||
):
|
|
||||||
files.add(current_path)
|
files.add(current_path)
|
||||||
elif isinstance(meta, dict):
|
elif isinstance(meta, dict):
|
||||||
files |= _flatten_directory_structure(
|
files |= _flatten_directory_structure(meta, current_path)
|
||||||
meta, current_path
|
|
||||||
)
|
|
||||||
return files
|
return files
|
||||||
|
|
||||||
old_files = _flatten_directory_structure(old_directory_structure)
|
old_files = _flatten_directory_structure(old_directory_structure)
|
||||||
@@ -488,9 +457,7 @@ def reingest_source_worker(self, source_id, user):
|
|||||||
logging.info("No files removed since last ingest.")
|
logging.info("No files removed since last ingest.")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(
|
logging.error(f"Error comparing directory structures: {e}", exc_info=True)
|
||||||
f"Error comparing directory structures: {e}", exc_info=True
|
|
||||||
)
|
|
||||||
added_files = []
|
added_files = []
|
||||||
removed_files = []
|
removed_files = []
|
||||||
try:
|
try:
|
||||||
@@ -510,21 +477,14 @@ def reingest_source_worker(self, source_id, user):
|
|||||||
settings.EMBEDDINGS_KEY,
|
settings.EMBEDDINGS_KEY,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.update_state(
|
self.update_state(state="PROGRESS", meta={"current": 40, "status": "Processing file changes"})
|
||||||
state="PROGRESS",
|
|
||||||
meta={"current": 40, "status": "Processing file changes"},
|
|
||||||
)
|
|
||||||
|
|
||||||
# 1) Delete chunks from removed files
|
# 1) Delete chunks from removed files
|
||||||
deleted = 0
|
deleted = 0
|
||||||
if removed_files:
|
if removed_files:
|
||||||
try:
|
try:
|
||||||
for ch in vector_store.get_chunks() or []:
|
for ch in vector_store.get_chunks() or []:
|
||||||
metadata = (
|
metadata = ch.get("metadata", {}) if isinstance(ch, dict) else getattr(ch, "metadata", {})
|
||||||
ch.get("metadata", {})
|
|
||||||
if isinstance(ch, dict)
|
|
||||||
else getattr(ch, "metadata", {})
|
|
||||||
)
|
|
||||||
raw_source = metadata.get("source")
|
raw_source = metadata.get("source")
|
||||||
|
|
||||||
source_file = str(raw_source) if raw_source else ""
|
source_file = str(raw_source) if raw_source else ""
|
||||||
@@ -536,17 +496,10 @@ def reingest_source_worker(self, source_id, user):
|
|||||||
vector_store.delete_chunk(cid)
|
vector_store.delete_chunk(cid)
|
||||||
deleted += 1
|
deleted += 1
|
||||||
except Exception as de:
|
except Exception as de:
|
||||||
logging.error(
|
logging.error(f"Failed deleting chunk {cid}: {de}")
|
||||||
f"Failed deleting chunk {cid}: {de}"
|
logging.info(f"Deleted {deleted} chunks from {len(removed_files)} removed files")
|
||||||
)
|
|
||||||
logging.info(
|
|
||||||
f"Deleted {deleted} chunks from {len(removed_files)} removed files"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(
|
logging.error(f"Error during deletion of removed file chunks: {e}", exc_info=True)
|
||||||
f"Error during deletion of removed file chunks: {e}",
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 2) Add chunks from new files
|
# 2) Add chunks from new files
|
||||||
added = 0
|
added = 0
|
||||||
@@ -575,86 +528,58 @@ def reingest_source_worker(self, source_id, user):
|
|||||||
)
|
)
|
||||||
chunked_new = chunker_new.chunk(documents=raw_docs_new)
|
chunked_new = chunker_new.chunk(documents=raw_docs_new)
|
||||||
|
|
||||||
for (
|
for file_path, token_count in reader_new.file_token_counts.items():
|
||||||
file_path,
|
|
||||||
token_count,
|
|
||||||
) in reader_new.file_token_counts.items():
|
|
||||||
try:
|
try:
|
||||||
rel_path = os.path.relpath(
|
rel_path = os.path.relpath(file_path, start=temp_dir)
|
||||||
file_path, start=temp_dir
|
|
||||||
)
|
|
||||||
path_parts = rel_path.split(os.sep)
|
path_parts = rel_path.split(os.sep)
|
||||||
current_dir = directory_structure
|
current_dir = directory_structure
|
||||||
|
|
||||||
for part in path_parts[:-1]:
|
for part in path_parts[:-1]:
|
||||||
if part in current_dir and isinstance(
|
if part in current_dir and isinstance(current_dir[part], dict):
|
||||||
current_dir[part], dict
|
|
||||||
):
|
|
||||||
current_dir = current_dir[part]
|
current_dir = current_dir[part]
|
||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
|
|
||||||
filename = path_parts[-1]
|
filename = path_parts[-1]
|
||||||
if filename in current_dir and isinstance(
|
if filename in current_dir and isinstance(current_dir[filename], dict):
|
||||||
current_dir[filename], dict
|
current_dir[filename]["token_count"] = token_count
|
||||||
):
|
logging.info(f"Updated token count for {rel_path}: {token_count}")
|
||||||
current_dir[filename][
|
|
||||||
"token_count"
|
|
||||||
] = token_count
|
|
||||||
logging.info(
|
|
||||||
f"Updated token count for {rel_path}: {token_count}"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.warning(
|
logging.warning(f"Could not update token count for {file_path}: {e}")
|
||||||
f"Could not update token count for {file_path}: {e}"
|
|
||||||
)
|
|
||||||
|
|
||||||
for d in chunked_new:
|
for d in chunked_new:
|
||||||
meta = dict(d.extra_info or {})
|
meta = dict(d.extra_info or {})
|
||||||
try:
|
try:
|
||||||
raw_src = meta.get("source")
|
raw_src = meta.get("source")
|
||||||
if isinstance(raw_src, str) and os.path.isabs(
|
if isinstance(raw_src, str) and os.path.isabs(raw_src):
|
||||||
raw_src
|
meta["source"] = os.path.relpath(raw_src, start=temp_dir)
|
||||||
):
|
|
||||||
meta["source"] = os.path.relpath(
|
|
||||||
raw_src, start=temp_dir
|
|
||||||
)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
vector_store.add_chunk(d.text, metadata=meta)
|
vector_store.add_chunk(d.text, metadata=meta)
|
||||||
added += 1
|
added += 1
|
||||||
logging.info(
|
logging.info(f"Added {added} chunks from {len(added_files)} new files")
|
||||||
f"Added {added} chunks from {len(added_files)} new files"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(
|
logging.error(f"Error during ingestion of new files: {e}", exc_info=True)
|
||||||
f"Error during ingestion of new files: {e}", exc_info=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# 3) Update source directory structure timestamp
|
# 3) Update source directory structure timestamp
|
||||||
try:
|
try:
|
||||||
total_tokens = sum(reader.file_token_counts.values())
|
total_tokens = sum(reader.file_token_counts.values())
|
||||||
|
|
||||||
sources_collection.update_one(
|
sources_collection.update_one(
|
||||||
{"_id": ObjectId(source_id)},
|
{"_id": ObjectId(source_id)},
|
||||||
{
|
{
|
||||||
"$set": {
|
"$set": {
|
||||||
"directory_structure": directory_structure,
|
"directory_structure": directory_structure,
|
||||||
"date": datetime.datetime.now(),
|
"date": datetime.datetime.now(),
|
||||||
"tokens": total_tokens,
|
"tokens": total_tokens
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(
|
logging.error(f"Error updating directory_structure in DB: {e}", exc_info=True)
|
||||||
f"Error updating directory_structure in DB: {e}", exc_info=True
|
|
||||||
)
|
|
||||||
|
|
||||||
self.update_state(
|
self.update_state(state="PROGRESS", meta={"current": 100, "status": "Re-ingestion completed"})
|
||||||
state="PROGRESS",
|
|
||||||
meta={"current": 100, "status": "Re-ingestion completed"},
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"source_id": source_id,
|
"source_id": source_id,
|
||||||
@@ -666,16 +591,15 @@ def reingest_source_worker(self, source_id, user):
|
|||||||
"chunks_deleted": deleted,
|
"chunks_deleted": deleted,
|
||||||
}
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(
|
logging.error(f"Error while processing file changes: {e}", exc_info=True)
|
||||||
f"Error while processing file changes: {e}", exc_info=True
|
|
||||||
)
|
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error in reingest_source_worker: {e}", exc_info=True)
|
logging.error(f"Error in reingest_source_worker: {e}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
def remote_worker(
|
def remote_worker(
|
||||||
self,
|
self,
|
||||||
source_data,
|
source_data,
|
||||||
@@ -727,7 +651,7 @@ def remote_worker(
|
|||||||
"id": str(id),
|
"id": str(id),
|
||||||
"type": loader,
|
"type": loader,
|
||||||
"remote_data": source_data,
|
"remote_data": source_data,
|
||||||
"sync_frequency": sync_frequency,
|
"sync_frequency": sync_frequency
|
||||||
}
|
}
|
||||||
|
|
||||||
if operation_mode == "sync":
|
if operation_mode == "sync":
|
||||||
@@ -788,7 +712,7 @@ def sync_worker(self, frequency):
|
|||||||
self, source_data, name, user, source_type, frequency, retriever, doc_id
|
self, source_data, name, user, source_type, frequency, retriever, doc_id
|
||||||
)
|
)
|
||||||
sync_counts["total_sync_count"] += 1
|
sync_counts["total_sync_count"] += 1
|
||||||
sync_counts[
|
sync_counts[
|
||||||
"sync_success" if resp["status"] == "success" else "sync_failure"
|
"sync_success" if resp["status"] == "success" else "sync_failure"
|
||||||
] += 1
|
] += 1
|
||||||
return {
|
return {
|
||||||
@@ -825,14 +749,15 @@ def attachment_worker(self, file_info, user):
|
|||||||
input_files=[local_path], exclude_hidden=True, errors="ignore"
|
input_files=[local_path], exclude_hidden=True, errors="ignore"
|
||||||
)
|
)
|
||||||
.load_data()[0]
|
.load_data()[0]
|
||||||
.text,
|
.text,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
token_count = num_tokens_from_string(content)
|
token_count = num_tokens_from_string(content)
|
||||||
if token_count > 100000:
|
if token_count > 100000:
|
||||||
content = content[:250000]
|
content = content[:250000]
|
||||||
token_count = num_tokens_from_string(content)
|
token_count = num_tokens_from_string(content)
|
||||||
|
|
||||||
self.update_state(
|
self.update_state(
|
||||||
state="PROGRESS", meta={"current": 80, "status": "Storing in database"}
|
state="PROGRESS", meta={"current": 80, "status": "Storing in database"}
|
||||||
)
|
)
|
||||||
@@ -947,49 +872,37 @@ def ingest_connector(
|
|||||||
doc_id: Document ID for sync operations (required when operation_mode="sync")
|
doc_id: Document ID for sync operations (required when operation_mode="sync")
|
||||||
sync_frequency: How often to sync ("never", "daily", "weekly", "monthly")
|
sync_frequency: How often to sync ("never", "daily", "weekly", "monthly")
|
||||||
"""
|
"""
|
||||||
logging.info(
|
logging.info(f"Starting remote ingestion from {source_type} for user: {user}, job: {job_name}")
|
||||||
f"Starting remote ingestion from {source_type} for user: {user}, job: {job_name}"
|
|
||||||
)
|
|
||||||
self.update_state(state="PROGRESS", meta={"current": 1})
|
self.update_state(state="PROGRESS", meta={"current": 1})
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as temp_dir:
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
try:
|
try:
|
||||||
# Step 1: Initialize the appropriate loader
|
# Step 1: Initialize the appropriate loader
|
||||||
self.update_state(
|
self.update_state(state="PROGRESS", meta={"current": 10, "status": "Initializing connector"})
|
||||||
state="PROGRESS",
|
|
||||||
meta={"current": 10, "status": "Initializing connector"},
|
|
||||||
)
|
|
||||||
|
|
||||||
if not session_token:
|
if not session_token:
|
||||||
raise ValueError(f"{source_type} connector requires session_token")
|
raise ValueError(f"{source_type} connector requires session_token")
|
||||||
|
|
||||||
if not ConnectorCreator.is_supported(source_type):
|
if not ConnectorCreator.is_supported(source_type):
|
||||||
raise ValueError(
|
raise ValueError(f"Unsupported connector type: {source_type}. Supported types: {ConnectorCreator.get_supported_connectors()}")
|
||||||
f"Unsupported connector type: {source_type}. Supported types: {ConnectorCreator.get_supported_connectors()}"
|
|
||||||
)
|
|
||||||
|
|
||||||
remote_loader = ConnectorCreator.create_connector(
|
remote_loader = ConnectorCreator.create_connector(source_type, session_token)
|
||||||
source_type, session_token
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create a clean config for storage
|
# Create a clean config for storage
|
||||||
api_source_config = {
|
api_source_config = {
|
||||||
"file_ids": file_ids or [],
|
"file_ids": file_ids or [],
|
||||||
"folder_ids": folder_ids or [],
|
"folder_ids": folder_ids or [],
|
||||||
"recursive": recursive,
|
"recursive": recursive
|
||||||
}
|
}
|
||||||
|
|
||||||
# Step 2: Download files to temp directory
|
# Step 2: Download files to temp directory
|
||||||
self.update_state(
|
self.update_state(state="PROGRESS", meta={"current": 20, "status": "Downloading files"})
|
||||||
state="PROGRESS", meta={"current": 20, "status": "Downloading files"}
|
|
||||||
)
|
|
||||||
download_info = remote_loader.download_to_directory(
|
download_info = remote_loader.download_to_directory(
|
||||||
temp_dir, api_source_config
|
temp_dir,
|
||||||
|
api_source_config
|
||||||
)
|
)
|
||||||
|
|
||||||
if download_info.get("empty_result", False) or not download_info.get(
|
if download_info.get("empty_result", False) or not download_info.get("files_downloaded", 0):
|
||||||
"files_downloaded", 0
|
|
||||||
):
|
|
||||||
logging.warning(f"No files were downloaded from {source_type}")
|
logging.warning(f"No files were downloaded from {source_type}")
|
||||||
# Create empty result directly instead of calling a separate method
|
# Create empty result directly instead of calling a separate method
|
||||||
return {
|
return {
|
||||||
@@ -1000,42 +913,28 @@ def ingest_connector(
|
|||||||
"source_config": api_source_config,
|
"source_config": api_source_config,
|
||||||
"directory_structure": "{}",
|
"directory_structure": "{}",
|
||||||
}
|
}
|
||||||
|
|
||||||
# Step 3: Use SimpleDirectoryReader to process downloaded files
|
# Step 3: Use SimpleDirectoryReader to process downloaded files
|
||||||
self.update_state(
|
self.update_state(state="PROGRESS", meta={"current": 40, "status": "Processing files"})
|
||||||
state="PROGRESS", meta={"current": 40, "status": "Processing files"}
|
|
||||||
)
|
|
||||||
reader = SimpleDirectoryReader(
|
reader = SimpleDirectoryReader(
|
||||||
input_dir=temp_dir,
|
input_dir=temp_dir,
|
||||||
recursive=True,
|
recursive=True,
|
||||||
required_exts=[
|
required_exts=[
|
||||||
".rst",
|
".rst", ".md", ".pdf", ".txt", ".docx", ".csv", ".epub",
|
||||||
".md",
|
".html", ".mdx", ".json", ".xlsx", ".pptx", ".png",
|
||||||
".pdf",
|
".jpg", ".jpeg",
|
||||||
".txt",
|
|
||||||
".docx",
|
|
||||||
".csv",
|
|
||||||
".epub",
|
|
||||||
".html",
|
|
||||||
".mdx",
|
|
||||||
".json",
|
|
||||||
".xlsx",
|
|
||||||
".pptx",
|
|
||||||
".png",
|
|
||||||
".jpg",
|
|
||||||
".jpeg",
|
|
||||||
],
|
],
|
||||||
exclude_hidden=True,
|
exclude_hidden=True,
|
||||||
file_metadata=metadata_from_filename,
|
file_metadata=metadata_from_filename,
|
||||||
)
|
)
|
||||||
raw_docs = reader.load_data()
|
raw_docs = reader.load_data()
|
||||||
directory_structure = getattr(reader, "directory_structure", {})
|
directory_structure = getattr(reader, 'directory_structure', {})
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Step 4: Process documents (chunking, embedding, etc.)
|
# Step 4: Process documents (chunking, embedding, etc.)
|
||||||
self.update_state(
|
self.update_state(state="PROGRESS", meta={"current": 60, "status": "Processing documents"})
|
||||||
state="PROGRESS", meta={"current": 60, "status": "Processing documents"}
|
|
||||||
)
|
|
||||||
|
|
||||||
chunker = Chunker(
|
chunker = Chunker(
|
||||||
chunking_strategy="classic_chunk",
|
chunking_strategy="classic_chunk",
|
||||||
max_tokens=MAX_TOKENS,
|
max_tokens=MAX_TOKENS,
|
||||||
@@ -1043,26 +942,22 @@ def ingest_connector(
|
|||||||
duplicate_headers=False,
|
duplicate_headers=False,
|
||||||
)
|
)
|
||||||
raw_docs = chunker.chunk(documents=raw_docs)
|
raw_docs = chunker.chunk(documents=raw_docs)
|
||||||
|
|
||||||
# Preserve source information in document metadata
|
# Preserve source information in document metadata
|
||||||
for doc in raw_docs:
|
for doc in raw_docs:
|
||||||
if hasattr(doc, "extra_info") and doc.extra_info:
|
if hasattr(doc, 'extra_info') and doc.extra_info:
|
||||||
source = doc.extra_info.get("source")
|
source = doc.extra_info.get('source')
|
||||||
if source and os.path.isabs(source):
|
if source and os.path.isabs(source):
|
||||||
# Convert absolute path to relative path
|
# Convert absolute path to relative path
|
||||||
doc.extra_info["source"] = os.path.relpath(
|
doc.extra_info['source'] = os.path.relpath(source, start=temp_dir)
|
||||||
source, start=temp_dir
|
|
||||||
)
|
|
||||||
|
|
||||||
docs = [Document.to_langchain_format(raw_doc) for raw_doc in raw_docs]
|
docs = [Document.to_langchain_format(raw_doc) for raw_doc in raw_docs]
|
||||||
|
|
||||||
if operation_mode == "upload":
|
if operation_mode == "upload":
|
||||||
id = ObjectId()
|
id = ObjectId()
|
||||||
elif operation_mode == "sync":
|
elif operation_mode == "sync":
|
||||||
if not doc_id or not ObjectId.is_valid(doc_id):
|
if not doc_id or not ObjectId.is_valid(doc_id):
|
||||||
logging.error(
|
logging.error("Invalid doc_id provided for sync operation: %s", doc_id)
|
||||||
"Invalid doc_id provided for sync operation: %s", doc_id
|
|
||||||
)
|
|
||||||
raise ValueError("doc_id must be provided for sync operation.")
|
raise ValueError("doc_id must be provided for sync operation.")
|
||||||
id = ObjectId(doc_id)
|
id = ObjectId(doc_id)
|
||||||
else:
|
else:
|
||||||
@@ -1071,9 +966,7 @@ def ingest_connector(
|
|||||||
vector_store_path = os.path.join(temp_dir, "vector_store")
|
vector_store_path = os.path.join(temp_dir, "vector_store")
|
||||||
os.makedirs(vector_store_path, exist_ok=True)
|
os.makedirs(vector_store_path, exist_ok=True)
|
||||||
|
|
||||||
self.update_state(
|
self.update_state(state="PROGRESS", meta={"current": 80, "status": "Storing documents"})
|
||||||
state="PROGRESS", meta={"current": 80, "status": "Storing documents"}
|
|
||||||
)
|
|
||||||
embed_and_store_documents(docs, vector_store_path, id, self)
|
embed_and_store_documents(docs, vector_store_path, id, self)
|
||||||
|
|
||||||
tokens = count_tokens_docs(docs)
|
tokens = count_tokens_docs(docs)
|
||||||
@@ -1085,12 +978,13 @@ def ingest_connector(
|
|||||||
"tokens": tokens,
|
"tokens": tokens,
|
||||||
"retriever": retriever,
|
"retriever": retriever,
|
||||||
"id": str(id),
|
"id": str(id),
|
||||||
"type": "connector:file",
|
"type": "connector",
|
||||||
"remote_data": json.dumps(
|
"remote_data": json.dumps({
|
||||||
{"provider": source_type, **api_source_config}
|
"provider": source_type,
|
||||||
),
|
**api_source_config
|
||||||
|
}),
|
||||||
"directory_structure": json.dumps(directory_structure),
|
"directory_structure": json.dumps(directory_structure),
|
||||||
"sync_frequency": sync_frequency,
|
"sync_frequency": sync_frequency
|
||||||
}
|
}
|
||||||
|
|
||||||
if operation_mode == "sync":
|
if operation_mode == "sync":
|
||||||
@@ -1101,9 +995,7 @@ def ingest_connector(
|
|||||||
upload_index(vector_store_path, file_data)
|
upload_index(vector_store_path, file_data)
|
||||||
|
|
||||||
# Ensure we mark the task as complete
|
# Ensure we mark the task as complete
|
||||||
self.update_state(
|
self.update_state(state="PROGRESS", meta={"current": 100, "status": "Complete"})
|
||||||
state="PROGRESS", meta={"current": 100, "status": "Complete"}
|
|
||||||
)
|
|
||||||
|
|
||||||
logging.info(f"Remote ingestion completed: {job_name}")
|
logging.info(f"Remote ingestion completed: {job_name}")
|
||||||
|
|
||||||
@@ -1113,136 +1005,9 @@ def ingest_connector(
|
|||||||
"tokens": tokens,
|
"tokens": tokens,
|
||||||
"type": source_type,
|
"type": source_type,
|
||||||
"id": str(id),
|
"id": str(id),
|
||||||
"status": "complete",
|
"status": "complete"
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error during remote ingestion: {e}", exc_info=True)
|
logging.error(f"Error during remote ingestion: {e}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
def mcp_oauth(self, config: Dict[str, Any], user_id: str = None) -> Dict[str, Any]:
|
|
||||||
"""Worker to handle MCP OAuth flow asynchronously."""
|
|
||||||
|
|
||||||
logging.info(
|
|
||||||
"[MCP OAuth] Worker started for user_id=%s, config=%s", user_id, config
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
from application.agents.tools.mcp_tool import MCPTool
|
|
||||||
|
|
||||||
task_id = self.request.id
|
|
||||||
logging.info("[MCP OAuth] Task ID: %s", task_id)
|
|
||||||
redis_client = get_redis_instance()
|
|
||||||
|
|
||||||
def update_status(status_data: Dict[str, Any]):
|
|
||||||
logging.info("[MCP OAuth] Updating status: %s", status_data)
|
|
||||||
status_key = f"mcp_oauth_status:{task_id}"
|
|
||||||
redis_client.setex(status_key, 600, json.dumps(status_data))
|
|
||||||
|
|
||||||
update_status(
|
|
||||||
{
|
|
||||||
"status": "in_progress",
|
|
||||||
"message": "Starting OAuth flow...",
|
|
||||||
"task_id": task_id,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
tool_config = config.copy()
|
|
||||||
tool_config["oauth_task_id"] = task_id
|
|
||||||
logging.info("[MCP OAuth] Initializing MCPTool with config: %s", tool_config)
|
|
||||||
mcp_tool = MCPTool(tool_config, user_id)
|
|
||||||
|
|
||||||
async def run_oauth_discovery():
|
|
||||||
if not mcp_tool._client:
|
|
||||||
mcp_tool._setup_client()
|
|
||||||
return await mcp_tool._execute_with_client("list_tools")
|
|
||||||
|
|
||||||
update_status(
|
|
||||||
{
|
|
||||||
"status": "awaiting_redirect",
|
|
||||||
"message": "Waiting for OAuth redirect...",
|
|
||||||
"task_id": task_id,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
|
|
||||||
try:
|
|
||||||
logging.info("[MCP OAuth] Starting event loop for OAuth discovery...")
|
|
||||||
tools_response = loop.run_until_complete(run_oauth_discovery())
|
|
||||||
logging.info(
|
|
||||||
"[MCP OAuth] Tools response after async call: %s", tools_response
|
|
||||||
)
|
|
||||||
|
|
||||||
status_key = f"mcp_oauth_status:{task_id}"
|
|
||||||
redis_status = redis_client.get(status_key)
|
|
||||||
if redis_status:
|
|
||||||
logging.info(
|
|
||||||
"[MCP OAuth] Redis status after async call: %s", redis_status
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logging.warning(
|
|
||||||
"[MCP OAuth] No Redis status found after async call for key: %s",
|
|
||||||
status_key,
|
|
||||||
)
|
|
||||||
tools = mcp_tool.get_actions_metadata()
|
|
||||||
|
|
||||||
update_status(
|
|
||||||
{
|
|
||||||
"status": "completed",
|
|
||||||
"message": f"OAuth completed successfully. Found {len(tools)} tools.",
|
|
||||||
"tools": tools,
|
|
||||||
"tools_count": len(tools),
|
|
||||||
"task_id": task_id,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
logging.info(
|
|
||||||
"[MCP OAuth] OAuth flow completed successfully for task_id=%s", task_id
|
|
||||||
)
|
|
||||||
return {"success": True, "tools": tools, "tools_count": len(tools)}
|
|
||||||
except Exception as e:
|
|
||||||
error_msg = f"OAuth flow failed: {str(e)}"
|
|
||||||
logging.error(
|
|
||||||
"[MCP OAuth] Exception in OAuth discovery: %s", error_msg, exc_info=True
|
|
||||||
)
|
|
||||||
update_status(
|
|
||||||
{
|
|
||||||
"status": "error",
|
|
||||||
"message": error_msg,
|
|
||||||
"error": str(e),
|
|
||||||
"task_id": task_id,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return {"success": False, "error": error_msg}
|
|
||||||
finally:
|
|
||||||
logging.info("[MCP OAuth] Closing event loop for task_id=%s", task_id)
|
|
||||||
loop.close()
|
|
||||||
except Exception as e:
|
|
||||||
error_msg = f"Failed to initialize OAuth flow: {str(e)}"
|
|
||||||
logging.error(
|
|
||||||
"[MCP OAuth] Exception during initialization: %s", error_msg, exc_info=True
|
|
||||||
)
|
|
||||||
update_status(
|
|
||||||
{
|
|
||||||
"status": "error",
|
|
||||||
"message": error_msg,
|
|
||||||
"error": str(e),
|
|
||||||
"task_id": task_id,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return {"success": False, "error": error_msg}
|
|
||||||
|
|
||||||
|
|
||||||
def mcp_oauth_status(self, task_id: str) -> Dict[str, Any]:
|
|
||||||
"""Check the status of an MCP OAuth flow."""
|
|
||||||
redis_client = get_redis_instance()
|
|
||||||
status_key = f"mcp_oauth_status:{task_id}"
|
|
||||||
|
|
||||||
status_data = redis_client.get(status_key)
|
|
||||||
if status_data:
|
|
||||||
return json.loads(status_data)
|
|
||||||
return {"status": "not_found", "message": "Status not found"}
|
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ services:
|
|||||||
environment:
|
environment:
|
||||||
- VITE_API_HOST=http://localhost:7091
|
- VITE_API_HOST=http://localhost:7091
|
||||||
- VITE_API_STREAMING=$VITE_API_STREAMING
|
- VITE_API_STREAMING=$VITE_API_STREAMING
|
||||||
- VITE_GOOGLE_CLIENT_ID=$VITE_GOOGLE_CLIENT_ID
|
|
||||||
ports:
|
ports:
|
||||||
- "5173:5173"
|
- "5173:5173"
|
||||||
depends_on:
|
depends_on:
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ services:
|
|||||||
environment:
|
environment:
|
||||||
- VITE_API_HOST=http://localhost:7091
|
- VITE_API_HOST=http://localhost:7091
|
||||||
- VITE_API_STREAMING=$VITE_API_STREAMING
|
- VITE_API_STREAMING=$VITE_API_STREAMING
|
||||||
- VITE_GOOGLE_CLIENT_ID=$VITE_GOOGLE_CLIENT_ID
|
|
||||||
ports:
|
ports:
|
||||||
- "5173:5173"
|
- "5173:5173"
|
||||||
depends_on:
|
depends_on:
|
||||||
|
|||||||
@@ -1,6 +0,0 @@
|
|||||||
{
|
|
||||||
"google-drive-connector": {
|
|
||||||
"title": "🔗 Google Drive",
|
|
||||||
"href": "/Guides/Integrations/google-drive-connector"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,212 +0,0 @@
|
|||||||
---
|
|
||||||
title: Google Drive Connector
|
|
||||||
description: Connect your Google Drive as an external knowledge base to upload and process files directly from your Google Drive account.
|
|
||||||
---
|
|
||||||
|
|
||||||
import { Callout } from 'nextra/components'
|
|
||||||
import { Steps } from 'nextra/components'
|
|
||||||
|
|
||||||
# Google Drive Connector
|
|
||||||
|
|
||||||
The Google Drive Connector allows you to seamlessly connect your Google Drive account as an external knowledge base. This integration enables you to upload and process files directly from your Google Drive without manually downloading and uploading them to DocsGPT.
|
|
||||||
|
|
||||||
## Features
|
|
||||||
|
|
||||||
- **Direct File Access**: Browse and select files directly from your Google Drive
|
|
||||||
- **Comprehensive File Support**: Supports all major document formats including:
|
|
||||||
- Google Workspace files (Docs, Sheets, Slides)
|
|
||||||
- Microsoft Office files (.docx, .xlsx, .pptx, .doc, .ppt, .xls)
|
|
||||||
- PDF documents
|
|
||||||
- Text files (.txt, .md, .rst, .html, .rtf)
|
|
||||||
- Data files (.csv, .json)
|
|
||||||
- Image files (.png, .jpg, .jpeg)
|
|
||||||
- E-books (.epub)
|
|
||||||
- **Secure Authentication**: Uses OAuth 2.0 for secure access to your Google Drive
|
|
||||||
- **Real-time Sync**: Process files directly from Google Drive without local downloads
|
|
||||||
|
|
||||||
<Callout type="info" emoji="ℹ️">
|
|
||||||
The Google Drive Connector requires proper configuration of Google API credentials. Follow the setup instructions below to enable this feature.
|
|
||||||
</Callout>
|
|
||||||
|
|
||||||
## Prerequisites
|
|
||||||
|
|
||||||
Before setting up the Google Drive Connector, you'll need:
|
|
||||||
|
|
||||||
1. A Google Cloud Platform (GCP) project
|
|
||||||
2. Google Drive API enabled
|
|
||||||
3. OAuth 2.0 credentials configured
|
|
||||||
4. DocsGPT instance with proper environment variables
|
|
||||||
|
|
||||||
## Setup Instructions
|
|
||||||
|
|
||||||
<Steps>
|
|
||||||
|
|
||||||
### Step 1: Create a Google Cloud Project
|
|
||||||
|
|
||||||
1. Go to the [Google Cloud Console](https://console.cloud.google.com/)
|
|
||||||
2. Create a new project or select an existing one
|
|
||||||
3. Note down your Project ID for later use
|
|
||||||
|
|
||||||
### Step 2: Enable Google Drive API
|
|
||||||
|
|
||||||
1. In the Google Cloud Console, navigate to **APIs & Services** > **Library**
|
|
||||||
2. Search for "Google Drive API"
|
|
||||||
3. Click on "Google Drive API" and click **Enable**
|
|
||||||
|
|
||||||
### Step 3: Create OAuth 2.0 Credentials
|
|
||||||
|
|
||||||
1. Go to **APIs & Services** > **Credentials**
|
|
||||||
2. Click **Create Credentials** > **OAuth client ID**
|
|
||||||
3. If prompted, configure the OAuth consent screen:
|
|
||||||
- Choose **External** user type (unless you're using Google Workspace)
|
|
||||||
- Fill in the required fields (App name, User support email, Developer contact)
|
|
||||||
- Add your domain to **Authorized domains** if deploying publicly
|
|
||||||
4. For Application type, select **Web application**
|
|
||||||
5. Add your DocsGPT frontend URL to **Authorized JavaScript origins**:
|
|
||||||
- For local development: `http://localhost:3000`
|
|
||||||
- For production: `https://yourdomain.com`
|
|
||||||
6. Add your DocsGPT callback URL to **Authorized redirect URIs**:
|
|
||||||
- For local development: `http://localhost:7091/api/connectors/callback?provider=google_drive`
|
|
||||||
- For production: `https://yourdomain.com/api/connectors/callback?provider=google_drive`
|
|
||||||
7. Click **Create** and note down the **Client ID** and **Client Secret**
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
### Step 4: Configure Backend Environment Variables
|
|
||||||
|
|
||||||
Add the following environment variables to your backend configuration:
|
|
||||||
|
|
||||||
**For Docker deployment**, add to your `.env` file in the root directory:
|
|
||||||
|
|
||||||
```env
|
|
||||||
# Google Drive Connector Configuration
|
|
||||||
GOOGLE_CLIENT_ID=your_google_client_id_here
|
|
||||||
GOOGLE_CLIENT_SECRET=your_google_client_secret_here
|
|
||||||
```
|
|
||||||
|
|
||||||
**For manual deployment**, set these environment variables in your system or application configuration.
|
|
||||||
|
|
||||||
### Step 5: Configure Frontend Environment Variables
|
|
||||||
|
|
||||||
Add the following environment variables to your frontend `.env` file:
|
|
||||||
|
|
||||||
```env
|
|
||||||
# Google Drive Frontend Configuration
|
|
||||||
VITE_GOOGLE_CLIENT_ID=your_google_client_id_here
|
|
||||||
```
|
|
||||||
|
|
||||||
<Callout type="warning" emoji="⚠️">
|
|
||||||
Make sure to use the same Google Client ID in both backend and frontend configurations.
|
|
||||||
</Callout>
|
|
||||||
|
|
||||||
### Step 6: Restart Your Application
|
|
||||||
|
|
||||||
After configuring the environment variables:
|
|
||||||
|
|
||||||
1. **For Docker**: Restart your Docker containers
|
|
||||||
```bash
|
|
||||||
docker-compose down
|
|
||||||
docker-compose up -d
|
|
||||||
```
|
|
||||||
|
|
||||||
2. **For manual deployment**: Restart both backend and frontend services
|
|
||||||
|
|
||||||
</Steps>
|
|
||||||
|
|
||||||
## Using the Google Drive Connector
|
|
||||||
|
|
||||||
Once configured, you can use the Google Drive Connector to upload files:
|
|
||||||
|
|
||||||
<Steps>
|
|
||||||
|
|
||||||
### Step 1: Access the Upload Interface
|
|
||||||
|
|
||||||
1. Navigate to the DocsGPT interface
|
|
||||||
2. Go to the upload/training section
|
|
||||||
3. You should now see "Google Drive" as an available upload option
|
|
||||||
|
|
||||||
### Step 2: Connect Your Google Account
|
|
||||||
|
|
||||||
1. Select "Google Drive" as your upload method
|
|
||||||
2. Click "Connect to Google Drive"
|
|
||||||
3. You'll be redirected to Google's OAuth consent screen
|
|
||||||
4. Grant the necessary permissions to DocsGPT
|
|
||||||
5. You'll be redirected back to DocsGPT with a successful connection
|
|
||||||
|
|
||||||
### Step 3: Select Files
|
|
||||||
|
|
||||||
1. Once connected, click "Select Files"
|
|
||||||
2. The Google Drive picker will open
|
|
||||||
3. Browse your Google Drive and select the files you want to process
|
|
||||||
4. Click "Select" to confirm your choices
|
|
||||||
|
|
||||||
### Step 4: Process Files
|
|
||||||
|
|
||||||
1. Review your selected files
|
|
||||||
2. Click "Train" or "Upload" to process the files
|
|
||||||
3. DocsGPT will download and process the files from your Google Drive
|
|
||||||
4. Once processing is complete, the files will be available in your knowledge base
|
|
||||||
|
|
||||||
</Steps>
|
|
||||||
|
|
||||||
## Supported File Types
|
|
||||||
|
|
||||||
The Google Drive Connector supports the following file types:
|
|
||||||
|
|
||||||
| File Type | Extensions | Description |
|
|
||||||
|-----------|------------|-------------|
|
|
||||||
| **Google Workspace** | - | Google Docs, Sheets, Slides (automatically converted) |
|
|
||||||
| **Microsoft Office** | .docx, .xlsx, .pptx | Modern Office formats |
|
|
||||||
| **Legacy Office** | .doc, .ppt, .xls | Older Office formats |
|
|
||||||
| **PDF Documents** | .pdf | Portable Document Format |
|
|
||||||
| **Text Files** | .txt, .md, .rst, .html, .rtf | Various text formats |
|
|
||||||
| **Data Files** | .csv, .json | Structured data formats |
|
|
||||||
| **Images** | .png, .jpg, .jpeg | Image files (with OCR if enabled) |
|
|
||||||
| **E-books** | .epub | Electronic publication format |
|
|
||||||
|
|
||||||
## Troubleshooting
|
|
||||||
|
|
||||||
### Common Issues
|
|
||||||
|
|
||||||
**"Google Drive option not appearing"**
|
|
||||||
- Verify that `VITE_GOOGLE_CLIENT_ID` is set in frontend environment
|
|
||||||
- Check that `VITE_GOOGLE_CLIENT_ID` environment variable is present in your frontend configuration
|
|
||||||
- Check browser console for any JavaScript errors
|
|
||||||
- Ensure the frontend has been restarted after adding environment variables
|
|
||||||
|
|
||||||
**"Authentication failed"**
|
|
||||||
- Verify that your OAuth 2.0 credentials are correctly configured
|
|
||||||
- Check that the redirect URI `http://<your-domain>/api/connectors/callback?provider=google_drive` is correctly added in GCP console
|
|
||||||
- Ensure the Google Drive API is enabled in your GCP project
|
|
||||||
|
|
||||||
**"Permission denied" errors**
|
|
||||||
- Verify that the OAuth consent screen is properly configured
|
|
||||||
- Check that your Google account has access to the files you're trying to select
|
|
||||||
- Ensure the required scopes are granted during authentication
|
|
||||||
|
|
||||||
**"Files not processing"**
|
|
||||||
- Check that the backend environment variables are correctly set
|
|
||||||
- Verify that the OAuth credentials have the necessary permissions
|
|
||||||
- Check the backend logs for any error messages
|
|
||||||
|
|
||||||
### Environment Variable Checklist
|
|
||||||
|
|
||||||
**Backend (.env in root directory):**
|
|
||||||
- ✅ `GOOGLE_CLIENT_ID`
|
|
||||||
- ✅ `GOOGLE_CLIENT_SECRET`
|
|
||||||
|
|
||||||
**Frontend (.env in frontend directory):**
|
|
||||||
- ✅ `VITE_GOOGLE_CLIENT_ID`
|
|
||||||
|
|
||||||
### Security Considerations
|
|
||||||
|
|
||||||
- Keep your Google Client Secret secure and never expose it in frontend code
|
|
||||||
- Regularly rotate your OAuth credentials
|
|
||||||
- Use HTTPS in production to protect authentication tokens
|
|
||||||
- Ensure proper OAuth consent screen configuration for production use
|
|
||||||
|
|
||||||
<Callout type="tip" emoji="💡">
|
|
||||||
For production deployments, make sure to add your actual domain to the OAuth consent screen and authorized origins/redirect URIs.
|
|
||||||
</Callout>
|
|
||||||
|
|
||||||
|
|
||||||
@@ -20,8 +20,5 @@
|
|||||||
"Architecture": {
|
"Architecture": {
|
||||||
"title": "🏗️ Architecture",
|
"title": "🏗️ Architecture",
|
||||||
"href": "/Guides/Architecture"
|
"href": "/Guides/Architecture"
|
||||||
},
|
|
||||||
"Integrations": {
|
|
||||||
"title": "🔗 Integrations"
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
11
frontend/package-lock.json
generated
@@ -21,7 +21,6 @@
|
|||||||
"react-chartjs-2": "^5.3.0",
|
"react-chartjs-2": "^5.3.0",
|
||||||
"react-dom": "^19.0.0",
|
"react-dom": "^19.0.0",
|
||||||
"react-dropzone": "^14.3.8",
|
"react-dropzone": "^14.3.8",
|
||||||
"react-google-drive-picker": "^1.2.2",
|
|
||||||
"react-i18next": "^15.4.0",
|
"react-i18next": "^15.4.0",
|
||||||
"react-markdown": "^9.0.1",
|
"react-markdown": "^9.0.1",
|
||||||
"react-redux": "^9.2.0",
|
"react-redux": "^9.2.0",
|
||||||
@@ -9383,16 +9382,6 @@
|
|||||||
"react": ">= 16.8 || 18.0.0"
|
"react": ">= 16.8 || 18.0.0"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"node_modules/react-google-drive-picker": {
|
|
||||||
"version": "1.2.2",
|
|
||||||
"resolved": "https://registry.npmjs.org/react-google-drive-picker/-/react-google-drive-picker-1.2.2.tgz",
|
|
||||||
"integrity": "sha512-x30mYkt9MIwPCgL+fyK75HZ8E6G5L/WGW0bfMG6kbD4NG2kmdlmV9oH5lPa6P6d46y9hj5Y3btAMrZd4JRRkSA==",
|
|
||||||
"license": "MIT",
|
|
||||||
"peerDependencies": {
|
|
||||||
"react": ">=17.0.0",
|
|
||||||
"react-dom": ">=17.0.0"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"node_modules/react-i18next": {
|
"node_modules/react-i18next": {
|
||||||
"version": "15.4.0",
|
"version": "15.4.0",
|
||||||
"resolved": "https://registry.npmjs.org/react-i18next/-/react-i18next-15.4.0.tgz",
|
"resolved": "https://registry.npmjs.org/react-i18next/-/react-i18next-15.4.0.tgz",
|
||||||
|
|||||||
@@ -32,7 +32,6 @@
|
|||||||
"react-chartjs-2": "^5.3.0",
|
"react-chartjs-2": "^5.3.0",
|
||||||
"react-dom": "^19.0.0",
|
"react-dom": "^19.0.0",
|
||||||
"react-dropzone": "^14.3.8",
|
"react-dropzone": "^14.3.8",
|
||||||
"react-google-drive-picker": "^1.2.2",
|
|
||||||
"react-i18next": "^15.4.0",
|
"react-i18next": "^15.4.0",
|
||||||
"react-markdown": "^9.0.1",
|
"react-markdown": "^9.0.1",
|
||||||
"react-redux": "^9.2.0",
|
"react-redux": "^9.2.0",
|
||||||
|
|||||||
@@ -1,4 +0,0 @@
|
|||||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" width="64" height="64" color="#000000" fill="none">
|
|
||||||
<path d="M3.49994 11.7501L11.6717 3.57855C12.7762 2.47398 14.5672 2.47398 15.6717 3.57855C16.7762 4.68312 16.7762 6.47398 15.6717 7.57855M15.6717 7.57855L9.49994 13.7501M15.6717 7.57855C16.7762 6.47398 18.5672 6.47398 19.6717 7.57855C20.7762 8.68312 20.7762 10.474 19.6717 11.5785L12.7072 18.543C12.3167 18.9335 12.3167 19.5667 12.7072 19.9572L13.9999 21.2499" stroke="currentColor" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"></path>
|
|
||||||
<path d="M17.4999 9.74921L11.3282 15.921C10.2237 17.0255 8.43272 17.0255 7.32823 15.921C6.22373 14.8164 6.22373 13.0255 7.32823 11.921L13.4999 5.74939" stroke="currentColor" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"></path>
|
|
||||||
</svg>
|
|
||||||
|
Before Width: | Height: | Size: 831 B |
@@ -29,7 +29,7 @@ export default function Hero({
|
|||||||
</div>
|
</div>
|
||||||
|
|
||||||
{/* Demo Buttons Section */}
|
{/* Demo Buttons Section */}
|
||||||
<div className="mb-3 w-full max-w-full md:mb-3">
|
<div className="mb-8 w-full max-w-full md:mb-16">
|
||||||
<div className="grid grid-cols-1 gap-3 text-xs md:grid-cols-1 md:gap-4 lg:grid-cols-2">
|
<div className="grid grid-cols-1 gap-3 text-xs md:grid-cols-1 md:gap-4 lg:grid-cols-2">
|
||||||
{demos?.map(
|
{demos?.map(
|
||||||
(demo: { header: string; query: string }, key: number) =>
|
(demo: { header: string; query: string }, key: number) =>
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import Add from './assets/add.svg';
|
|||||||
import DocsGPT3 from './assets/cute_docsgpt3.svg';
|
import DocsGPT3 from './assets/cute_docsgpt3.svg';
|
||||||
import Discord from './assets/discord.svg';
|
import Discord from './assets/discord.svg';
|
||||||
import Expand from './assets/expand.svg';
|
import Expand from './assets/expand.svg';
|
||||||
import Github from './assets/git_nav.svg';
|
import Github from './assets/github.svg';
|
||||||
import Hamburger from './assets/hamburger.svg';
|
import Hamburger from './assets/hamburger.svg';
|
||||||
import openNewChat from './assets/openNewChat.svg';
|
import openNewChat from './assets/openNewChat.svg';
|
||||||
import Pin from './assets/pin.svg';
|
import Pin from './assets/pin.svg';
|
||||||
@@ -568,8 +568,6 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
|
|||||||
>
|
>
|
||||||
<img
|
<img
|
||||||
src={Discord}
|
src={Discord}
|
||||||
width={24}
|
|
||||||
height={24}
|
|
||||||
alt="Join Discord community"
|
alt="Join Discord community"
|
||||||
className="m-2 w-6 self-center filter dark:invert"
|
className="m-2 w-6 self-center filter dark:invert"
|
||||||
/>
|
/>
|
||||||
@@ -583,10 +581,8 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
|
|||||||
>
|
>
|
||||||
<img
|
<img
|
||||||
src={Twitter}
|
src={Twitter}
|
||||||
width={20}
|
|
||||||
height={20}
|
|
||||||
alt="Follow us on Twitter"
|
alt="Follow us on Twitter"
|
||||||
className="m-2 self-center filter dark:invert"
|
className="m-2 w-5 self-center filter dark:invert"
|
||||||
/>
|
/>
|
||||||
</NavLink>
|
</NavLink>
|
||||||
<NavLink
|
<NavLink
|
||||||
@@ -599,9 +595,7 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
|
|||||||
<img
|
<img
|
||||||
src={Github}
|
src={Github}
|
||||||
alt="View on GitHub"
|
alt="View on GitHub"
|
||||||
width={28}
|
className="m-2 w-6 self-center filter dark:invert"
|
||||||
height={28}
|
|
||||||
className="m-2 self-center filter dark:invert"
|
|
||||||
/>
|
/>
|
||||||
</NavLink>
|
</NavLink>
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
@@ -45,7 +45,6 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
|||||||
description: '',
|
description: '',
|
||||||
image: '',
|
image: '',
|
||||||
source: '',
|
source: '',
|
||||||
sources: [],
|
|
||||||
chunks: '',
|
chunks: '',
|
||||||
retriever: '',
|
retriever: '',
|
||||||
prompt_id: 'default',
|
prompt_id: 'default',
|
||||||
@@ -151,41 +150,7 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
|||||||
const formData = new FormData();
|
const formData = new FormData();
|
||||||
formData.append('name', agent.name);
|
formData.append('name', agent.name);
|
||||||
formData.append('description', agent.description);
|
formData.append('description', agent.description);
|
||||||
|
formData.append('source', agent.source);
|
||||||
if (selectedSourceIds.size > 1) {
|
|
||||||
const sourcesArray = Array.from(selectedSourceIds)
|
|
||||||
.map((id) => {
|
|
||||||
const sourceDoc = sourceDocs?.find(
|
|
||||||
(source) =>
|
|
||||||
source.id === id || source.retriever === id || source.name === id,
|
|
||||||
);
|
|
||||||
if (sourceDoc?.name === 'Default' && !sourceDoc?.id) {
|
|
||||||
return 'default';
|
|
||||||
}
|
|
||||||
return sourceDoc?.id || id;
|
|
||||||
})
|
|
||||||
.filter(Boolean);
|
|
||||||
formData.append('sources', JSON.stringify(sourcesArray));
|
|
||||||
formData.append('source', '');
|
|
||||||
} else if (selectedSourceIds.size === 1) {
|
|
||||||
const singleSourceId = Array.from(selectedSourceIds)[0];
|
|
||||||
const sourceDoc = sourceDocs?.find(
|
|
||||||
(source) =>
|
|
||||||
source.id === singleSourceId ||
|
|
||||||
source.retriever === singleSourceId ||
|
|
||||||
source.name === singleSourceId,
|
|
||||||
);
|
|
||||||
let finalSourceId;
|
|
||||||
if (sourceDoc?.name === 'Default' && !sourceDoc?.id)
|
|
||||||
finalSourceId = 'default';
|
|
||||||
else finalSourceId = sourceDoc?.id || singleSourceId;
|
|
||||||
formData.append('source', String(finalSourceId));
|
|
||||||
formData.append('sources', JSON.stringify([]));
|
|
||||||
} else {
|
|
||||||
formData.append('source', '');
|
|
||||||
formData.append('sources', JSON.stringify([]));
|
|
||||||
}
|
|
||||||
|
|
||||||
formData.append('chunks', agent.chunks);
|
formData.append('chunks', agent.chunks);
|
||||||
formData.append('retriever', agent.retriever);
|
formData.append('retriever', agent.retriever);
|
||||||
formData.append('prompt_id', agent.prompt_id);
|
formData.append('prompt_id', agent.prompt_id);
|
||||||
@@ -231,41 +196,7 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
|||||||
const formData = new FormData();
|
const formData = new FormData();
|
||||||
formData.append('name', agent.name);
|
formData.append('name', agent.name);
|
||||||
formData.append('description', agent.description);
|
formData.append('description', agent.description);
|
||||||
|
formData.append('source', agent.source);
|
||||||
if (selectedSourceIds.size > 1) {
|
|
||||||
const sourcesArray = Array.from(selectedSourceIds)
|
|
||||||
.map((id) => {
|
|
||||||
const sourceDoc = sourceDocs?.find(
|
|
||||||
(source) =>
|
|
||||||
source.id === id || source.retriever === id || source.name === id,
|
|
||||||
);
|
|
||||||
if (sourceDoc?.name === 'Default' && !sourceDoc?.id) {
|
|
||||||
return 'default';
|
|
||||||
}
|
|
||||||
return sourceDoc?.id || id;
|
|
||||||
})
|
|
||||||
.filter(Boolean);
|
|
||||||
formData.append('sources', JSON.stringify(sourcesArray));
|
|
||||||
formData.append('source', '');
|
|
||||||
} else if (selectedSourceIds.size === 1) {
|
|
||||||
const singleSourceId = Array.from(selectedSourceIds)[0];
|
|
||||||
const sourceDoc = sourceDocs?.find(
|
|
||||||
(source) =>
|
|
||||||
source.id === singleSourceId ||
|
|
||||||
source.retriever === singleSourceId ||
|
|
||||||
source.name === singleSourceId,
|
|
||||||
);
|
|
||||||
let finalSourceId;
|
|
||||||
if (sourceDoc?.name === 'Default' && !sourceDoc?.id)
|
|
||||||
finalSourceId = 'default';
|
|
||||||
else finalSourceId = sourceDoc?.id || singleSourceId;
|
|
||||||
formData.append('source', String(finalSourceId));
|
|
||||||
formData.append('sources', JSON.stringify([]));
|
|
||||||
} else {
|
|
||||||
formData.append('source', '');
|
|
||||||
formData.append('sources', JSON.stringify([]));
|
|
||||||
}
|
|
||||||
|
|
||||||
formData.append('chunks', agent.chunks);
|
formData.append('chunks', agent.chunks);
|
||||||
formData.append('retriever', agent.retriever);
|
formData.append('retriever', agent.retriever);
|
||||||
formData.append('prompt_id', agent.prompt_id);
|
formData.append('prompt_id', agent.prompt_id);
|
||||||
@@ -362,33 +293,9 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
|||||||
throw new Error('Failed to fetch agent');
|
throw new Error('Failed to fetch agent');
|
||||||
}
|
}
|
||||||
const data = await response.json();
|
const data = await response.json();
|
||||||
|
if (data.source) setSelectedSourceIds(new Set([data.source]));
|
||||||
if (data.sources && data.sources.length > 0) {
|
else if (data.retriever)
|
||||||
const mappedSources = data.sources.map((sourceId: string) => {
|
|
||||||
if (sourceId === 'default') {
|
|
||||||
const defaultSource = sourceDocs?.find(
|
|
||||||
(source) => source.name === 'Default',
|
|
||||||
);
|
|
||||||
return defaultSource?.retriever || 'classic';
|
|
||||||
}
|
|
||||||
return sourceId;
|
|
||||||
});
|
|
||||||
setSelectedSourceIds(new Set(mappedSources));
|
|
||||||
} else if (data.source) {
|
|
||||||
if (data.source === 'default') {
|
|
||||||
const defaultSource = sourceDocs?.find(
|
|
||||||
(source) => source.name === 'Default',
|
|
||||||
);
|
|
||||||
setSelectedSourceIds(
|
|
||||||
new Set([defaultSource?.retriever || 'classic']),
|
|
||||||
);
|
|
||||||
} else {
|
|
||||||
setSelectedSourceIds(new Set([data.source]));
|
|
||||||
}
|
|
||||||
} else if (data.retriever) {
|
|
||||||
setSelectedSourceIds(new Set([data.retriever]));
|
setSelectedSourceIds(new Set([data.retriever]));
|
||||||
}
|
|
||||||
|
|
||||||
if (data.tools) setSelectedToolIds(new Set(data.tools));
|
if (data.tools) setSelectedToolIds(new Set(data.tools));
|
||||||
if (data.status === 'draft') setEffectiveMode('draft');
|
if (data.status === 'draft') setEffectiveMode('draft');
|
||||||
if (data.json_schema) {
|
if (data.json_schema) {
|
||||||
@@ -404,57 +311,25 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
|||||||
}, [agentId, mode, token]);
|
}, [agentId, mode, token]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
const selectedSources = Array.from(selectedSourceIds)
|
const selectedSource = Array.from(selectedSourceIds).map((id) =>
|
||||||
.map((id) =>
|
sourceDocs?.find(
|
||||||
sourceDocs?.find(
|
(source) =>
|
||||||
(source) =>
|
source.id === id || source.retriever === id || source.name === id,
|
||||||
source.id === id || source.retriever === id || source.name === id,
|
),
|
||||||
),
|
);
|
||||||
)
|
if (selectedSource[0]?.model === embeddingsName) {
|
||||||
.filter(Boolean);
|
if (selectedSource[0] && 'id' in selectedSource[0]) {
|
||||||
|
|
||||||
if (selectedSources.length > 0) {
|
|
||||||
// Handle multiple sources
|
|
||||||
if (selectedSources.length > 1) {
|
|
||||||
// Multiple sources selected - store in sources array
|
|
||||||
const sourceIds = selectedSources
|
|
||||||
.map((source) => source?.id)
|
|
||||||
.filter((id): id is string => Boolean(id));
|
|
||||||
setAgent((prev) => ({
|
setAgent((prev) => ({
|
||||||
...prev,
|
...prev,
|
||||||
sources: sourceIds,
|
source: selectedSource[0]?.id || 'default',
|
||||||
source: '', // Clear single source for multiple sources
|
|
||||||
retriever: '',
|
retriever: '',
|
||||||
}));
|
}));
|
||||||
} else {
|
} else
|
||||||
// Single source selected - maintain backward compatibility
|
setAgent((prev) => ({
|
||||||
const selectedSource = selectedSources[0];
|
...prev,
|
||||||
if (selectedSource?.model === embeddingsName) {
|
source: '',
|
||||||
if (selectedSource && 'id' in selectedSource) {
|
retriever: selectedSource[0]?.retriever || 'classic',
|
||||||
setAgent((prev) => ({
|
}));
|
||||||
...prev,
|
|
||||||
source: selectedSource?.id || 'default',
|
|
||||||
sources: [], // Clear sources array for single source
|
|
||||||
retriever: '',
|
|
||||||
}));
|
|
||||||
} else {
|
|
||||||
setAgent((prev) => ({
|
|
||||||
...prev,
|
|
||||||
source: '',
|
|
||||||
sources: [], // Clear sources array
|
|
||||||
retriever: selectedSource?.retriever || 'classic',
|
|
||||||
}));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// No sources selected
|
|
||||||
setAgent((prev) => ({
|
|
||||||
...prev,
|
|
||||||
source: '',
|
|
||||||
sources: [],
|
|
||||||
retriever: '',
|
|
||||||
}));
|
|
||||||
}
|
}
|
||||||
}, [selectedSourceIds]);
|
}, [selectedSourceIds]);
|
||||||
|
|
||||||
@@ -586,7 +461,7 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
|||||||
onChange={(e) => setAgent({ ...agent, name: e.target.value })}
|
onChange={(e) => setAgent({ ...agent, name: e.target.value })}
|
||||||
/>
|
/>
|
||||||
<textarea
|
<textarea
|
||||||
className="border-silver text-jet dark:bg-raisin-black dark:text-bright-gray dark:placeholder:text-silver mt-3 h-32 w-full rounded-xl border bg-white px-5 py-4 text-sm outline-hidden placeholder:text-gray-400 dark:border-[#7E7E7E]"
|
className="border-silver text-jet dark:bg-raisin-black dark:text-bright-gray dark:placeholder:text-silver mt-3 h-32 w-full rounded-3xl border bg-white px-5 py-4 text-sm outline-hidden placeholder:text-gray-400 dark:border-[#7E7E7E]"
|
||||||
placeholder="Describe your agent"
|
placeholder="Describe your agent"
|
||||||
value={agent.description}
|
value={agent.description}
|
||||||
onChange={(e) =>
|
onChange={(e) =>
|
||||||
@@ -635,7 +510,7 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
|||||||
)
|
)
|
||||||
.filter(Boolean)
|
.filter(Boolean)
|
||||||
.join(', ')
|
.join(', ')
|
||||||
: 'Select sources'}
|
: 'Select source'}
|
||||||
</button>
|
</button>
|
||||||
<MultiSelectPopup
|
<MultiSelectPopup
|
||||||
isOpen={isSourcePopupOpen}
|
isOpen={isSourcePopupOpen}
|
||||||
@@ -651,10 +526,12 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
|||||||
selectedIds={selectedSourceIds}
|
selectedIds={selectedSourceIds}
|
||||||
onSelectionChange={(newSelectedIds: Set<string | number>) => {
|
onSelectionChange={(newSelectedIds: Set<string | number>) => {
|
||||||
setSelectedSourceIds(newSelectedIds);
|
setSelectedSourceIds(newSelectedIds);
|
||||||
|
setIsSourcePopupOpen(false);
|
||||||
}}
|
}}
|
||||||
title="Select Sources"
|
title="Select Source"
|
||||||
searchPlaceholder="Search sources..."
|
searchPlaceholder="Search sources..."
|
||||||
noOptionsMessage="No sources available"
|
noOptionsMessage="No source available"
|
||||||
|
singleSelect={true}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
<div className="mt-3">
|
<div className="mt-3">
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ export type Agent = {
|
|||||||
description: string;
|
description: string;
|
||||||
image: string;
|
image: string;
|
||||||
source: string;
|
source: string;
|
||||||
sources?: string[];
|
|
||||||
chunks: string;
|
chunks: string;
|
||||||
retriever: string;
|
retriever: string;
|
||||||
prompt_id: string;
|
prompt_id: string;
|
||||||
|
|||||||
@@ -57,10 +57,6 @@ const endpoints = {
|
|||||||
DIRECTORY_STRUCTURE: (docId: string) =>
|
DIRECTORY_STRUCTURE: (docId: string) =>
|
||||||
`/api/directory_structure?id=${docId}`,
|
`/api/directory_structure?id=${docId}`,
|
||||||
MANAGE_SOURCE_FILES: '/api/manage_source_files',
|
MANAGE_SOURCE_FILES: '/api/manage_source_files',
|
||||||
MCP_TEST_CONNECTION: '/api/mcp_server/test',
|
|
||||||
MCP_SAVE_SERVER: '/api/mcp_server/save',
|
|
||||||
MCP_OAUTH_STATUS: (task_id: string) =>
|
|
||||||
`/api/mcp_server/oauth_status/${task_id}`,
|
|
||||||
},
|
},
|
||||||
CONVERSATION: {
|
CONVERSATION: {
|
||||||
ANSWER: '/api/answer',
|
ANSWER: '/api/answer',
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import { getSessionToken } from '../../utils/providerUtils';
|
|
||||||
import apiClient from '../client';
|
import apiClient from '../client';
|
||||||
import endpoints from '../endpoints';
|
import endpoints from '../endpoints';
|
||||||
|
import { getSessionToken } from '../../utils/providerUtils';
|
||||||
|
|
||||||
const userService = {
|
const userService = {
|
||||||
getConfig: (): Promise<any> => apiClient.get(endpoints.USER.CONFIG, null),
|
getConfig: (): Promise<any> => apiClient.get(endpoints.USER.CONFIG, null),
|
||||||
@@ -90,10 +90,7 @@ const userService = {
|
|||||||
path?: string,
|
path?: string,
|
||||||
search?: string,
|
search?: string,
|
||||||
): Promise<any> =>
|
): Promise<any> =>
|
||||||
apiClient.get(
|
apiClient.get(endpoints.USER.GET_CHUNKS(docId, page, perPage, path, search), token),
|
||||||
endpoints.USER.GET_CHUNKS(docId, page, perPage, path, search),
|
|
||||||
token,
|
|
||||||
),
|
|
||||||
addChunk: (data: any, token: string | null): Promise<any> =>
|
addChunk: (data: any, token: string | null): Promise<any> =>
|
||||||
apiClient.post(endpoints.USER.ADD_CHUNK, data, token),
|
apiClient.post(endpoints.USER.ADD_CHUNK, data, token),
|
||||||
deleteChunk: (
|
deleteChunk: (
|
||||||
@@ -108,26 +105,16 @@ const userService = {
|
|||||||
apiClient.get(endpoints.USER.DIRECTORY_STRUCTURE(docId), token),
|
apiClient.get(endpoints.USER.DIRECTORY_STRUCTURE(docId), token),
|
||||||
manageSourceFiles: (data: FormData, token: string | null): Promise<any> =>
|
manageSourceFiles: (data: FormData, token: string | null): Promise<any> =>
|
||||||
apiClient.postFormData(endpoints.USER.MANAGE_SOURCE_FILES, data, token),
|
apiClient.postFormData(endpoints.USER.MANAGE_SOURCE_FILES, data, token),
|
||||||
testMCPConnection: (data: any, token: string | null): Promise<any> =>
|
syncConnector: (docId: string, provider: string, token: string | null): Promise<any> => {
|
||||||
apiClient.post(endpoints.USER.MCP_TEST_CONNECTION, data, token),
|
|
||||||
saveMCPServer: (data: any, token: string | null): Promise<any> =>
|
|
||||||
apiClient.post(endpoints.USER.MCP_SAVE_SERVER, data, token),
|
|
||||||
getMCPOAuthStatus: (task_id: string, token: string | null): Promise<any> =>
|
|
||||||
apiClient.get(endpoints.USER.MCP_OAUTH_STATUS(task_id), token),
|
|
||||||
syncConnector: (
|
|
||||||
docId: string,
|
|
||||||
provider: string,
|
|
||||||
token: string | null,
|
|
||||||
): Promise<any> => {
|
|
||||||
const sessionToken = getSessionToken(provider);
|
const sessionToken = getSessionToken(provider);
|
||||||
return apiClient.post(
|
return apiClient.post(
|
||||||
endpoints.USER.SYNC_CONNECTOR,
|
endpoints.USER.SYNC_CONNECTOR,
|
||||||
{
|
{
|
||||||
source_id: docId,
|
source_id: docId,
|
||||||
session_token: sessionToken,
|
session_token: sessionToken,
|
||||||
provider: provider,
|
provider: provider
|
||||||
},
|
},
|
||||||
token,
|
token
|
||||||
);
|
);
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
<svg width="16px" height="16px" viewBox="0 0 1024 1024" class="icon" version="1.1" xmlns="http://www.w3.org/2000/svg" fill="#11ee1c" stroke="#11ee1c" stroke-width="83.96799999999999"><g id="SVGRepo_bgCarrier" stroke-width="0"></g><g id="SVGRepo_tracerCarrier" stroke-linecap="round" stroke-linejoin="round"></g><g id="SVGRepo_iconCarrier"><path d="M866.133333 258.133333L362.666667 761.6l-204.8-204.8L98.133333 618.666667 362.666667 881.066667l563.2-563.2z" fill="#0C9D35"></path></g></svg>
|
<svg width="16px" height="16px" viewBox="0 0 1024 1024" class="icon" version="1.1" xmlns="http://www.w3.org/2000/svg" fill="#11ee1c" stroke="#11ee1c" stroke-width="83.96799999999999"><g id="SVGRepo_bgCarrier" stroke-width="0"></g><g id="SVGRepo_tracerCarrier" stroke-linecap="round" stroke-linejoin="round"></g><g id="SVGRepo_iconCarrier"><path d="M866.133333 258.133333L362.666667 761.6l-204.8-204.8L98.133333 618.666667 362.666667 881.066667l563.2-563.2z" fill="#11ee1c"></path></g></svg>
|
||||||
|
Before Width: | Height: | Size: 490 B After Width: | Height: | Size: 490 B |
@@ -1,3 +0,0 @@
|
|||||||
<svg width="22" height="22" viewBox="0 0 22 22" fill="none" xmlns="http://www.w3.org/2000/svg">
|
|
||||||
<path d="M20.2891 15.81L21.7091 14.39L18.4991 11.21L15.4991 10.36L17.4091 10.1L21.5991 6.89999L20.3991 5.29998L16.5891 8.14999L13.9091 8.59999L17.1091 5.40999L15.9991 0.859985L13.9991 1.33999L14.8591 4.78999L13.7591 5.92999C13.5285 5.38882 13.144 4.92736 12.6533 4.60302C12.1625 4.27867 11.5873 4.10574 10.9991 4.10574C10.4108 4.10574 9.83559 4.27867 9.34487 4.60302C8.85414 4.92736 8.4696 5.38882 8.23906 5.92999L7.10906 4.78999L7.99906 1.33999L5.99906 0.859985L4.88906 5.40999L8.08906 8.59999L5.39906 8.14999L1.59906 5.29998L0.399063 6.89999L4.59906 10.1L6.45906 10.41L3.45906 11.26L0.289062 14.39L1.70906 15.81L4.49906 12.99L6.86906 12.32L2.99906 15.64V21.1H4.99906V16.56L6.55906 15.22C6.73264 16.2723 7.27432 17.2287 8.08751 17.9188C8.90071 18.6088 9.93255 18.9876 10.9991 18.9876C12.0656 18.9876 13.0974 18.6088 13.9106 17.9188C14.7238 17.2287 15.2655 16.2723 15.4391 15.22L16.9991 16.56V21.1H18.9991V15.64L15.1291 12.32L17.4991 12.99L20.2891 15.81Z" fill="black"/>
|
|
||||||
</svg>
|
|
||||||
|
Before Width: | Height: | Size: 1.0 KiB |
@@ -1,3 +0,0 @@
|
|||||||
<svg width="24" height="22" viewBox="0 0 24 22" fill="none" xmlns="http://www.w3.org/2000/svg">
|
|
||||||
<path d="M12.01 0.784912C9.928 0.784912 8.256 0.804912 8.267 0.831912C8.277 0.851912 9.975 3.83291 12.041 7.45191L15.801 14.0259H19.561C21.642 14.0259 23.314 14.0059 23.303 13.9789C23.298 13.9589 21.595 10.9779 19.528 7.35891L15.768 0.784912H12.01ZM7.25 2.51491C6.03029 4.61565 4.82028 6.72201 3.62 8.83391L0 15.1679L1.89 18.4659L3.775 21.7629L7.395 15.4279L11.013 9.09791L9.133 5.81091C8.1 4.00391 7.255 2.52091 7.25 2.51491ZM9.509 15.1679L9.306 15.5159C9.192 15.7139 8.346 17.1879 7.426 18.8029C6.864 19.7952 6.29799 20.7852 5.728 21.7729C5.718 21.7989 8.968 21.8149 12.95 21.8149H20.194L21.99 18.6579C22.982 16.9239 23.84 15.4279 23.896 15.3349L24 15.1679H16.751H9.509Z" fill="black"/>
|
|
||||||
</svg>
|
|
||||||
|
Before Width: | Height: | Size: 792 B |
@@ -1,10 +1,3 @@
|
|||||||
<svg width="24" height="25" viewBox="0 0 24 25" fill="none" xmlns="http://www.w3.org/2000/svg">
|
<svg width="28" height="34" viewBox="0 0 28 34" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||||
<g clip-path="url(#clip0_9890_21170)">
|
<path d="M10 26.0003H18C19.1 26.0003 20 25.1003 20 24.0003V14.0003H23.18C24.96 14.0003 25.86 11.8403 24.6 10.5803L15.42 1.40032C15.235 1.21491 15.0152 1.06782 14.7732 0.967453C14.5313 0.86709 14.2719 0.81543 14.01 0.81543C13.7481 0.81543 13.4887 0.86709 13.2468 0.967453C13.0048 1.06782 12.785 1.21491 12.6 1.40032L3.42 10.5803C2.16 11.8403 3.04 14.0003 4.82 14.0003H8V24.0003C8 25.1003 8.9 26.0003 10 26.0003ZM2 30.0003H26C27.1 30.0003 28 30.9003 28 32.0003C28 33.1003 27.1 34.0003 26 34.0003H2C0.9 34.0003 0 33.1003 0 32.0003C0 30.9003 0.9 30.0003 2 30.0003Z" fill="#949494"/>
|
||||||
<path d="M12.75 19.1V8.91248L9.5 12.1625L7.75 10.35L14 4.09998L20.25 10.35L18.5 12.1625L15.25 8.91248V19.1H12.75ZM6.5 24.1C5.8125 24.1 5.22417 23.8554 4.735 23.3662C4.24583 22.8771 4.00083 22.2883 4 21.6V17.85H6.5V21.6H21.5V17.85H24V21.6C24 22.2875 23.7554 22.8762 23.2663 23.3662C22.7771 23.8562 22.1883 24.1008 21.5 24.1H6.5Z" fill="black"/>
|
</svg>
|
||||||
</g>
|
|
||||||
<defs>
|
|
||||||
<clipPath id="clip0_9890_21170">
|
|
||||||
<rect width="24" height="24" fill="white" transform="translate(0 0.0999756)"/>
|
|
||||||
</clipPath>
|
|
||||||
</defs>
|
|
||||||
</svg>
|
|
||||||
|
Before Width: | Height: | Size: 630 B After Width: | Height: | Size: 681 B |
@@ -1,3 +0,0 @@
|
|||||||
<svg width="24" height="24" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
|
|
||||||
<path d="M20.8175 3.09139C20.0845 2.34392 19.2025 1.9707 18.1705 1.9707H5.6835C4.6515 1.9707 3.7695 2.34392 3.0365 3.09139C2.3035 3.83885 1.9375 4.73825 1.9375 5.79061V18.524C1.9375 19.5763 2.3035 20.4757 3.0365 21.2232C3.7695 21.9707 4.6515 22.3439 5.6835 22.3439H8.5975C8.7875 22.3439 8.9305 22.3368 9.0265 22.3235C9.13819 22.3007 9.23901 22.2399 9.3125 22.1512C9.4075 22.0492 9.4555 21.9013 9.4555 21.7076L9.4485 20.8051C9.4445 20.23 9.4425 19.7752 9.4425 19.4387L9.1425 19.4917C8.9525 19.5274 8.7125 19.5427 8.4215 19.5386C8.11819 19.5329 7.81584 19.5019 7.5175 19.4458C7.1999 19.386 6.90093 19.2497 6.6455 19.0481C6.37799 18.8418 6.17847 18.5572 6.0735 18.2323L5.9435 17.9264C5.83393 17.6851 5.69627 17.4581 5.5335 17.2503C5.3475 17.0025 5.1585 16.8353 4.9675 16.7466L4.8775 16.6803C4.81474 16.6345 4.75766 16.5811 4.7075 16.5212C4.65959 16.4657 4.62015 16.4031 4.5905 16.3356C4.5645 16.2734 4.5865 16.2224 4.6555 16.1827C4.7255 16.1419 4.8505 16.1225 5.0335 16.1225L5.2935 16.1633C5.4665 16.198 5.6815 16.304 5.9365 16.4804C6.19456 16.6598 6.41013 16.8957 6.5675 17.1708C6.7675 17.5328 7.0075 17.8091 7.2895 17.9998C7.5715 18.1895 7.8555 18.2854 8.1415 18.2854C8.4275 18.2854 8.6745 18.2629 8.8835 18.2191C9.08561 18.1765 9.28201 18.1094 9.4685 18.0192C9.5465 17.4278 9.7585 16.9709 10.1055 16.6528C9.65588 16.6078 9.21026 16.5281 8.7725 16.4142C8.34529 16.2945 7.93444 16.1208 7.5495 15.8972C7.14675 15.6736 6.79101 15.3714 6.5025 15.008C6.2255 14.6541 5.9975 14.1901 5.8195 13.616C5.6425 13.0409 5.5535 12.377 5.5535 11.6255C5.5535 10.5558 5.8955 9.64519 6.5805 8.89263C6.2605 8.08908 6.2905 7.18662 6.6715 6.18831C6.9235 6.10775 7.2965 6.16791 7.7905 6.36676C8.2845 6.56561 8.6465 6.7359 8.8765 6.87662C9.1065 7.01939 9.2905 7.13869 9.4295 7.23557C10.2425 7.00486 11.0826 6.88889 11.9265 6.8909C12.7855 6.8909 13.6175 7.00613 14.4245 7.23557L14.9185 6.91741C15.2985 6.68476 15.6993 6.48946 16.1155 6.33413C16.5755 6.15669 16.9255 6.10877 17.1695 6.18831C17.5595 7.18764 17.5935 8.08908 17.2725 8.89365C17.9575 9.64519 18.3005 10.5558 18.3005 11.6265C18.3005 12.3781 18.2115 13.044 18.0335 13.6221C17.8565 14.2013 17.6265 14.6653 17.3445 15.0151C17.0509 15.3739 16.6937 15.6731 16.2915 15.8972C15.8715 16.1358 15.4635 16.3081 15.0685 16.4142C14.6308 16.5284 14.1852 16.6085 13.7355 16.6538C14.1855 17.0515 14.4115 17.6786 14.4115 18.5362V21.7076C14.4115 21.8575 14.4325 21.9788 14.4765 22.0716C14.4967 22.1163 14.5256 22.1564 14.5613 22.1895C14.597 22.2226 14.6389 22.2481 14.6845 22.2643C14.7805 22.299 14.8645 22.3215 14.9385 22.3296C15.0125 22.3398 15.1185 22.3429 15.2565 22.3429H18.1705C19.2025 22.3429 20.0845 21.9696 20.8175 21.2222C21.5495 20.4757 21.9165 19.5753 21.9165 18.523V5.79061C21.9165 4.73825 21.5505 3.83885 20.8175 3.09139Z" fill="#747474"/>
|
|
||||||
</svg>
|
|
||||||
|
Before Width: | Height: | Size: 2.8 KiB |
@@ -1,3 +1,5 @@
|
|||||||
<svg width="20" height="20" viewBox="0 0 20 20" fill="none" xmlns="http://www.w3.org/2000/svg">
|
<svg width="800px" height="800px" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg">
|
||||||
<path d="M10 0.299927C8.68678 0.299927 7.38642 0.558584 6.17317 1.06113C4.95991 1.56368 3.85752 2.30027 2.92893 3.22886C1.05357 5.10422 0 7.64776 0 10.2999C0 14.7199 2.87 18.4699 6.84 19.7999C7.34 19.8799 7.5 19.5699 7.5 19.2999V17.6099C4.73 18.2099 4.14 16.2699 4.14 16.2699C3.68 15.1099 3.03 14.7999 3.03 14.7999C2.12 14.1799 3.1 14.1999 3.1 14.1999C4.1 14.2699 4.63 15.2299 4.63 15.2299C5.5 16.7499 6.97 16.2999 7.54 16.0599C7.63 15.4099 7.89 14.9699 8.17 14.7199C5.95 14.4699 3.62 13.6099 3.62 9.79993C3.62 8.68993 4 7.79993 4.65 7.08993C4.55 6.83993 4.2 5.79993 4.75 4.44993C4.75 4.44993 5.59 4.17993 7.5 5.46993C8.29 5.24993 9.15 5.13993 10 5.13993C10.85 5.13993 11.71 5.24993 12.5 5.46993C14.41 4.17993 15.25 4.44993 15.25 4.44993C15.8 5.79993 15.45 6.83993 15.35 7.08993C16 7.79993 16.38 8.68993 16.38 9.79993C16.38 13.6199 14.04 14.4599 11.81 14.7099C12.17 15.0199 12.5 15.6299 12.5 16.5599V19.2999C12.5 19.5699 12.66 19.8899 13.17 19.7999C17.14 18.4599 20 14.7199 20 10.2999C20 8.98671 19.7413 7.68635 19.2388 6.47309C18.7362 5.25984 17.9997 4.15744 17.0711 3.22886C16.1425 2.30027 15.0401 1.56368 13.8268 1.06113C12.6136 0.558584 11.3132 0.299927 10 0.299927Z" fill="black"/>
|
<title>github</title>
|
||||||
|
<rect width="24" height="24" fill="none"/>
|
||||||
|
<path d="M12,2A10,10,0,0,0,8.84,21.5c.5.08.66-.23.66-.5V19.31C6.73,19.91,6.14,18,6.14,18A2.69,2.69,0,0,0,5,16.5c-.91-.62.07-.6.07-.6a2.1,2.1,0,0,1,1.53,1,2.15,2.15,0,0,0,2.91.83,2.16,2.16,0,0,1,.63-1.34C8,16.17,5.62,15.31,5.62,11.5a3.87,3.87,0,0,1,1-2.71,3.58,3.58,0,0,1,.1-2.64s.84-.27,2.75,1a9.63,9.63,0,0,1,5,0c1.91-1.29,2.75-1,2.75-1a3.58,3.58,0,0,1,.1,2.64,3.87,3.87,0,0,1,1,2.71c0,3.82-2.34,4.66-4.57,4.91a2.39,2.39,0,0,1,.69,1.85V21c0,.27.16.59.67.5A10,10,0,0,0,12,2Z" fill="black" fill-opacity="0.54"/>
|
||||||
</svg>
|
</svg>
|
||||||
|
|||||||
|
Before Width: | Height: | Size: 1.3 KiB After Width: | Height: | Size: 679 B |
@@ -1,4 +0,0 @@
|
|||||||
<svg width="24" height="25" viewBox="0 0 24 25" fill="none" xmlns="http://www.w3.org/2000/svg">
|
|
||||||
<path d="M10.7519 13.3399C10.7519 12.7699 10.2819 12.2999 9.71187 12.2999C9.14187 12.2999 8.67188 12.7699 8.67188 13.3399C8.67188 13.6158 8.78145 13.8803 8.97648 14.0753C9.17152 14.2704 9.43605 14.3799 9.71187 14.3799C9.9877 14.3799 10.2522 14.2704 10.4473 14.0753C10.6423 13.8803 10.7519 13.6158 10.7519 13.3399ZM14.0919 15.7099C13.6419 16.1599 12.6819 16.3199 12.0019 16.3199C11.3219 16.3199 10.3619 16.1599 9.91187 15.7099C9.88755 15.6839 9.85813 15.6631 9.82545 15.6489C9.79276 15.6347 9.75751 15.6274 9.72188 15.6274C9.68624 15.6274 9.65099 15.6347 9.6183 15.6489C9.58562 15.6631 9.5562 15.6839 9.53187 15.7099C9.50583 15.7343 9.48507 15.7637 9.47088 15.7964C9.45668 15.829 9.44936 15.8643 9.44936 15.8999C9.44936 15.9356 9.45668 15.9708 9.47088 16.0035C9.48507 16.0362 9.50583 16.0656 9.53187 16.0899C10.2419 16.7999 11.6019 16.8599 12.0019 16.8599C12.4019 16.8599 13.7619 16.7999 14.4719 16.0899C14.4979 16.0656 14.5187 16.0362 14.5329 16.0035C14.5471 15.9708 14.5544 15.9356 14.5544 15.8999C14.5544 15.8643 14.5471 15.829 14.5329 15.7964C14.5187 15.7637 14.4979 15.7343 14.4719 15.7099C14.3719 15.6099 14.2019 15.6099 14.0919 15.7099ZM14.2919 12.2999C13.7219 12.2999 13.2519 12.7699 13.2519 13.3399C13.2519 13.9099 13.7219 14.3799 14.2919 14.3799C14.8619 14.3799 15.3319 13.9099 15.3319 13.3399C15.3319 12.7699 14.8719 12.2999 14.2919 12.2999Z" fill="black"/>
|
|
||||||
<path d="M12 2.29993C6.48 2.29993 2 6.77993 2 12.2999C2 17.8199 6.48 22.2999 12 22.2999C17.52 22.2999 22 17.8199 22 12.2999C22 6.77993 17.52 2.29993 12 2.29993ZM17.8 13.6299C17.82 13.7699 17.83 13.9199 17.83 14.0699C17.83 16.3099 15.22 18.1299 12 18.1299C8.78 18.1299 6.17 16.3099 6.17 14.0699C6.17 13.9199 6.18 13.7699 6.2 13.6299C5.69 13.3999 5.34 12.8899 5.34 12.2999C5.33852 12.0132 5.4218 11.7324 5.57939 11.4928C5.73698 11.2532 5.96185 11.0656 6.22576 10.9534C6.48966 10.8412 6.78083 10.8095 7.06269 10.8622C7.34456 10.915 7.60454 11.0499 7.81 11.2499C8.82 10.5199 10.22 10.0599 11.77 10.0099L12.51 6.51993C12.52 6.44993 12.56 6.38993 12.62 6.35993C12.68 6.31993 12.75 6.30993 12.82 6.31993L15.24 6.83993C15.3221 6.67351 15.4472 6.53207 15.6023 6.4303C15.7575 6.32853 15.9371 6.27013 16.1224 6.26115C16.3077 6.25217 16.4921 6.29294 16.6564 6.37924C16.8207 6.46553 16.9589 6.59421 17.0566 6.75191C17.1544 6.90962 17.2082 7.09062 17.2125 7.27613C17.2167 7.46164 17.1712 7.64491 17.0808 7.80692C16.9903 7.96894 16.8582 8.1038 16.698 8.19753C16.5379 8.29125 16.3556 8.34042 16.17 8.33993C15.61 8.33993 15.16 7.89993 15.13 7.34993L12.96 6.88993L12.3 10.0099C13.83 10.0599 15.2 10.5299 16.2 11.2499C16.3533 11.1035 16.5367 10.9924 16.7375 10.9243C16.9382 10.8562 17.1514 10.8328 17.3621 10.8557C17.5728 10.8787 17.776 10.9473 17.9574 11.057C18.1388 11.1666 18.2941 11.3145 18.4123 11.4905C18.5306 11.6664 18.609 11.866 18.642 12.0754C18.6751 12.2847 18.662 12.4988 18.6037 12.7026C18.5454 12.9064 18.4432 13.0949 18.3044 13.2551C18.1656 13.4153 17.9934 13.5432 17.8 13.6299Z" fill="black"/>
|
|
||||||
</svg>
|
|
||||||
|
Before Width: | Height: | Size: 3.0 KiB |
@@ -1 +1 @@
|
|||||||
<svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24"><path fill="black" d="M10.72,19.9a8,8,0,0,1-6.5-9.79A7.77,7.77,0,0,1,10.4,4.16a8,8,0,0,1,9.49,6.52A1.54,1.54,0,0,0,21.38,12h.13a1.37,1.37,0,0,0,1.38-1.54,11,11,0,1,0-12.7,12.39A1.54,1.54,0,0,0,12,21.34h0A1.47,1.47,0,0,0,10.72,19.9Z"><animateTransform attributeName="transform" dur="0.75s" repeatCount="indefinite" type="rotate" values="0 12 12;360 12 12"/></path></svg>
|
<svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24"><path fill="white" d="M10.72,19.9a8,8,0,0,1-6.5-9.79A7.77,7.77,0,0,1,10.4,4.16a8,8,0,0,1,9.49,6.52A1.54,1.54,0,0,0,21.38,12h.13a1.37,1.37,0,0,0,1.38-1.54,11,11,0,1,0-12.7,12.39A1.54,1.54,0,0,0,12,21.34h0A1.47,1.47,0,0,0,10.72,19.9Z"><animateTransform attributeName="transform" dur="0.75s" repeatCount="indefinite" type="rotate" values="0 12 12;360 12 12"/></path></svg>
|
||||||
|
Before Width: | Height: | Size: 454 B After Width: | Height: | Size: 454 B |
@@ -1 +1 @@
|
|||||||
<svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24"><path fill="white" d="M10.72,19.9a8,8,0,0,1-6.5-9.79A7.77,7.77,0,0,1,10.4,4.16a8,8,0,0,1,9.49,6.52A1.54,1.54,0,0,0,21.38,12h.13a1.37,1.37,0,0,0,1.38-1.54,11,11,0,1,0-12.7,12.39A1.54,1.54,0,0,0,12,21.34h0A1.47,1.47,0,0,0,10.72,19.9Z"><animateTransform attributeName="transform" dur="0.75s" repeatCount="indefinite" type="rotate" values="0 12 12;360 12 12"/></path></svg>
|
<svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24"><path fill="black" d="M10.72,19.9a8,8,0,0,1-6.5-9.79A7.77,7.77,0,0,1,10.4,4.16a8,8,0,0,1,9.49,6.52A1.54,1.54,0,0,0,21.38,12h.13a1.37,1.37,0,0,0,1.38-1.54,11,11,0,1,0-12.7,12.39A1.54,1.54,0,0,0,12,21.34h0A1.47,1.47,0,0,0,10.72,19.9Z"><animateTransform attributeName="transform" dur="0.75s" repeatCount="indefinite" type="rotate" values="0 12 12;360 12 12"/></path></svg>
|
||||||
|
Before Width: | Height: | Size: 454 B After Width: | Height: | Size: 454 B |
@@ -1,3 +0,0 @@
|
|||||||
<svg width="22" height="23" viewBox="0 0 22 23" fill="none" xmlns="http://www.w3.org/2000/svg">
|
|
||||||
<path d="M16.304 8.62401L19.486 11.806C20.0023 12.3155 20.4128 12.922 20.6937 13.5908C20.9747 14.2596 21.1206 14.9773 21.123 15.7026C21.1254 16.428 20.9843 17.1467 20.7079 17.8173C20.4314 18.4879 20.025 19.0972 19.5121 19.6102C18.9992 20.1231 18.3899 20.5295 17.7193 20.8059C17.0486 21.0824 16.33 21.2235 15.6046 21.221C14.8792 21.2186 14.1615 21.0727 13.4928 20.7918C12.824 20.5108 12.2174 20.1003 11.708 19.584L10.648 18.524C10.5046 18.3857 10.3903 18.2202 10.3116 18.0373C10.2329 17.8543 10.1914 17.6575 10.1896 17.4583C10.1878 17.2592 10.2256 17.0616 10.301 16.8772C10.3763 16.6929 10.4876 16.5253 10.6284 16.3844C10.7691 16.2435 10.9366 16.1321 11.1209 16.0566C11.3052 15.9811 11.5027 15.943 11.7019 15.9446C11.901 15.9463 12.0979 15.9876 12.2809 16.0661C12.464 16.1446 12.6295 16.2588 12.768 16.402L13.83 17.463C14.2996 17.9284 14.9344 18.1888 15.5955 18.1873C16.2567 18.1857 16.8903 17.9224 17.3577 17.4548C17.8252 16.9872 18.0884 16.3535 18.0897 15.6924C18.0911 15.0312 17.8305 14.3965 17.365 13.927L14.183 10.745C13.839 10.4009 13.4022 10.1647 12.926 10.0652C12.4498 9.96574 11.9549 10.0074 11.502 10.185C11.3406 10.249 11.1893 10.3143 11.048 10.381L10.584 10.598C9.96396 10.878 9.48696 10.998 8.87996 10.392C8.00796 9.52001 8.23396 8.71501 9.29696 7.98201C10.3559 7.25337 11.6365 6.91862 12.9166 7.0359C14.1966 7.15318 15.3951 7.71508 16.304 8.62401ZM10.294 2.61401L11.354 3.67401C11.6273 3.95678 11.7787 4.33562 11.7755 4.72891C11.7722 5.12221 11.6147 5.49851 11.3367 5.77675C11.0587 6.055 10.6826 6.21293 10.2893 6.21653C9.89597 6.22013 9.517 6.06912 9.23396 5.79601L8.17296 4.73601C7.94241 4.49717 7.66661 4.30664 7.36163 4.17553C7.05666 4.04442 6.72863 3.97536 6.39668 3.97239C6.06474 3.96941 5.73552 4.03257 5.42824 4.15818C5.12097 4.2838 4.84179 4.46935 4.60699 4.70402C4.37219 4.93868 4.18648 5.21776 4.06069 5.52496C3.9349 5.83217 3.87155 6.16135 3.87434 6.4933C3.87713 6.82525 3.94601 7.15332 4.07694 7.45836C4.20788 7.76341 4.39825 8.03933 4.63696 8.27001L7.81896 11.452C8.16289 11.7961 8.59974 12.0323 9.07595 12.1318C9.55217 12.2313 10.0471 12.1896 10.5 12.012C10.6613 11.948 10.8126 11.8827 10.954 11.816L11.418 11.599C12.038 11.319 12.516 11.199 13.122 11.805C13.994 12.677 13.768 13.482 12.705 14.215C11.6461 14.9437 10.3654 15.2784 9.08537 15.1611C7.80535 15.0438 6.60683 14.4819 5.69796 13.573L2.51596 10.391C1.99962 9.88154 1.58916 9.27497 1.30821 8.60622C1.02726 7.93747 0.881367 7.21975 0.878937 6.49438C0.876507 5.76901 1.01759 5.05033 1.29405 4.37971C1.57052 3.70909 1.9769 3.09978 2.48982 2.58686C3.00273 2.07395 3.61204 1.66756 4.28266 1.3911C4.95328 1.11463 5.67196 0.973553 6.39733 0.975983C7.1227 0.978413 7.84042 1.1243 8.50917 1.40526C9.17793 1.68621 9.7845 2.09767 10.294 2.61401Z" fill="black"/>
|
|
||||||
</svg>
|
|
||||||
|
Before Width: | Height: | Size: 2.8 KiB |
@@ -1,6 +1,5 @@
|
|||||||
import React, { useRef } from 'react';
|
import React, { useRef } from 'react';
|
||||||
import { useSelector } from 'react-redux';
|
import { useSelector } from 'react-redux';
|
||||||
import { useDarkTheme } from '../hooks';
|
|
||||||
import { selectToken } from '../preferences/preferenceSlice';
|
import { selectToken } from '../preferences/preferenceSlice';
|
||||||
|
|
||||||
interface ConnectorAuthProps {
|
interface ConnectorAuthProps {
|
||||||
@@ -8,24 +7,17 @@ interface ConnectorAuthProps {
|
|||||||
onSuccess: (data: { session_token: string; user_email: string }) => void;
|
onSuccess: (data: { session_token: string; user_email: string }) => void;
|
||||||
onError: (error: string) => void;
|
onError: (error: string) => void;
|
||||||
label?: string;
|
label?: string;
|
||||||
isConnected?: boolean;
|
|
||||||
userEmail?: string;
|
|
||||||
onDisconnect?: () => void;
|
|
||||||
errorMessage?: string;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const ConnectorAuth: React.FC<ConnectorAuthProps> = ({
|
const providerLabel = (provider: string) => {
|
||||||
provider,
|
const map: Record<string, string> = {
|
||||||
onSuccess,
|
google_drive: 'Google Drive',
|
||||||
onError,
|
};
|
||||||
label,
|
return map[provider] || provider.replace(/_/g, ' ');
|
||||||
isConnected = false,
|
};
|
||||||
userEmail = '',
|
|
||||||
onDisconnect,
|
const ConnectorAuth: React.FC<ConnectorAuthProps> = ({ provider, onSuccess, onError, label }) => {
|
||||||
errorMessage,
|
|
||||||
}) => {
|
|
||||||
const token = useSelector(selectToken);
|
const token = useSelector(selectToken);
|
||||||
const [isDarkTheme] = useDarkTheme();
|
|
||||||
const completedRef = useRef(false);
|
const completedRef = useRef(false);
|
||||||
const intervalRef = useRef<number | null>(null);
|
const intervalRef = useRef<number | null>(null);
|
||||||
|
|
||||||
@@ -39,8 +31,8 @@ const ConnectorAuth: React.FC<ConnectorAuthProps> = ({
|
|||||||
|
|
||||||
const handleAuthMessage = (event: MessageEvent) => {
|
const handleAuthMessage = (event: MessageEvent) => {
|
||||||
const successGeneric = event.data?.type === 'connector_auth_success';
|
const successGeneric = event.data?.type === 'connector_auth_success';
|
||||||
const successProvider = event.data?.type === `${provider}_auth_success`;
|
const successProvider = event.data?.type === `${provider}_auth_success` || event.data?.type === 'google_drive_auth_success';
|
||||||
const errorProvider = event.data?.type === `${provider}_auth_error`;
|
const errorProvider = event.data?.type === `${provider}_auth_error` || event.data?.type === 'google_drive_auth_error';
|
||||||
|
|
||||||
if (successGeneric || successProvider) {
|
if (successGeneric || successProvider) {
|
||||||
completedRef.current = true;
|
completedRef.current = true;
|
||||||
@@ -62,17 +54,12 @@ const ConnectorAuth: React.FC<ConnectorAuthProps> = ({
|
|||||||
cleanup();
|
cleanup();
|
||||||
|
|
||||||
const apiHost = import.meta.env.VITE_API_HOST;
|
const apiHost = import.meta.env.VITE_API_HOST;
|
||||||
const authResponse = await fetch(
|
const authResponse = await fetch(`${apiHost}/api/connectors/auth?provider=${provider}`, {
|
||||||
`${apiHost}/api/connectors/auth?provider=${provider}`,
|
headers: { Authorization: `Bearer ${token}` },
|
||||||
{
|
});
|
||||||
headers: { Authorization: `Bearer ${token}` },
|
|
||||||
},
|
|
||||||
);
|
|
||||||
|
|
||||||
if (!authResponse.ok) {
|
if (!authResponse.ok) {
|
||||||
throw new Error(
|
throw new Error(`Failed to get authorization URL: ${authResponse.status}`);
|
||||||
`Failed to get authorization URL: ${authResponse.status}`,
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const authData = await authResponse.json();
|
const authData = await authResponse.json();
|
||||||
@@ -83,12 +70,10 @@ const ConnectorAuth: React.FC<ConnectorAuthProps> = ({
|
|||||||
const authWindow = window.open(
|
const authWindow = window.open(
|
||||||
authData.authorization_url,
|
authData.authorization_url,
|
||||||
`${provider}-auth`,
|
`${provider}-auth`,
|
||||||
'width=500,height=600,scrollbars=yes,resizable=yes',
|
'width=500,height=600,scrollbars=yes,resizable=yes'
|
||||||
);
|
);
|
||||||
if (!authWindow) {
|
if (!authWindow) {
|
||||||
throw new Error(
|
throw new Error('Failed to open authentication window. Please allow popups.');
|
||||||
'Failed to open authentication window. Please allow popups.',
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
window.addEventListener('message', handleAuthMessage as any);
|
window.addEventListener('message', handleAuthMessage as any);
|
||||||
@@ -108,58 +93,20 @@ const ConnectorAuth: React.FC<ConnectorAuthProps> = ({
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const buttonLabel = label || `Connect ${providerLabel(provider)}`;
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<>
|
<button
|
||||||
{errorMessage && (
|
onClick={handleAuth}
|
||||||
<div className="mb-4 flex items-center gap-2 rounded-lg border border-[#E60000] dark:border-[#D42626] bg-transparent dark:bg-[#D426261A] p-2">
|
className="w-full flex items-center justify-center gap-2 rounded-lg bg-blue-500 px-4 py-3 text-white hover:bg-blue-600 transition-colors"
|
||||||
<svg width="30" height="30" viewBox="0 0 30 30" fill="none" xmlns="http://www.w3.org/2000/svg">
|
>
|
||||||
<path d="M7.09974 24.5422H22.9C24.5156 24.5422 25.5228 22.7901 24.715 21.3947L16.8149 7.74526C16.007 6.34989 13.9927 6.34989 13.1848 7.74526L5.28471 21.3947C4.47686 22.7901 5.48405 24.5422 7.09974 24.5422ZM14.9998 17.1981C14.4228 17.1981 13.9507 16.726 13.9507 16.149V14.0507C13.9507 13.4736 14.4228 13.0015 14.9998 13.0015C15.5769 13.0015 16.049 13.4736 16.049 14.0507V16.149C16.049 16.726 15.5769 17.1981 14.9998 17.1981ZM16.049 21.3947H13.9507V19.2964H16.049V21.3947Z" fill={isDarkTheme ? '#EECF56' : '#E60000'} />
|
<svg className="h-5 w-5" viewBox="0 0 24 24">
|
||||||
</svg>
|
<path fill="currentColor" d="M6.28 3l5.72 10H24l-5.72-10H6.28zm11.44 0L12 13l5.72 10H24L18.28 3h-.56zM0 13l5.72 10h5.72L5.72 13H0z"/>
|
||||||
|
</svg>
|
||||||
<span className='text-[#E60000] dark:text-[#E37064] text-sm' style={{
|
{buttonLabel}
|
||||||
fontFamily: 'Inter',
|
</button>
|
||||||
lineHeight: '100%'
|
|
||||||
}}>
|
|
||||||
{errorMessage}
|
|
||||||
</span>
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
|
|
||||||
{isConnected ? (
|
|
||||||
<div className="mb-4">
|
|
||||||
<div className="w-full flex items-center justify-between rounded-[10px] bg-[#8FDD51] px-4 py-2 text-[#212121] font-medium text-sm">
|
|
||||||
<div className="flex items-center gap-2">
|
|
||||||
<svg className="h-4 w-4" viewBox="0 0 24 24">
|
|
||||||
<path fill="currentColor" d="M9 16.17L4.83 12l-1.42 1.41L9 19 21 7l-1.41-1.41z" />
|
|
||||||
</svg>
|
|
||||||
<span>Connected as {userEmail}</span>
|
|
||||||
</div>
|
|
||||||
{onDisconnect && (
|
|
||||||
<button
|
|
||||||
onClick={onDisconnect}
|
|
||||||
className="text-[#212121] hover:text-gray-700 font-medium text-xs underline"
|
|
||||||
>
|
|
||||||
Disconnect
|
|
||||||
</button>
|
|
||||||
)}
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
) : (
|
|
||||||
<button
|
|
||||||
onClick={handleAuth}
|
|
||||||
className="flex w-full items-center justify-center gap-2 rounded-lg bg-blue-500 px-4 py-3 text-white transition-colors hover:bg-blue-600"
|
|
||||||
>
|
|
||||||
<svg className="h-5 w-5" viewBox="0 0 24 24">
|
|
||||||
<path
|
|
||||||
fill="currentColor"
|
|
||||||
d="M6.28 3l5.72 10H24l-5.72-10H6.28zm11.44 0L12 13l5.72 10H24L18.28 3h-.56zM0 13l5.72 10h5.72L5.72 13H0z"
|
|
||||||
/>
|
|
||||||
</svg>
|
|
||||||
{label}
|
|
||||||
</button>
|
|
||||||
)}
|
|
||||||
</>
|
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
export default ConnectorAuth;
|
export default ConnectorAuth;
|
||||||
|
|
||||||
|
|||||||
@@ -3,10 +3,8 @@ import { useTranslation } from 'react-i18next';
|
|||||||
import { useSelector } from 'react-redux';
|
import { useSelector } from 'react-redux';
|
||||||
import { formatBytes } from '../utils/stringUtils';
|
import { formatBytes } from '../utils/stringUtils';
|
||||||
import { selectToken } from '../preferences/preferenceSlice';
|
import { selectToken } from '../preferences/preferenceSlice';
|
||||||
import { ActiveState } from '../models/misc';
|
|
||||||
import Chunks from './Chunks';
|
import Chunks from './Chunks';
|
||||||
import ContextMenu, { MenuOption } from './ContextMenu';
|
import ContextMenu, { MenuOption } from './ContextMenu';
|
||||||
import ConfirmationModal from '../modals/ConfirmationModal';
|
|
||||||
import userService from '../api/services/userService';
|
import userService from '../api/services/userService';
|
||||||
import FileIcon from '../assets/file.svg';
|
import FileIcon from '../assets/file.svg';
|
||||||
import FolderIcon from '../assets/folder.svg';
|
import FolderIcon from '../assets/folder.svg';
|
||||||
@@ -14,17 +12,7 @@ import ArrowLeft from '../assets/arrow-left.svg';
|
|||||||
import ThreeDots from '../assets/three-dots.svg';
|
import ThreeDots from '../assets/three-dots.svg';
|
||||||
import EyeView from '../assets/eye-view.svg';
|
import EyeView from '../assets/eye-view.svg';
|
||||||
import SyncIcon from '../assets/sync.svg';
|
import SyncIcon from '../assets/sync.svg';
|
||||||
import CheckmarkIcon from '../assets/checkMark2.svg';
|
|
||||||
import { useOutsideAlerter } from '../hooks';
|
import { useOutsideAlerter } from '../hooks';
|
||||||
import {
|
|
||||||
Table,
|
|
||||||
TableContainer,
|
|
||||||
TableHead,
|
|
||||||
TableBody,
|
|
||||||
TableRow,
|
|
||||||
TableHeader,
|
|
||||||
TableCell,
|
|
||||||
} from './Table';
|
|
||||||
|
|
||||||
interface FileNode {
|
interface FileNode {
|
||||||
type?: string;
|
type?: string;
|
||||||
@@ -76,7 +64,6 @@ const ConnectorTreeComponent: React.FC<ConnectorTreeComponentProps> = ({
|
|||||||
const [syncProgress, setSyncProgress] = useState<number>(0);
|
const [syncProgress, setSyncProgress] = useState<number>(0);
|
||||||
const [sourceProvider, setSourceProvider] = useState<string>('');
|
const [sourceProvider, setSourceProvider] = useState<string>('');
|
||||||
const [syncDone, setSyncDone] = useState<boolean>(false);
|
const [syncDone, setSyncDone] = useState<boolean>(false);
|
||||||
const [syncConfirmationModal, setSyncConfirmationModal] = useState<ActiveState>('INACTIVE');
|
|
||||||
|
|
||||||
useOutsideAlerter(
|
useOutsideAlerter(
|
||||||
searchDropdownRef,
|
searchDropdownRef,
|
||||||
@@ -240,6 +227,8 @@ const ConnectorTreeComponent: React.FC<ConnectorTreeComponentProps> = ({
|
|||||||
return current;
|
return current;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
const getMenuRef = (id: string) => {
|
const getMenuRef = (id: string) => {
|
||||||
if (!menuRefs.current[id]) {
|
if (!menuRefs.current[id]) {
|
||||||
menuRefs.current[id] = React.createRef();
|
menuRefs.current[id] = React.createRef();
|
||||||
@@ -356,7 +345,7 @@ const ConnectorTreeComponent: React.FC<ConnectorTreeComponentProps> = ({
|
|||||||
|
|
||||||
{/* Sync button */}
|
{/* Sync button */}
|
||||||
<button
|
<button
|
||||||
onClick={() => setSyncConfirmationModal('ACTIVE')}
|
onClick={handleSync}
|
||||||
disabled={isSyncing}
|
disabled={isSyncing}
|
||||||
className={`flex h-[38px] min-w-[108px] items-center justify-center rounded-full px-4 text-[14px] font-medium whitespace-nowrap transition-colors ${
|
className={`flex h-[38px] min-w-[108px] items-center justify-center rounded-full px-4 text-[14px] font-medium whitespace-nowrap transition-colors ${
|
||||||
isSyncing
|
isSyncing
|
||||||
@@ -372,7 +361,7 @@ const ConnectorTreeComponent: React.FC<ConnectorTreeComponentProps> = ({
|
|||||||
}
|
}
|
||||||
>
|
>
|
||||||
<img
|
<img
|
||||||
src={syncDone ? CheckmarkIcon : SyncIcon}
|
src={SyncIcon}
|
||||||
alt={t('settings.sources.sync')}
|
alt={t('settings.sources.sync')}
|
||||||
className={`mr-2 h-4 w-4 brightness-0 invert filter ${isSyncing ? 'animate-spin' : ''}`}
|
className={`mr-2 h-4 w-4 brightness-0 invert filter ${isSyncing ? 'animate-spin' : ''}`}
|
||||||
/>
|
/>
|
||||||
@@ -387,36 +376,39 @@ const ConnectorTreeComponent: React.FC<ConnectorTreeComponentProps> = ({
|
|||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
const renderFileTree = (directory: DirectoryStructure): React.ReactNode[] => {
|
const renderFileTree = (directory: DirectoryStructure) => {
|
||||||
|
if (!directory) return [];
|
||||||
|
|
||||||
// Create parent directory row
|
// Create parent directory row
|
||||||
const parentRow =
|
const parentRow =
|
||||||
currentPath.length > 0
|
currentPath.length > 0
|
||||||
? [
|
? [
|
||||||
<TableRow
|
<tr
|
||||||
key="parent-dir"
|
key="parent-dir"
|
||||||
onClick={navigateUp}
|
className="cursor-pointer border-b border-[#D1D9E0] hover:bg-[#ECEEEF] dark:border-[#6A6A6A] dark:hover:bg-[#27282D]"
|
||||||
>
|
onClick={navigateUp}
|
||||||
<TableCell width="40%" align="left">
|
>
|
||||||
<div className="flex items-center">
|
<td className="px-2 py-2 lg:px-4">
|
||||||
<img
|
<div className="flex items-center">
|
||||||
src={FolderIcon}
|
<img
|
||||||
alt={t('settings.sources.parentFolderAlt')}
|
src={FolderIcon}
|
||||||
className="mr-2 h-4 w-4 flex-shrink-0"
|
alt={t('settings.sources.parentFolderAlt')}
|
||||||
/>
|
className="mr-2 h-4 w-4 flex-shrink-0"
|
||||||
<span className="truncate">
|
/>
|
||||||
..
|
<span className="truncate text-sm dark:text-[#E0E0E0]">
|
||||||
</span>
|
..
|
||||||
</div>
|
</span>
|
||||||
</TableCell>
|
</div>
|
||||||
<TableCell width="30%" align="left">
|
</td>
|
||||||
-
|
<td className="px-2 py-2 text-sm lg:px-4 dark:text-[#E0E0E0]">
|
||||||
</TableCell>
|
-
|
||||||
<TableCell width="20%" align="left">
|
</td>
|
||||||
-
|
<td className="px-2 py-2 text-sm lg:px-4 dark:text-[#E0E0E0]">
|
||||||
</TableCell>
|
-
|
||||||
<TableCell width="10%" align="right"></TableCell>
|
</td>
|
||||||
</TableRow>,
|
<td className="w-10 px-2 py-2 text-sm lg:px-4"></td>
|
||||||
]
|
</tr>,
|
||||||
|
]
|
||||||
: [];
|
: [];
|
||||||
|
|
||||||
// Sort entries: directories first, then files, both alphabetically
|
// Sort entries: directories first, then files, both alphabetically
|
||||||
@@ -444,35 +436,36 @@ const ConnectorTreeComponent: React.FC<ConnectorTreeComponentProps> = ({
|
|||||||
const dirStats = calculateDirectoryStats(node as DirectoryStructure);
|
const dirStats = calculateDirectoryStats(node as DirectoryStructure);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<TableRow
|
<tr
|
||||||
key={itemId}
|
key={itemId}
|
||||||
|
className="cursor-pointer border-b border-[#D1D9E0] hover:bg-[#ECEEEF] dark:border-[#6A6A6A] dark:hover:bg-[#27282D]"
|
||||||
onClick={() => navigateToDirectory(name)}
|
onClick={() => navigateToDirectory(name)}
|
||||||
>
|
>
|
||||||
<TableCell width="40%" align="left">
|
<td className="px-2 py-2 lg:px-4">
|
||||||
<div className="flex min-w-0 items-center">
|
<div className="flex min-w-0 items-center">
|
||||||
<img
|
<img
|
||||||
src={FolderIcon}
|
src={FolderIcon}
|
||||||
alt={t('settings.sources.folderAlt')}
|
alt={t('settings.sources.folderAlt')}
|
||||||
className="mr-2 h-4 w-4 flex-shrink-0"
|
className="mr-2 h-4 w-4 flex-shrink-0"
|
||||||
/>
|
/>
|
||||||
<span className="truncate">
|
<span className="truncate text-sm dark:text-[#E0E0E0]">
|
||||||
{name}
|
{name}
|
||||||
</span>
|
</span>
|
||||||
</div>
|
</div>
|
||||||
</TableCell>
|
</td>
|
||||||
<TableCell width="30%" align="left">
|
<td className="px-2 py-2 text-sm lg:px-4 dark:text-[#E0E0E0]">
|
||||||
{dirStats.totalTokens > 0
|
{dirStats.totalTokens > 0
|
||||||
? dirStats.totalTokens.toLocaleString()
|
? dirStats.totalTokens.toLocaleString()
|
||||||
: '-'}
|
: '-'}
|
||||||
</TableCell>
|
</td>
|
||||||
<TableCell width="20%" align="left">
|
<td className="px-2 py-2 text-sm lg:px-4 dark:text-[#E0E0E0]">
|
||||||
{dirStats.totalSize > 0 ? formatBytes(dirStats.totalSize) : '-'}
|
{dirStats.totalSize > 0 ? formatBytes(dirStats.totalSize) : '-'}
|
||||||
</TableCell>
|
</td>
|
||||||
<TableCell width="10%" align="right">
|
<td className="w-10 px-2 py-2 text-sm lg:px-4">
|
||||||
<div ref={menuRef} className="relative">
|
<div ref={menuRef} className="relative">
|
||||||
<button
|
<button
|
||||||
onClick={(e) => handleMenuClick(e, itemId)}
|
onClick={(e) => handleMenuClick(e, itemId)}
|
||||||
className="inline-flex h-[35px] w-[24px] shrink-0 items-center justify-center rounded-md transition-colors hover:bg-[#EBEBEB] dark:hover:bg-[#26272E] font-medium"
|
className="inline-flex h-[35px] w-[24px] shrink-0 items-center justify-center rounded-md font-medium transition-colors hover:bg-[#EBEBEB] dark:hover:bg-[#26272E]"
|
||||||
aria-label={t('settings.sources.menuAlt')}
|
aria-label={t('settings.sources.menuAlt')}
|
||||||
>
|
>
|
||||||
<img
|
<img
|
||||||
@@ -492,8 +485,8 @@ const ConnectorTreeComponent: React.FC<ConnectorTreeComponentProps> = ({
|
|||||||
offset={{ x: -4, y: 4 }}
|
offset={{ x: -4, y: 4 }}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
</TableCell>
|
</td>
|
||||||
</TableRow>
|
</tr>
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -505,29 +498,30 @@ const ConnectorTreeComponent: React.FC<ConnectorTreeComponentProps> = ({
|
|||||||
const menuRef = getMenuRef(itemId);
|
const menuRef = getMenuRef(itemId);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<TableRow
|
<tr
|
||||||
key={itemId}
|
key={itemId}
|
||||||
|
className="cursor-pointer border-b border-[#D1D9E0] hover:bg-[#ECEEEF] dark:border-[#6A6A6A] dark:hover:bg-[#27282D]"
|
||||||
onClick={() => handleFileClick(name)}
|
onClick={() => handleFileClick(name)}
|
||||||
>
|
>
|
||||||
<TableCell width="40%" align="left">
|
<td className="px-2 py-2 lg:px-4">
|
||||||
<div className="flex min-w-0 items-center">
|
<div className="flex min-w-0 items-center">
|
||||||
<img
|
<img
|
||||||
src={FileIcon}
|
src={FileIcon}
|
||||||
alt={t('settings.sources.fileAlt')}
|
alt={t('settings.sources.fileAlt')}
|
||||||
className="mr-2 h-4 w-4 flex-shrink-0"
|
className="mr-2 h-4 w-4 flex-shrink-0"
|
||||||
/>
|
/>
|
||||||
<span className="truncate">
|
<span className="truncate text-sm dark:text-[#E0E0E0]">
|
||||||
{name}
|
{name}
|
||||||
</span>
|
</span>
|
||||||
</div>
|
</div>
|
||||||
</TableCell>
|
</td>
|
||||||
<TableCell width="30%" align="left">
|
<td className="px-2 py-2 text-sm lg:px-4 dark:text-[#E0E0E0]">
|
||||||
{node.token_count?.toLocaleString() || '-'}
|
{node.token_count?.toLocaleString() || '-'}
|
||||||
</TableCell>
|
</td>
|
||||||
<TableCell width="20%" align="left">
|
<td className="px-2 py-2 text-sm md:px-4 dark:text-[#E0E0E0]">
|
||||||
{node.size_bytes ? formatBytes(node.size_bytes) : '-'}
|
{node.size_bytes ? formatBytes(node.size_bytes) : '-'}
|
||||||
</TableCell>
|
</td>
|
||||||
<TableCell width="10%" align="right">
|
<td className="w-10 px-2 py-2 text-sm lg:px-4">
|
||||||
<div ref={menuRef} className="relative">
|
<div ref={menuRef} className="relative">
|
||||||
<button
|
<button
|
||||||
onClick={(e) => handleMenuClick(e, itemId)}
|
onClick={(e) => handleMenuClick(e, itemId)}
|
||||||
@@ -551,8 +545,8 @@ const ConnectorTreeComponent: React.FC<ConnectorTreeComponentProps> = ({
|
|||||||
offset={{ x: -4, y: 4 }}
|
offset={{ x: -4, y: 4 }}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
</TableCell>
|
</td>
|
||||||
</TableRow>
|
</tr>
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -710,45 +704,28 @@ const ConnectorTreeComponent: React.FC<ConnectorTreeComponentProps> = ({
|
|||||||
<div className="mb-2">{renderPathNavigation()}</div>
|
<div className="mb-2">{renderPathNavigation()}</div>
|
||||||
|
|
||||||
<div className="w-full">
|
<div className="w-full">
|
||||||
<TableContainer>
|
<div className="overflow-x-auto rounded-[6px] border border-[#D1D9E0] dark:border-[#6A6A6A]">
|
||||||
<Table>
|
<table className="w-full min-w-[600px] table-auto bg-transparent">
|
||||||
<TableHead>
|
<thead className="bg-gray-100 dark:bg-[#27282D]">
|
||||||
<TableRow>
|
<tr className="border-b border-[#D1D9E0] dark:border-[#6A6A6A]">
|
||||||
<TableHeader width="40%" align="left">
|
<th className="min-w-[200px] px-2 py-3 text-left text-sm font-medium text-gray-700 lg:px-4 dark:text-[#59636E]">
|
||||||
{t('settings.sources.fileName')}
|
{t('settings.sources.fileName')}
|
||||||
</TableHeader>
|
</th>
|
||||||
<TableHeader width="30%" align="left">
|
<th className="min-w-[80px] px-2 py-3 text-left text-sm font-medium text-gray-700 lg:px-4 dark:text-[#59636E]">
|
||||||
{t('settings.sources.tokens')}
|
{t('settings.sources.tokens')}
|
||||||
</TableHeader>
|
</th>
|
||||||
<TableHeader width="20%" align="left">
|
<th className="min-w-[80px] px-2 py-3 text-left text-sm font-medium text-gray-700 lg:px-4 dark:text-[#59636E]">
|
||||||
{t('settings.sources.size')}
|
{t('settings.sources.size')}
|
||||||
</TableHeader>
|
</th>
|
||||||
<TableHeader width="10%" align="right">
|
<th className="w-10 px-2 py-3 text-left text-sm font-medium text-gray-700 lg:px-4 dark:text-[#59636E]"></th>
|
||||||
<span className="sr-only">
|
</tr>
|
||||||
{t('settings.sources.actions')}
|
</thead>
|
||||||
</span>
|
<tbody>{renderFileTree(getCurrentDirectory())}</tbody>
|
||||||
</TableHeader>
|
</table>
|
||||||
</TableRow>
|
</div>
|
||||||
</TableHead>
|
|
||||||
<TableBody>
|
|
||||||
{renderFileTree(getCurrentDirectory())}
|
|
||||||
</TableBody>
|
|
||||||
</Table>
|
|
||||||
</TableContainer>
|
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
<ConfirmationModal
|
|
||||||
message={t('settings.sources.syncConfirmation', {
|
|
||||||
sourceName,
|
|
||||||
})}
|
|
||||||
modalState={syncConfirmationModal}
|
|
||||||
setModalState={setSyncConfirmationModal}
|
|
||||||
handleSubmit={handleSync}
|
|
||||||
submitLabel={t('settings.sources.sync')}
|
|
||||||
cancelLabel={t('cancel')}
|
|
||||||
/>
|
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -1,469 +0,0 @@
|
|||||||
import React, { useState, useEffect, useCallback, useRef } from 'react';
|
|
||||||
import { formatBytes } from '../utils/stringUtils';
|
|
||||||
import { formatDate } from '../utils/dateTimeUtils';
|
|
||||||
import { getSessionToken, setSessionToken, removeSessionToken } from '../utils/providerUtils';
|
|
||||||
import ConnectorAuth from '../components/ConnectorAuth';
|
|
||||||
import FileIcon from '../assets/file.svg';
|
|
||||||
import FolderIcon from '../assets/folder.svg';
|
|
||||||
import CheckIcon from '../assets/checkmark.svg';
|
|
||||||
import SearchIcon from '../assets/search.svg';
|
|
||||||
import Input from './Input';
|
|
||||||
import {
|
|
||||||
Table,
|
|
||||||
TableContainer,
|
|
||||||
TableHead,
|
|
||||||
TableBody,
|
|
||||||
TableRow,
|
|
||||||
TableHeader,
|
|
||||||
TableCell,
|
|
||||||
} from './Table';
|
|
||||||
|
|
||||||
interface CloudFile {
|
|
||||||
id: string;
|
|
||||||
name: string;
|
|
||||||
type: string;
|
|
||||||
size?: number;
|
|
||||||
modifiedTime: string;
|
|
||||||
isFolder?: boolean;
|
|
||||||
}
|
|
||||||
|
|
||||||
interface CloudFilePickerProps {
|
|
||||||
onSelectionChange: (selectedFileIds: string[], selectedFolderIds?: string[]) => void;
|
|
||||||
onDisconnect?: () => void;
|
|
||||||
provider: string;
|
|
||||||
token: string | null;
|
|
||||||
initialSelectedFiles?: string[];
|
|
||||||
initialSelectedFolders?: string[];
|
|
||||||
}
|
|
||||||
|
|
||||||
export const FilePicker: React.FC<CloudFilePickerProps> = ({
|
|
||||||
onSelectionChange,
|
|
||||||
onDisconnect,
|
|
||||||
provider,
|
|
||||||
token,
|
|
||||||
initialSelectedFiles = [],
|
|
||||||
}) => {
|
|
||||||
const PROVIDER_CONFIG = {
|
|
||||||
google_drive: {
|
|
||||||
displayName: 'Drive',
|
|
||||||
rootName: 'My Drive',
|
|
||||||
},
|
|
||||||
} as const;
|
|
||||||
|
|
||||||
const getProviderConfig = (provider: string) => {
|
|
||||||
return PROVIDER_CONFIG[provider as keyof typeof PROVIDER_CONFIG] || {
|
|
||||||
displayName: provider,
|
|
||||||
rootName: 'Root',
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
const [files, setFiles] = useState<CloudFile[]>([]);
|
|
||||||
const [selectedFiles, setSelectedFiles] = useState<string[]>(initialSelectedFiles);
|
|
||||||
const [selectedFolders, setSelectedFolders] = useState<string[]>([]);
|
|
||||||
const [isLoading, setIsLoading] = useState(false);
|
|
||||||
const [hasMoreFiles, setHasMoreFiles] = useState(false);
|
|
||||||
const [nextPageToken, setNextPageToken] = useState<string | null>(null);
|
|
||||||
const [currentFolderId, setCurrentFolderId] = useState<string | null>(null);
|
|
||||||
const [folderPath, setFolderPath] = useState<Array<{ id: string | null, name: string }>>([{
|
|
||||||
id: null,
|
|
||||||
name: getProviderConfig(provider).rootName
|
|
||||||
}]);
|
|
||||||
const [searchQuery, setSearchQuery] = useState<string>('');
|
|
||||||
const [authError, setAuthError] = useState<string>('');
|
|
||||||
const [isConnected, setIsConnected] = useState(false);
|
|
||||||
const [userEmail, setUserEmail] = useState<string>('');
|
|
||||||
|
|
||||||
const scrollContainerRef = useRef<HTMLDivElement>(null);
|
|
||||||
const searchTimeoutRef = useRef<ReturnType<typeof setTimeout> | null>(null);
|
|
||||||
|
|
||||||
const isFolder = (file: CloudFile) => {
|
|
||||||
return file.isFolder ||
|
|
||||||
file.type === 'application/vnd.google-apps.folder' ||
|
|
||||||
file.type === 'folder';
|
|
||||||
};
|
|
||||||
|
|
||||||
const loadCloudFiles = useCallback(
|
|
||||||
async (
|
|
||||||
sessionToken: string,
|
|
||||||
folderId: string | null,
|
|
||||||
pageToken?: string,
|
|
||||||
searchQuery: string = ''
|
|
||||||
) => {
|
|
||||||
setIsLoading(true);
|
|
||||||
|
|
||||||
const apiHost = import.meta.env.VITE_API_HOST;
|
|
||||||
if (!pageToken) {
|
|
||||||
setFiles([]);
|
|
||||||
}
|
|
||||||
|
|
||||||
try {
|
|
||||||
const response = await fetch(`${apiHost}/api/connectors/files`, {
|
|
||||||
method: 'POST',
|
|
||||||
headers: {
|
|
||||||
'Content-Type': 'application/json',
|
|
||||||
'Authorization': `Bearer ${token}`
|
|
||||||
},
|
|
||||||
body: JSON.stringify({
|
|
||||||
provider: provider,
|
|
||||||
session_token: sessionToken,
|
|
||||||
folder_id: folderId,
|
|
||||||
limit: 10,
|
|
||||||
page_token: pageToken,
|
|
||||||
search_query: searchQuery
|
|
||||||
})
|
|
||||||
});
|
|
||||||
|
|
||||||
const data = await response.json();
|
|
||||||
if (data.success) {
|
|
||||||
setFiles(prev => pageToken ? [...prev, ...data.files] : data.files);
|
|
||||||
setNextPageToken(data.next_page_token);
|
|
||||||
setHasMoreFiles(!!data.next_page_token);
|
|
||||||
} else {
|
|
||||||
console.error('Error loading files:', data.error);
|
|
||||||
if (!pageToken) {
|
|
||||||
setFiles([]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} catch (err) {
|
|
||||||
console.error('Error loading files:', err);
|
|
||||||
if (!pageToken) {
|
|
||||||
setFiles([]);
|
|
||||||
}
|
|
||||||
} finally {
|
|
||||||
setIsLoading(false);
|
|
||||||
}
|
|
||||||
},
|
|
||||||
[token, provider]
|
|
||||||
);
|
|
||||||
|
|
||||||
const validateAndLoadFiles = useCallback(async () => {
|
|
||||||
const sessionToken = getSessionToken(provider);
|
|
||||||
if (!sessionToken) {
|
|
||||||
setIsConnected(false);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
try {
|
|
||||||
const apiHost = import.meta.env.VITE_API_HOST;
|
|
||||||
const validateResponse = await fetch(`${apiHost}/api/connectors/validate-session`, {
|
|
||||||
method: 'POST',
|
|
||||||
headers: {
|
|
||||||
'Content-Type': 'application/json',
|
|
||||||
'Authorization': `Bearer ${token}`
|
|
||||||
},
|
|
||||||
body: JSON.stringify({ provider: provider, session_token: sessionToken })
|
|
||||||
});
|
|
||||||
|
|
||||||
if (!validateResponse.ok) {
|
|
||||||
removeSessionToken(provider);
|
|
||||||
setIsConnected(false);
|
|
||||||
setAuthError('Session expired. Please reconnect to Google Drive.');
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const validateData = await validateResponse.json();
|
|
||||||
if (validateData.success) {
|
|
||||||
setUserEmail(validateData.user_email || 'Connected User');
|
|
||||||
setIsConnected(true);
|
|
||||||
setAuthError('');
|
|
||||||
|
|
||||||
setFiles([]);
|
|
||||||
setNextPageToken(null);
|
|
||||||
setHasMoreFiles(false);
|
|
||||||
setCurrentFolderId(null);
|
|
||||||
setFolderPath([{
|
|
||||||
id: null, name: getProviderConfig(provider).rootName
|
|
||||||
}]);
|
|
||||||
loadCloudFiles(sessionToken, null, undefined, '');
|
|
||||||
} else {
|
|
||||||
removeSessionToken(provider);
|
|
||||||
setIsConnected(false);
|
|
||||||
setAuthError(validateData.error || 'Session expired. Please reconnect your account.');
|
|
||||||
}
|
|
||||||
} catch (error) {
|
|
||||||
console.error('Error validating session:', error);
|
|
||||||
setAuthError('Failed to validate session. Please reconnect.');
|
|
||||||
setIsConnected(false);
|
|
||||||
}
|
|
||||||
}, [provider, token, loadCloudFiles]);
|
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
validateAndLoadFiles();
|
|
||||||
}, [validateAndLoadFiles]);
|
|
||||||
|
|
||||||
const handleScroll = useCallback(() => {
|
|
||||||
const scrollContainer = scrollContainerRef.current;
|
|
||||||
if (!scrollContainer) return;
|
|
||||||
|
|
||||||
const { scrollTop, scrollHeight, clientHeight } = scrollContainer;
|
|
||||||
const isNearBottom = scrollHeight - scrollTop - clientHeight < 50;
|
|
||||||
|
|
||||||
if (isNearBottom && hasMoreFiles && !isLoading && nextPageToken) {
|
|
||||||
const sessionToken = getSessionToken(provider);
|
|
||||||
if (sessionToken) {
|
|
||||||
loadCloudFiles(sessionToken, currentFolderId, nextPageToken, searchQuery);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}, [hasMoreFiles, isLoading, nextPageToken, currentFolderId, searchQuery, provider, loadCloudFiles]);
|
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
const scrollContainer = scrollContainerRef.current;
|
|
||||||
if (scrollContainer) {
|
|
||||||
scrollContainer.addEventListener('scroll', handleScroll);
|
|
||||||
return () => scrollContainer.removeEventListener('scroll', handleScroll);
|
|
||||||
}
|
|
||||||
}, [handleScroll]);
|
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
return () => {
|
|
||||||
if (searchTimeoutRef.current) {
|
|
||||||
clearTimeout(searchTimeoutRef.current);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
}, []);
|
|
||||||
|
|
||||||
const handleSearchChange = (query: string) => {
|
|
||||||
setSearchQuery(query);
|
|
||||||
|
|
||||||
if (searchTimeoutRef.current) {
|
|
||||||
clearTimeout(searchTimeoutRef.current);
|
|
||||||
}
|
|
||||||
|
|
||||||
searchTimeoutRef.current = setTimeout(() => {
|
|
||||||
const sessionToken = getSessionToken(provider);
|
|
||||||
if (sessionToken) {
|
|
||||||
loadCloudFiles(sessionToken, currentFolderId, undefined, query);
|
|
||||||
}
|
|
||||||
}, 300);
|
|
||||||
};
|
|
||||||
|
|
||||||
const handleFolderClick = (folderId: string, folderName: string) => {
|
|
||||||
if (folderId === currentFolderId) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
setIsLoading(true);
|
|
||||||
|
|
||||||
setCurrentFolderId(folderId);
|
|
||||||
setFolderPath(prev => [...prev, { id: folderId, name: folderName }]);
|
|
||||||
setSearchQuery('');
|
|
||||||
|
|
||||||
const sessionToken = getSessionToken(provider);
|
|
||||||
if (sessionToken) {
|
|
||||||
loadCloudFiles(sessionToken, folderId, undefined, '');
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
const navigateBack = (index: number) => {
|
|
||||||
if (index >= folderPath.length - 1) return;
|
|
||||||
|
|
||||||
const newFolderPath = folderPath.slice(0, index + 1);
|
|
||||||
const newFolderId = newFolderPath[newFolderPath.length - 1].id;
|
|
||||||
|
|
||||||
setFolderPath(newFolderPath);
|
|
||||||
setCurrentFolderId(newFolderId);
|
|
||||||
setSearchQuery('');
|
|
||||||
|
|
||||||
const sessionToken = getSessionToken(provider);
|
|
||||||
if (sessionToken) {
|
|
||||||
loadCloudFiles(sessionToken, newFolderId, undefined, '');
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
const handleFileSelect = (fileId: string, isFolder: boolean) => {
|
|
||||||
if (isFolder) {
|
|
||||||
const newSelectedFolders = selectedFolders.includes(fileId)
|
|
||||||
? selectedFolders.filter(id => id !== fileId)
|
|
||||||
: [...selectedFolders, fileId];
|
|
||||||
setSelectedFolders(newSelectedFolders);
|
|
||||||
onSelectionChange(selectedFiles, newSelectedFolders);
|
|
||||||
} else {
|
|
||||||
const newSelectedFiles = selectedFiles.includes(fileId)
|
|
||||||
? selectedFiles.filter(id => id !== fileId)
|
|
||||||
: [...selectedFiles, fileId];
|
|
||||||
setSelectedFiles(newSelectedFiles);
|
|
||||||
onSelectionChange(newSelectedFiles, selectedFolders);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
return (
|
|
||||||
<div className=''>
|
|
||||||
{authError && (
|
|
||||||
<div className="text-red-500 text-sm mb-4 text-center">{authError}</div>
|
|
||||||
)}
|
|
||||||
|
|
||||||
<ConnectorAuth
|
|
||||||
provider={provider}
|
|
||||||
onSuccess={(data) => {
|
|
||||||
setUserEmail(data.user_email || 'Connected User');
|
|
||||||
setIsConnected(true);
|
|
||||||
setAuthError('');
|
|
||||||
|
|
||||||
if (data.session_token) {
|
|
||||||
setSessionToken(provider, data.session_token);
|
|
||||||
loadCloudFiles(data.session_token, null);
|
|
||||||
}
|
|
||||||
}}
|
|
||||||
onError={(error) => {
|
|
||||||
setAuthError(error);
|
|
||||||
setIsConnected(false);
|
|
||||||
}}
|
|
||||||
isConnected={isConnected}
|
|
||||||
userEmail={userEmail}
|
|
||||||
onDisconnect={() => {
|
|
||||||
const sessionToken = getSessionToken(provider);
|
|
||||||
if (sessionToken) {
|
|
||||||
const apiHost = import.meta.env.VITE_API_HOST;
|
|
||||||
fetch(`${apiHost}/api/connectors/disconnect`, {
|
|
||||||
method: 'POST',
|
|
||||||
headers: {
|
|
||||||
'Content-Type': 'application/json',
|
|
||||||
'Authorization': `Bearer ${token}`
|
|
||||||
},
|
|
||||||
body: JSON.stringify({ provider: provider, session_token: sessionToken })
|
|
||||||
}).catch(err => console.error(`Error disconnecting from ${getProviderConfig(provider).displayName}:`, err));
|
|
||||||
}
|
|
||||||
|
|
||||||
removeSessionToken(provider);
|
|
||||||
setIsConnected(false);
|
|
||||||
setFiles([]);
|
|
||||||
setSelectedFiles([]);
|
|
||||||
onSelectionChange([]);
|
|
||||||
|
|
||||||
if (onDisconnect) {
|
|
||||||
onDisconnect();
|
|
||||||
}
|
|
||||||
}}
|
|
||||||
/>
|
|
||||||
|
|
||||||
{isConnected && (
|
|
||||||
<div className="border border-[#D7D7D7] rounded-lg dark:border-[#6A6A6A] mt-3">
|
|
||||||
<div className="border-[#EEE6FF78] dark:border-[#6A6A6A] rounded-t-lg">
|
|
||||||
{/* Breadcrumb navigation */}
|
|
||||||
<div className="px-4 pt-4 bg-[#EEE6FF78] dark:bg-[#2A262E] rounded-t-lg">
|
|
||||||
<div className="flex items-center gap-1 mb-2">
|
|
||||||
{folderPath.map((path, index) => (
|
|
||||||
<div key={path.id || 'root'} className="flex items-center gap-1">
|
|
||||||
{index > 0 && <span className="text-gray-400">/</span>}
|
|
||||||
<button
|
|
||||||
onClick={() => navigateBack(index)}
|
|
||||||
className="text-sm text-[#A076F6] hover:text-[#8A5FD4] hover:underline"
|
|
||||||
disabled={index === folderPath.length - 1}
|
|
||||||
>
|
|
||||||
{path.name}
|
|
||||||
</button>
|
|
||||||
</div>
|
|
||||||
))}
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div className="mb-3 text-sm text-gray-600 dark:text-gray-400">
|
|
||||||
Select Files from {getProviderConfig(provider).displayName}
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div className="mb-3 max-w-md">
|
|
||||||
<Input
|
|
||||||
type="text"
|
|
||||||
placeholder="Search files and folders..."
|
|
||||||
value={searchQuery}
|
|
||||||
onChange={(e) => handleSearchChange(e.target.value)}
|
|
||||||
colorVariant="silver"
|
|
||||||
borderVariant="thin"
|
|
||||||
labelBgClassName="bg-[#EEE6FF78] dark:bg-[#2A262E]"
|
|
||||||
leftIcon={<img src={SearchIcon} alt="Search" width={16} height={16} />}
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
{/* Selected Files Message */}
|
|
||||||
<div className="pb-3 text-sm text-gray-600 dark:text-gray-400">
|
|
||||||
{selectedFiles.length + selectedFolders.length} selected
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div className="h-72">
|
|
||||||
<TableContainer
|
|
||||||
ref={scrollContainerRef}
|
|
||||||
height="288px"
|
|
||||||
className="scrollbar-thin md:w-4xl lg:w-5xl"
|
|
||||||
bordered={false}
|
|
||||||
>
|
|
||||||
{(
|
|
||||||
<>
|
|
||||||
<Table minWidth="1200px">
|
|
||||||
<TableHead>
|
|
||||||
<TableRow>
|
|
||||||
<TableHeader width="40px"></TableHeader>
|
|
||||||
<TableHeader width="60%">Name</TableHeader>
|
|
||||||
<TableHeader width="20%">Last Modified</TableHeader>
|
|
||||||
<TableHeader width="20%">Size</TableHeader>
|
|
||||||
</TableRow>
|
|
||||||
</TableHead>
|
|
||||||
<TableBody>
|
|
||||||
{files.map((file, index) => (
|
|
||||||
<TableRow
|
|
||||||
key={`${file.id}-${index}`}
|
|
||||||
onClick={() => {
|
|
||||||
if (isFolder(file)) {
|
|
||||||
handleFolderClick(file.id, file.name);
|
|
||||||
} else {
|
|
||||||
handleFileSelect(file.id, false);
|
|
||||||
}
|
|
||||||
}}
|
|
||||||
>
|
|
||||||
<TableCell width="40px" align="center">
|
|
||||||
<div
|
|
||||||
className="flex h-5 w-5 text-sm shrink-0 items-center justify-center border border-[#EEE6FF78] p-[0.5px] dark:border-[#6A6A6A] cursor-pointer mx-auto"
|
|
||||||
onClick={(e) => {
|
|
||||||
e.stopPropagation();
|
|
||||||
handleFileSelect(file.id, isFolder(file));
|
|
||||||
}}
|
|
||||||
>
|
|
||||||
{(isFolder(file) ? selectedFolders : selectedFiles).includes(file.id) && (
|
|
||||||
<img
|
|
||||||
src={CheckIcon}
|
|
||||||
alt="Selected"
|
|
||||||
className="h-4 w-4"
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
</div>
|
|
||||||
</TableCell>
|
|
||||||
<TableCell>
|
|
||||||
<div className="flex items-center gap-3 min-w-0">
|
|
||||||
<div className="flex-shrink-0">
|
|
||||||
<img
|
|
||||||
src={isFolder(file) ? FolderIcon : FileIcon}
|
|
||||||
alt={isFolder(file) ? "Folder" : "File"}
|
|
||||||
className="h-6 w-6"
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
<span className="truncate">{file.name}</span>
|
|
||||||
</div>
|
|
||||||
</TableCell>
|
|
||||||
<TableCell className='text-xs'>
|
|
||||||
{formatDate(file.modifiedTime)}
|
|
||||||
</TableCell>
|
|
||||||
<TableCell className='text-xs'>
|
|
||||||
{file.size ? formatBytes(file.size) : '-'}
|
|
||||||
</TableCell>
|
|
||||||
</TableRow>
|
|
||||||
))}
|
|
||||||
</TableBody>
|
|
||||||
</Table>
|
|
||||||
|
|
||||||
{isLoading && (
|
|
||||||
<div className="flex items-center justify-center p-4 border-t border-[#EEE6FF78] dark:border-[#6A6A6A]">
|
|
||||||
<div className="inline-flex items-center gap-2 text-sm text-gray-600 dark:text-gray-400">
|
|
||||||
<div className="h-4 w-4 animate-spin rounded-full border-2 border-blue-500 border-t-transparent"></div>
|
|
||||||
Loading more files...
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
</>
|
|
||||||
)}
|
|
||||||
</TableContainer>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
@@ -11,18 +11,11 @@ import FolderIcon from '../assets/folder.svg';
|
|||||||
import ArrowLeft from '../assets/arrow-left.svg';
|
import ArrowLeft from '../assets/arrow-left.svg';
|
||||||
import ThreeDots from '../assets/three-dots.svg';
|
import ThreeDots from '../assets/three-dots.svg';
|
||||||
import EyeView from '../assets/eye-view.svg';
|
import EyeView from '../assets/eye-view.svg';
|
||||||
|
import OutlineSource from '../assets/outline-source.svg';
|
||||||
import Trash from '../assets/red-trash.svg';
|
import Trash from '../assets/red-trash.svg';
|
||||||
|
import SearchIcon from '../assets/search.svg';
|
||||||
import { useOutsideAlerter } from '../hooks';
|
import { useOutsideAlerter } from '../hooks';
|
||||||
import ConfirmationModal from '../modals/ConfirmationModal';
|
import ConfirmationModal from '../modals/ConfirmationModal';
|
||||||
import {
|
|
||||||
Table,
|
|
||||||
TableContainer,
|
|
||||||
TableHead,
|
|
||||||
TableBody,
|
|
||||||
TableRow,
|
|
||||||
TableHeader,
|
|
||||||
TableCell,
|
|
||||||
} from './Table';
|
|
||||||
|
|
||||||
interface FileNode {
|
interface FileNode {
|
||||||
type?: string;
|
type?: string;
|
||||||
@@ -136,6 +129,8 @@ const FileTreeComponent: React.FC<FileTreeComponentProps> = ({
|
|||||||
}
|
}
|
||||||
}, [docId, token]);
|
}, [docId, token]);
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
const navigateToDirectory = (dirName: string) => {
|
const navigateToDirectory = (dirName: string) => {
|
||||||
setCurrentPath((prev) => [...prev, dirName]);
|
setCurrentPath((prev) => [...prev, dirName]);
|
||||||
};
|
};
|
||||||
@@ -443,18 +438,18 @@ const FileTreeComponent: React.FC<FileTreeComponentProps> = ({
|
|||||||
|
|
||||||
const renderPathNavigation = () => {
|
const renderPathNavigation = () => {
|
||||||
return (
|
return (
|
||||||
<div className="mb-0 flex min-h-[38px] flex-col gap-2 text-base sm:flex-row sm:items-center sm:justify-between">
|
<div className="mb-0 min-h-[38px] flex flex-col gap-2 text-base sm:flex-row sm:items-center sm:justify-between">
|
||||||
{/* Left side with path navigation */}
|
{/* Left side with path navigation */}
|
||||||
<div className="flex w-full items-center sm:w-auto">
|
<div className="flex w-full items-center sm:w-auto">
|
||||||
<button
|
<button
|
||||||
className="mr-3 flex h-[29px] w-[29px] items-center justify-center rounded-full border p-2 text-sm font-medium text-gray-400 dark:border-0 dark:bg-[#28292D] dark:text-gray-500 dark:hover:bg-[#2E2F34]"
|
className="mr-3 flex h-[29px] w-[29px] items-center justify-center rounded-full border p-2 text-sm text-gray-400 dark:border-0 dark:bg-[#28292D] dark:text-gray-500 dark:hover:bg-[#2E2F34] font-medium"
|
||||||
onClick={handleBackNavigation}
|
onClick={handleBackNavigation}
|
||||||
>
|
>
|
||||||
<img src={ArrowLeft} alt="left-arrow" className="h-3 w-3" />
|
<img src={ArrowLeft} alt="left-arrow" className="h-3 w-3" />
|
||||||
</button>
|
</button>
|
||||||
|
|
||||||
<div className="flex flex-wrap items-center">
|
<div className="flex flex-wrap items-center">
|
||||||
<span className="font-semibold break-words text-[#7D54D1]">
|
<span className="text-[#7D54D1] font-semibold break-words">
|
||||||
{sourceName}
|
{sourceName}
|
||||||
</span>
|
</span>
|
||||||
{currentPath.length > 0 && (
|
{currentPath.length > 0 && (
|
||||||
@@ -485,7 +480,8 @@ const FileTreeComponent: React.FC<FileTreeComponentProps> = ({
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div className="relative mt-2 flex w-full flex-row flex-nowrap items-center justify-end gap-2 sm:mt-0 sm:w-auto">
|
<div className="flex relative flex-row flex-nowrap items-center gap-2 w-full sm:w-auto justify-end mt-2 sm:mt-0">
|
||||||
|
|
||||||
{processingRef.current && (
|
{processingRef.current && (
|
||||||
<div className="text-sm text-gray-500">
|
<div className="text-sm text-gray-500">
|
||||||
{currentOpRef.current === 'add'
|
{currentOpRef.current === 'add'
|
||||||
@@ -494,13 +490,13 @@ const FileTreeComponent: React.FC<FileTreeComponentProps> = ({
|
|||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
{renderFileSearch()}
|
{renderFileSearch()}
|
||||||
|
|
||||||
{/* Add file button */}
|
{/* Add file button */}
|
||||||
{!processingRef.current && (
|
{!processingRef.current && (
|
||||||
<button
|
<button
|
||||||
onClick={handleAddFile}
|
onClick={handleAddFile}
|
||||||
className="bg-purple-30 hover:bg-violets-are-blue flex h-[38px] min-w-[108px] items-center justify-center rounded-full px-4 text-[14px] font-medium whitespace-nowrap text-white"
|
className="bg-purple-30 hover:bg-violets-are-blue flex h-[38px] min-w-[108px] items-center justify-center rounded-full px-4 text-[14px] whitespace-nowrap text-white font-medium"
|
||||||
title={t('settings.sources.addFile')}
|
title={t('settings.sources.addFile')}
|
||||||
>
|
>
|
||||||
{t('settings.sources.addFile')}
|
{t('settings.sources.addFile')}
|
||||||
@@ -542,30 +538,31 @@ const FileTreeComponent: React.FC<FileTreeComponentProps> = ({
|
|||||||
const parentRow =
|
const parentRow =
|
||||||
currentPath.length > 0
|
currentPath.length > 0
|
||||||
? [
|
? [
|
||||||
<TableRow
|
<tr
|
||||||
key="parent-dir"
|
key="parent-dir"
|
||||||
|
className="cursor-pointer border-b border-[#D1D9E0] hover:bg-[#ECEEEF] dark:border-[#6A6A6A] dark:hover:bg-[#27282D]"
|
||||||
onClick={navigateUp}
|
onClick={navigateUp}
|
||||||
>
|
>
|
||||||
<TableCell width="40%" align="left">
|
<td className="px-2 py-2 lg:px-4">
|
||||||
<div className="flex items-center">
|
<div className="flex items-center">
|
||||||
<img
|
<img
|
||||||
src={FolderIcon}
|
src={FolderIcon}
|
||||||
alt={t('settings.sources.parentFolderAlt')}
|
alt={t('settings.sources.parentFolderAlt')}
|
||||||
className="mr-2 h-4 w-4 flex-shrink-0"
|
className="mr-2 h-4 w-4 flex-shrink-0"
|
||||||
/>
|
/>
|
||||||
<span className="truncate">
|
<span className="truncate text-sm dark:text-[#E0E0E0]">
|
||||||
..
|
..
|
||||||
</span>
|
</span>
|
||||||
</div>
|
</div>
|
||||||
</TableCell>
|
</td>
|
||||||
<TableCell width="30%" align="left">
|
<td className="px-2 py-2 text-sm lg:px-4 dark:text-[#E0E0E0]">
|
||||||
-
|
-
|
||||||
</TableCell>
|
</td>
|
||||||
<TableCell width="20%" align="right">
|
<td className="px-2 py-2 text-sm lg:px-4 dark:text-[#E0E0E0]">
|
||||||
-
|
-
|
||||||
</TableCell>
|
</td>
|
||||||
<TableCell width="10%" align="right"></TableCell>
|
<td className="w-10 px-2 py-2 text-sm lg:px-4"></td>
|
||||||
</TableRow>,
|
</tr>,
|
||||||
]
|
]
|
||||||
: [];
|
: [];
|
||||||
|
|
||||||
@@ -578,35 +575,36 @@ const FileTreeComponent: React.FC<FileTreeComponentProps> = ({
|
|||||||
const dirStats = calculateDirectoryStats(node as DirectoryStructure);
|
const dirStats = calculateDirectoryStats(node as DirectoryStructure);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<TableRow
|
<tr
|
||||||
key={itemId}
|
key={itemId}
|
||||||
|
className="cursor-pointer border-b border-[#D1D9E0] hover:bg-[#ECEEEF] dark:border-[#6A6A6A] dark:hover:bg-[#27282D]"
|
||||||
onClick={() => navigateToDirectory(name)}
|
onClick={() => navigateToDirectory(name)}
|
||||||
>
|
>
|
||||||
<TableCell width="40%" align="left">
|
<td className="px-2 py-2 lg:px-4">
|
||||||
<div className="flex min-w-0 items-center">
|
<div className="flex min-w-0 items-center">
|
||||||
<img
|
<img
|
||||||
src={FolderIcon}
|
src={FolderIcon}
|
||||||
alt={t('settings.sources.folderAlt')}
|
alt={t('settings.sources.folderAlt')}
|
||||||
className="mr-2 h-4 w-4 flex-shrink-0"
|
className="mr-2 h-4 w-4 flex-shrink-0"
|
||||||
/>
|
/>
|
||||||
<span className="truncate">
|
<span className="truncate text-sm dark:text-[#E0E0E0]">
|
||||||
{name}
|
{name}
|
||||||
</span>
|
</span>
|
||||||
</div>
|
</div>
|
||||||
</TableCell>
|
</td>
|
||||||
<TableCell width="30%" align="left">
|
<td className="px-2 py-2 text-sm lg:px-4 dark:text-[#E0E0E0]">
|
||||||
{dirStats.totalSize > 0 ? formatBytes(dirStats.totalSize) : '-'}
|
|
||||||
</TableCell>
|
|
||||||
<TableCell width="20%" align="right">
|
|
||||||
{dirStats.totalTokens > 0
|
{dirStats.totalTokens > 0
|
||||||
? dirStats.totalTokens.toLocaleString()
|
? dirStats.totalTokens.toLocaleString()
|
||||||
: '-'}
|
: '-'}
|
||||||
</TableCell>
|
</td>
|
||||||
<TableCell width="10%" align="right">
|
<td className="px-2 py-2 text-sm lg:px-4 dark:text-[#E0E0E0]">
|
||||||
|
{dirStats.totalSize > 0 ? formatBytes(dirStats.totalSize) : '-'}
|
||||||
|
</td>
|
||||||
|
<td className="w-10 px-2 py-2 text-sm lg:px-4">
|
||||||
<div ref={menuRef} className="relative">
|
<div ref={menuRef} className="relative">
|
||||||
<button
|
<button
|
||||||
onClick={(e) => handleMenuClick(e, itemId)}
|
onClick={(e) => handleMenuClick(e, itemId)}
|
||||||
className="inline-flex h-[35px] w-[24px] shrink-0 items-center justify-center rounded-md font-medium transition-colors hover:bg-[#EBEBEB] dark:hover:bg-[#26272E]"
|
className="inline-flex h-[35px] w-[24px] shrink-0 items-center justify-center rounded-md transition-colors hover:bg-[#EBEBEB] dark:hover:bg-[#26272E] font-medium"
|
||||||
aria-label={t('settings.sources.menuAlt')}
|
aria-label={t('settings.sources.menuAlt')}
|
||||||
>
|
>
|
||||||
<img
|
<img
|
||||||
@@ -626,8 +624,8 @@ const FileTreeComponent: React.FC<FileTreeComponentProps> = ({
|
|||||||
offset={{ x: -4, y: 4 }}
|
offset={{ x: -4, y: 4 }}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
</TableCell>
|
</td>
|
||||||
</TableRow>
|
</tr>
|
||||||
);
|
);
|
||||||
}),
|
}),
|
||||||
...files.map(([name, node]) => {
|
...files.map(([name, node]) => {
|
||||||
@@ -635,33 +633,34 @@ const FileTreeComponent: React.FC<FileTreeComponentProps> = ({
|
|||||||
const menuRef = getMenuRef(itemId);
|
const menuRef = getMenuRef(itemId);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<TableRow
|
<tr
|
||||||
key={itemId}
|
key={itemId}
|
||||||
|
className="cursor-pointer border-b border-[#D1D9E0] hover:bg-[#ECEEEF] dark:border-[#6A6A6A] dark:hover:bg-[#27282D]"
|
||||||
onClick={() => handleFileClick(name)}
|
onClick={() => handleFileClick(name)}
|
||||||
>
|
>
|
||||||
<TableCell width="40%" align="left">
|
<td className="px-2 py-2 lg:px-4">
|
||||||
<div className="flex min-w-0 items-center">
|
<div className="flex min-w-0 items-center">
|
||||||
<img
|
<img
|
||||||
src={FileIcon}
|
src={FileIcon}
|
||||||
alt={t('settings.sources.fileAlt')}
|
alt={t('settings.sources.fileAlt')}
|
||||||
className="mr-2 h-4 w-4 flex-shrink-0"
|
className="mr-2 h-4 w-4 flex-shrink-0"
|
||||||
/>
|
/>
|
||||||
<span className="truncate">
|
<span className="truncate text-sm dark:text-[#E0E0E0]">
|
||||||
{name}
|
{name}
|
||||||
</span>
|
</span>
|
||||||
</div>
|
</div>
|
||||||
</TableCell>
|
</td>
|
||||||
<TableCell width="30%" align="left">
|
<td className="px-2 py-2 text-sm lg:px-4 dark:text-[#E0E0E0]">
|
||||||
{node.size_bytes ? formatBytes(node.size_bytes) : '-'}
|
|
||||||
</TableCell>
|
|
||||||
<TableCell width="20%" align="right">
|
|
||||||
{node.token_count?.toLocaleString() || '-'}
|
{node.token_count?.toLocaleString() || '-'}
|
||||||
</TableCell>
|
</td>
|
||||||
<TableCell width="10%" align="right">
|
<td className="px-2 py-2 text-sm md:px-4 dark:text-[#E0E0E0]">
|
||||||
|
{node.size_bytes ? formatBytes(node.size_bytes) : '-'}
|
||||||
|
</td>
|
||||||
|
<td className="w-10 px-2 py-2 text-sm lg:px-4">
|
||||||
<div ref={menuRef} className="relative">
|
<div ref={menuRef} className="relative">
|
||||||
<button
|
<button
|
||||||
onClick={(e) => handleMenuClick(e, itemId)}
|
onClick={(e) => handleMenuClick(e, itemId)}
|
||||||
className="inline-flex h-[35px] w-[24px] shrink-0 items-center justify-center rounded-md font-medium transition-colors hover:bg-[#EBEBEB] dark:hover:bg-[#26272E]"
|
className="inline-flex h-[35px] w-[24px] shrink-0 items-center justify-center rounded-md transition-colors hover:bg-[#EBEBEB] dark:hover:bg-[#26272E] font-medium"
|
||||||
aria-label={t('settings.sources.menuAlt')}
|
aria-label={t('settings.sources.menuAlt')}
|
||||||
>
|
>
|
||||||
<img
|
<img
|
||||||
@@ -681,8 +680,8 @@ const FileTreeComponent: React.FC<FileTreeComponentProps> = ({
|
|||||||
offset={{ x: -4, y: 4 }}
|
offset={{ x: -4, y: 4 }}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
</TableCell>
|
</td>
|
||||||
</TableRow>
|
</tr>
|
||||||
);
|
);
|
||||||
}),
|
}),
|
||||||
];
|
];
|
||||||
@@ -753,12 +752,14 @@ const FileTreeComponent: React.FC<FileTreeComponentProps> = ({
|
|||||||
}
|
}
|
||||||
}}
|
}}
|
||||||
placeholder={t('settings.sources.searchFiles')}
|
placeholder={t('settings.sources.searchFiles')}
|
||||||
className={`h-[38px] w-full border border-[#D1D9E0] px-4 py-2 dark:border-[#6A6A6A] ${searchQuery ? 'rounded-t-[24px]' : 'rounded-[24px]'} bg-transparent focus:outline-none dark:text-[#E0E0E0]`}
|
className={`w-full h-[38px] border border-[#D1D9E0] px-4 py-2 dark:border-[#6A6A6A]
|
||||||
|
${searchQuery ? 'rounded-t-[24px]' : 'rounded-[24px]'}
|
||||||
|
bg-transparent focus:outline-none dark:text-[#E0E0E0]`}
|
||||||
/>
|
/>
|
||||||
|
|
||||||
{searchQuery && (
|
{searchQuery && (
|
||||||
<div className="absolute top-full right-0 left-0 z-10 max-h-[calc(100vh-200px)] w-full overflow-hidden rounded-b-[12px] border border-t-0 border-[#D1D9E0] bg-white shadow-lg transition-all duration-200 dark:border-[#6A6A6A] dark:bg-[#1F2023]">
|
<div className="absolute top-full left-0 right-0 z-10 max-h-[calc(100vh-200px)] w-full overflow-hidden rounded-b-[12px] border border-t-0 border-[#D1D9E0] bg-white shadow-lg dark:border-[#6A6A6A] dark:bg-[#1F2023] transition-all duration-200">
|
||||||
<div className="max-h-[calc(100vh-200px)] overflow-x-hidden overflow-y-auto overscroll-contain">
|
<div className="max-h-[calc(100vh-200px)] overflow-y-auto overflow-x-hidden overscroll-contain">
|
||||||
{searchResults.length === 0 ? (
|
{searchResults.length === 0 ? (
|
||||||
<div className="py-2 text-center text-sm text-gray-500 dark:text-gray-400">
|
<div className="py-2 text-center text-sm text-gray-500 dark:text-gray-400">
|
||||||
{t('settings.sources.noResults')}
|
{t('settings.sources.noResults')}
|
||||||
@@ -769,11 +770,10 @@ const FileTreeComponent: React.FC<FileTreeComponentProps> = ({
|
|||||||
key={index}
|
key={index}
|
||||||
onClick={() => handleSearchSelect(result)}
|
onClick={() => handleSearchSelect(result)}
|
||||||
title={result.path}
|
title={result.path}
|
||||||
className={`flex min-w-0 cursor-pointer items-center px-3 py-2 hover:bg-[#ECEEEF] dark:hover:bg-[#27282D] ${
|
className={`flex min-w-0 cursor-pointer items-center px-3 py-2 hover:bg-[#ECEEEF] dark:hover:bg-[#27282D] ${index !== searchResults.length - 1
|
||||||
index !== searchResults.length - 1
|
|
||||||
? 'border-b border-[#D1D9E0] dark:border-[#6A6A6A]'
|
? 'border-b border-[#D1D9E0] dark:border-[#6A6A6A]'
|
||||||
: ''
|
: ''
|
||||||
}`}
|
}`}
|
||||||
>
|
>
|
||||||
<img
|
<img
|
||||||
src={result.isFile ? FileIcon : FolderIcon}
|
src={result.isFile ? FileIcon : FolderIcon}
|
||||||
@@ -784,7 +784,7 @@ const FileTreeComponent: React.FC<FileTreeComponentProps> = ({
|
|||||||
}
|
}
|
||||||
className="mr-2 h-4 w-4 flex-shrink-0"
|
className="mr-2 h-4 w-4 flex-shrink-0"
|
||||||
/>
|
/>
|
||||||
<span className="flex-1 truncate text-sm dark:text-[#E0E0E0]">
|
<span className="text-sm dark:text-[#E0E0E0] truncate flex-1">
|
||||||
{result.path.split('/').pop() || result.path}
|
{result.path.split('/').pop() || result.path}
|
||||||
</span>
|
</span>
|
||||||
</div>
|
</div>
|
||||||
@@ -834,31 +834,31 @@ const FileTreeComponent: React.FC<FileTreeComponentProps> = ({
|
|||||||
<div className="mb-2">{renderPathNavigation()}</div>
|
<div className="mb-2">{renderPathNavigation()}</div>
|
||||||
|
|
||||||
<div className="w-full">
|
<div className="w-full">
|
||||||
<TableContainer>
|
<div className="overflow-x-auto rounded-[6px] border border-[#D1D9E0] dark:border-[#6A6A6A]">
|
||||||
<Table>
|
<table className="w-full min-w-[600px] table-auto bg-transparent">
|
||||||
<TableHead>
|
<thead className="bg-gray-100 dark:bg-[#27282D]">
|
||||||
<TableRow>
|
<tr className="border-b border-[#D1D9E0] dark:border-[#6A6A6A]">
|
||||||
<TableHeader width="40%" align="left">
|
<th className="min-w-[200px] px-2 py-3 text-left text-sm font-medium text-gray-700 lg:px-4 dark:text-[#59636E]">
|
||||||
{t('settings.sources.fileName')}
|
{t('settings.sources.fileName')}
|
||||||
</TableHeader>
|
</th>
|
||||||
<TableHeader width="30%" align="left">
|
<th className="min-w-[80px] px-2 py-3 text-left text-sm font-medium text-gray-700 lg:px-4 dark:text-[#59636E]">
|
||||||
{t('settings.sources.size')}
|
|
||||||
</TableHeader>
|
|
||||||
<TableHeader width="20%" align="right">
|
|
||||||
{t('settings.sources.tokens')}
|
{t('settings.sources.tokens')}
|
||||||
</TableHeader>
|
</th>
|
||||||
<TableHeader width="10%" align="right">
|
<th className="min-w-[80px] px-2 py-3 text-left text-sm font-medium text-gray-700 lg:px-4 dark:text-[#59636E]">
|
||||||
|
{t('settings.sources.size')}
|
||||||
|
</th>
|
||||||
|
<th className="w-[60px] px-2 py-3 text-left text-sm font-medium text-gray-700 lg:px-4 dark:text-[#59636E]">
|
||||||
<span className="sr-only">
|
<span className="sr-only">
|
||||||
{t('settings.sources.actions')}
|
{t('settings.sources.actions')}
|
||||||
</span>
|
</span>
|
||||||
</TableHeader>
|
</th>
|
||||||
</TableRow>
|
</tr>
|
||||||
</TableHead>
|
</thead>
|
||||||
<TableBody>
|
<tbody className="[&>tr:last-child]:border-b-0">
|
||||||
{renderFileTree(currentDirectory)}
|
{renderFileTree(currentDirectory)}
|
||||||
</TableBody>
|
</tbody>
|
||||||
</Table>
|
</table>
|
||||||
</TableContainer>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
@@ -866,9 +866,7 @@ const FileTreeComponent: React.FC<FileTreeComponentProps> = ({
|
|||||||
message={
|
message={
|
||||||
itemToDelete?.isFile
|
itemToDelete?.isFile
|
||||||
? t('settings.sources.confirmDelete')
|
? t('settings.sources.confirmDelete')
|
||||||
: t('settings.sources.deleteDirectoryWarning', {
|
: t('settings.sources.deleteDirectoryWarning', { name: itemToDelete?.name })
|
||||||
name: itemToDelete?.name,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
modalState={deleteModalState}
|
modalState={deleteModalState}
|
||||||
setModalState={setDeleteModalState}
|
setModalState={setDeleteModalState}
|
||||||
|
|||||||
@@ -1,342 +0,0 @@
|
|||||||
import React, { useState, useEffect } from 'react';
|
|
||||||
import useDrivePicker from 'react-google-drive-picker';
|
|
||||||
|
|
||||||
import ConnectorAuth from './ConnectorAuth';
|
|
||||||
import { getSessionToken, setSessionToken, removeSessionToken } from '../utils/providerUtils';
|
|
||||||
|
|
||||||
|
|
||||||
interface PickerFile {
|
|
||||||
id: string;
|
|
||||||
name: string;
|
|
||||||
mimeType: string;
|
|
||||||
iconUrl: string;
|
|
||||||
description?: string;
|
|
||||||
sizeBytes?: string;
|
|
||||||
}
|
|
||||||
|
|
||||||
interface GoogleDrivePickerProps {
|
|
||||||
token: string | null;
|
|
||||||
onSelectionChange: (fileIds: string[], folderIds?: string[]) => void;
|
|
||||||
}
|
|
||||||
|
|
||||||
const GoogleDrivePicker: React.FC<GoogleDrivePickerProps> = ({
|
|
||||||
token,
|
|
||||||
onSelectionChange,
|
|
||||||
}) => {
|
|
||||||
const [selectedFiles, setSelectedFiles] = useState<PickerFile[]>([]);
|
|
||||||
const [selectedFolders, setSelectedFolders] = useState<PickerFile[]>([]);
|
|
||||||
const [isLoading, setIsLoading] = useState(false);
|
|
||||||
const [userEmail, setUserEmail] = useState<string>('');
|
|
||||||
const [isConnected, setIsConnected] = useState(false);
|
|
||||||
const [authError, setAuthError] = useState<string>('');
|
|
||||||
const [accessToken, setAccessToken] = useState<string | null>(null);
|
|
||||||
const [isValidating, setIsValidating] = useState(false);
|
|
||||||
|
|
||||||
const [openPicker] = useDrivePicker();
|
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
const sessionToken = getSessionToken('google_drive');
|
|
||||||
if (sessionToken) {
|
|
||||||
setIsValidating(true);
|
|
||||||
setIsConnected(true); // Optimistically set as connected for skeleton
|
|
||||||
validateSession(sessionToken);
|
|
||||||
}
|
|
||||||
}, [token]);
|
|
||||||
|
|
||||||
const validateSession = async (sessionToken: string) => {
|
|
||||||
try {
|
|
||||||
const apiHost = import.meta.env.VITE_API_HOST;
|
|
||||||
const validateResponse = await fetch(`${apiHost}/api/connectors/validate-session`, {
|
|
||||||
method: 'POST',
|
|
||||||
headers: {
|
|
||||||
'Content-Type': 'application/json',
|
|
||||||
'Authorization': `Bearer ${token}`
|
|
||||||
},
|
|
||||||
body: JSON.stringify({ provider: 'google_drive', session_token: sessionToken })
|
|
||||||
});
|
|
||||||
|
|
||||||
if (!validateResponse.ok) {
|
|
||||||
setIsConnected(false);
|
|
||||||
setAuthError('Session expired. Please reconnect to Google Drive.');
|
|
||||||
setIsValidating(false);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
const validateData = await validateResponse.json();
|
|
||||||
if (validateData.success) {
|
|
||||||
setUserEmail(validateData.user_email || 'Connected User');
|
|
||||||
setIsConnected(true);
|
|
||||||
setAuthError('');
|
|
||||||
setAccessToken(validateData.access_token || null);
|
|
||||||
setIsValidating(false);
|
|
||||||
return true;
|
|
||||||
} else {
|
|
||||||
setIsConnected(false);
|
|
||||||
setAuthError(validateData.error || 'Session expired. Please reconnect your account.');
|
|
||||||
setIsValidating(false);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
} catch (error) {
|
|
||||||
console.error('Error validating session:', error);
|
|
||||||
setAuthError('Failed to validate session. Please reconnect.');
|
|
||||||
setIsConnected(false);
|
|
||||||
setIsValidating(false);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
const handleOpenPicker = async () => {
|
|
||||||
setIsLoading(true);
|
|
||||||
|
|
||||||
const sessionToken = getSessionToken('google_drive');
|
|
||||||
|
|
||||||
if (!sessionToken) {
|
|
||||||
setAuthError('No valid session found. Please reconnect to Google Drive.');
|
|
||||||
setIsLoading(false);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!accessToken) {
|
|
||||||
setAuthError('No access token available. Please reconnect to Google Drive.');
|
|
||||||
setIsLoading(false);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
try {
|
|
||||||
const clientId: string = import.meta.env.VITE_GOOGLE_CLIENT_ID;
|
|
||||||
|
|
||||||
// Derive appId from clientId (extract numeric part before first dash)
|
|
||||||
const appId = clientId ? clientId.split('-')[0] : null;
|
|
||||||
|
|
||||||
if (!clientId || !appId) {
|
|
||||||
console.error('Missing Google Drive configuration');
|
|
||||||
|
|
||||||
setIsLoading(false);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
openPicker({
|
|
||||||
clientId: clientId,
|
|
||||||
developerKey: "",
|
|
||||||
appId: appId,
|
|
||||||
setSelectFolderEnabled: false,
|
|
||||||
viewId: "DOCS",
|
|
||||||
showUploadView: false,
|
|
||||||
showUploadFolders: false,
|
|
||||||
supportDrives: false,
|
|
||||||
multiselect: true,
|
|
||||||
token: accessToken,
|
|
||||||
viewMimeTypes: 'application/vnd.google-apps.document,application/vnd.google-apps.presentation,application/vnd.google-apps.spreadsheet,application/pdf,application/vnd.openxmlformats-officedocument.wordprocessingml.document,application/vnd.openxmlformats-officedocument.presentationml.presentation,application/vnd.openxmlformats-officedocument.spreadsheetml.sheet,application/msword,application/vnd.ms-powerpoint,application/vnd.ms-excel,text/plain,text/csv,text/html,text/markdown,text/x-rst,application/json,application/epub+zip,application/rtf,image/jpeg,image/jpg,image/png',
|
|
||||||
callbackFunction: (data:any) => {
|
|
||||||
setIsLoading(false);
|
|
||||||
if (data.action === 'picked') {
|
|
||||||
const docs = data.docs;
|
|
||||||
|
|
||||||
const newFiles: PickerFile[] = [];
|
|
||||||
const newFolders: PickerFile[] = [];
|
|
||||||
|
|
||||||
docs.forEach((doc: any) => {
|
|
||||||
const item = {
|
|
||||||
id: doc.id,
|
|
||||||
name: doc.name,
|
|
||||||
mimeType: doc.mimeType,
|
|
||||||
iconUrl: doc.iconUrl || '',
|
|
||||||
description: doc.description,
|
|
||||||
sizeBytes: doc.sizeBytes
|
|
||||||
};
|
|
||||||
|
|
||||||
if (doc.mimeType === 'application/vnd.google-apps.folder') {
|
|
||||||
newFolders.push(item);
|
|
||||||
} else {
|
|
||||||
newFiles.push(item);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
setSelectedFiles(prevFiles => {
|
|
||||||
const existingFileIds = new Set(prevFiles.map(file => file.id));
|
|
||||||
const uniqueNewFiles = newFiles.filter(file => !existingFileIds.has(file.id));
|
|
||||||
return [...prevFiles, ...uniqueNewFiles];
|
|
||||||
});
|
|
||||||
|
|
||||||
setSelectedFolders(prevFolders => {
|
|
||||||
const existingFolderIds = new Set(prevFolders.map(folder => folder.id));
|
|
||||||
const uniqueNewFolders = newFolders.filter(folder => !existingFolderIds.has(folder.id));
|
|
||||||
return [...prevFolders, ...uniqueNewFolders];
|
|
||||||
});
|
|
||||||
onSelectionChange(
|
|
||||||
[...selectedFiles, ...newFiles].map(file => file.id),
|
|
||||||
[...selectedFolders, ...newFolders].map(folder => folder.id)
|
|
||||||
);
|
|
||||||
}
|
|
||||||
},
|
|
||||||
});
|
|
||||||
} catch (error) {
|
|
||||||
console.error('Error opening picker:', error);
|
|
||||||
setAuthError('Failed to open file picker. Please try again.');
|
|
||||||
setIsLoading(false);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
const handleDisconnect = async () => {
|
|
||||||
const sessionToken = getSessionToken('google_drive');
|
|
||||||
if (sessionToken) {
|
|
||||||
try {
|
|
||||||
const apiHost = import.meta.env.VITE_API_HOST;
|
|
||||||
await fetch(`${apiHost}/api/connectors/disconnect`, {
|
|
||||||
method: 'POST',
|
|
||||||
headers: {
|
|
||||||
'Content-Type': 'application/json',
|
|
||||||
'Authorization': `Bearer ${token}`
|
|
||||||
},
|
|
||||||
body: JSON.stringify({ provider: 'google_drive', session_token: sessionToken })
|
|
||||||
});
|
|
||||||
} catch (err) {
|
|
||||||
console.error('Error disconnecting from Google Drive:', err);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
removeSessionToken('google_drive');
|
|
||||||
setIsConnected(false);
|
|
||||||
setSelectedFiles([]);
|
|
||||||
setSelectedFolders([]);
|
|
||||||
setAccessToken(null);
|
|
||||||
setUserEmail('');
|
|
||||||
setAuthError('');
|
|
||||||
onSelectionChange([], []);
|
|
||||||
};
|
|
||||||
|
|
||||||
const ConnectedStateSkeleton = () => (
|
|
||||||
<div className="mb-4">
|
|
||||||
<div className="w-full flex items-center justify-between rounded-[10px] bg-gray-200 dark:bg-gray-700 px-4 py-2 animate-pulse">
|
|
||||||
<div className="flex items-center gap-2">
|
|
||||||
<div className="h-4 w-4 bg-gray-300 dark:bg-gray-600 rounded"></div>
|
|
||||||
<div className="h-4 w-32 bg-gray-300 dark:bg-gray-600 rounded"></div>
|
|
||||||
</div>
|
|
||||||
<div className="h-4 w-16 bg-gray-300 dark:bg-gray-600 rounded"></div>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
|
|
||||||
const FilesSectionSkeleton = () => (
|
|
||||||
<div className="border border-[#EEE6FF78] rounded-lg dark:border-[#6A6A6A]">
|
|
||||||
<div className="p-4">
|
|
||||||
<div className="flex justify-between items-center mb-4">
|
|
||||||
<div className="h-5 w-24 bg-gray-200 dark:bg-gray-700 rounded animate-pulse"></div>
|
|
||||||
<div className="h-8 w-24 bg-gray-200 dark:bg-gray-700 rounded animate-pulse"></div>
|
|
||||||
</div>
|
|
||||||
<div className="h-4 w-40 bg-gray-200 dark:bg-gray-700 rounded animate-pulse"></div>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
|
|
||||||
return (
|
|
||||||
<div>
|
|
||||||
{isValidating ? (
|
|
||||||
<>
|
|
||||||
<ConnectedStateSkeleton />
|
|
||||||
<FilesSectionSkeleton />
|
|
||||||
</>
|
|
||||||
) : (
|
|
||||||
<>
|
|
||||||
<ConnectorAuth
|
|
||||||
provider="google_drive"
|
|
||||||
label="Connect to Google Drive"
|
|
||||||
onSuccess={(data) => {
|
|
||||||
setUserEmail(data.user_email || 'Connected User');
|
|
||||||
setIsConnected(true);
|
|
||||||
setAuthError('');
|
|
||||||
|
|
||||||
if (data.session_token) {
|
|
||||||
setSessionToken('google_drive', data.session_token);
|
|
||||||
validateSession(data.session_token);
|
|
||||||
}
|
|
||||||
}}
|
|
||||||
onError={(error) => {
|
|
||||||
setAuthError(error);
|
|
||||||
setIsConnected(false);
|
|
||||||
}}
|
|
||||||
isConnected={isConnected}
|
|
||||||
userEmail={userEmail}
|
|
||||||
onDisconnect={handleDisconnect}
|
|
||||||
errorMessage={authError}
|
|
||||||
/>
|
|
||||||
|
|
||||||
{isConnected && (
|
|
||||||
<div className="border border-[#EEE6FF78] rounded-lg dark:border-[#6A6A6A]">
|
|
||||||
<div className="p-4">
|
|
||||||
<div className="flex justify-between items-center mb-4">
|
|
||||||
<h3 className="text-sm font-medium">Selected Files</h3>
|
|
||||||
<button
|
|
||||||
onClick={() => handleOpenPicker()}
|
|
||||||
className="bg-[#A076F6] hover:bg-[#8A5FD4] text-white text-sm py-1 px-3 rounded-md"
|
|
||||||
disabled={isLoading}
|
|
||||||
>
|
|
||||||
{isLoading ? 'Loading...' : 'Select Files'}
|
|
||||||
</button>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
{selectedFiles.length === 0 && selectedFolders.length === 0 ? (
|
|
||||||
<p className="text-gray-600 dark:text-gray-400 text-sm">No files or folders selected</p>
|
|
||||||
) : (
|
|
||||||
<div className="max-h-60 overflow-y-auto">
|
|
||||||
{selectedFolders.length > 0 && (
|
|
||||||
<div className="mb-2">
|
|
||||||
<h4 className="text-xs font-medium text-gray-500 mb-1">Folders</h4>
|
|
||||||
{selectedFolders.map((folder) => (
|
|
||||||
<div key={folder.id} className="flex items-center p-2 border-b border-gray-200 dark:border-gray-700">
|
|
||||||
<img src={folder.iconUrl} alt="Folder" className="w-5 h-5 mr-2" />
|
|
||||||
<span className="text-sm truncate flex-1">{folder.name}</span>
|
|
||||||
<button
|
|
||||||
onClick={() => {
|
|
||||||
const newSelectedFolders = selectedFolders.filter(f => f.id !== folder.id);
|
|
||||||
setSelectedFolders(newSelectedFolders);
|
|
||||||
onSelectionChange(
|
|
||||||
selectedFiles.map(f => f.id),
|
|
||||||
newSelectedFolders.map(f => f.id)
|
|
||||||
);
|
|
||||||
}}
|
|
||||||
className="text-red-500 hover:text-red-700 text-sm ml-2"
|
|
||||||
>
|
|
||||||
Remove
|
|
||||||
</button>
|
|
||||||
</div>
|
|
||||||
))}
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
|
|
||||||
{selectedFiles.length > 0 && (
|
|
||||||
<div>
|
|
||||||
<h4 className="text-xs font-medium text-gray-500 mb-1">Files</h4>
|
|
||||||
{selectedFiles.map((file) => (
|
|
||||||
<div key={file.id} className="flex items-center p-2 border-b border-gray-200 dark:border-gray-700">
|
|
||||||
<img src={file.iconUrl} alt="File" className="w-5 h-5 mr-2" />
|
|
||||||
<span className="text-sm truncate flex-1">{file.name}</span>
|
|
||||||
<button
|
|
||||||
onClick={() => {
|
|
||||||
const newSelectedFiles = selectedFiles.filter(f => f.id !== file.id);
|
|
||||||
setSelectedFiles(newSelectedFiles);
|
|
||||||
onSelectionChange(
|
|
||||||
newSelectedFiles.map(f => f.id),
|
|
||||||
selectedFolders.map(f => f.id)
|
|
||||||
);
|
|
||||||
}}
|
|
||||||
className="text-red-500 hover:text-red-700 text-sm ml-2"
|
|
||||||
>
|
|
||||||
Remove
|
|
||||||
</button>
|
|
||||||
</div>
|
|
||||||
))}
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
</>
|
|
||||||
)}
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
export default GoogleDrivePicker;
|
|
||||||
@@ -16,7 +16,6 @@ const Input = ({
|
|||||||
textSize = 'medium',
|
textSize = 'medium',
|
||||||
children,
|
children,
|
||||||
labelBgClassName = 'bg-white dark:bg-raisin-black',
|
labelBgClassName = 'bg-white dark:bg-raisin-black',
|
||||||
leftIcon,
|
|
||||||
onChange,
|
onChange,
|
||||||
onPaste,
|
onPaste,
|
||||||
onKeyDown,
|
onKeyDown,
|
||||||
@@ -43,7 +42,7 @@ const Input = ({
|
|||||||
<div className={`relative ${className}`}>
|
<div className={`relative ${className}`}>
|
||||||
<input
|
<input
|
||||||
ref={inputRef}
|
ref={inputRef}
|
||||||
className={`peer text-jet dark:text-bright-gray h-[42px] w-full rounded-full bg-transparent ${leftIcon ? 'pl-10' : 'px-3'} py-1 placeholder-transparent outline-hidden ${colorStyles[colorVariant]} ${borderStyles[borderVariant]} ${textSizeStyles[textSize]} [&:-webkit-autofill]:appearance-none [&:-webkit-autofill]:bg-transparent [&:-webkit-autofill_selected]:bg-transparent`}
|
className={`peer text-jet dark:text-bright-gray h-[42px] w-full rounded-full bg-transparent px-3 py-1 placeholder-transparent outline-hidden ${colorStyles[colorVariant]} ${borderStyles[borderVariant]} ${textSizeStyles[textSize]} [&:-webkit-autofill]:appearance-none [&:-webkit-autofill]:bg-transparent [&:-webkit-autofill_selected]:bg-transparent`}
|
||||||
type={type}
|
type={type}
|
||||||
id={id}
|
id={id}
|
||||||
name={name}
|
name={name}
|
||||||
@@ -58,19 +57,12 @@ const Input = ({
|
|||||||
>
|
>
|
||||||
{children}
|
{children}
|
||||||
</input>
|
</input>
|
||||||
{leftIcon && (
|
|
||||||
<div className="absolute left-3 top-1/2 transform -translate-y-1/2 flex items-center justify-center">
|
|
||||||
{leftIcon}
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
{placeholder && (
|
{placeholder && (
|
||||||
<label
|
<label
|
||||||
htmlFor={id}
|
htmlFor={id}
|
||||||
className={`absolute select-none ${
|
className={`absolute select-none ${
|
||||||
hasValue ? '-top-2.5 left-3 text-xs' : ''
|
hasValue ? '-top-2.5 left-3 text-xs' : ''
|
||||||
} px-2 transition-all peer-placeholder-shown:top-2.5 ${
|
} px-2 transition-all peer-placeholder-shown:top-2.5 peer-placeholder-shown:left-3 peer-placeholder-shown:${
|
||||||
leftIcon ? 'peer-placeholder-shown:left-7' : 'peer-placeholder-shown:left-3'
|
|
||||||
} peer-placeholder-shown:${
|
|
||||||
textSizeStyles[textSize]
|
textSizeStyles[textSize]
|
||||||
} text-gray-4000 pointer-events-none cursor-none peer-focus:-top-2.5 peer-focus:left-3 peer-focus:text-xs dark:text-gray-400 ${labelBgClassName} max-w-[calc(100%-24px)] overflow-hidden text-ellipsis whitespace-nowrap`}
|
} text-gray-4000 pointer-events-none cursor-none peer-focus:-top-2.5 peer-focus:left-3 peer-focus:text-xs dark:text-gray-400 ${labelBgClassName} max-w-[calc(100%-24px)] overflow-hidden text-ellipsis whitespace-nowrap`}
|
||||||
>
|
>
|
||||||
|
|||||||
@@ -259,7 +259,7 @@ export default function MessageInput({
|
|||||||
return (
|
return (
|
||||||
<div className="mx-2 flex w-full flex-col">
|
<div className="mx-2 flex w-full flex-col">
|
||||||
<div className="border-dark-gray bg-lotion dark:border-grey relative flex w-full flex-col rounded-[23px] border dark:bg-transparent">
|
<div className="border-dark-gray bg-lotion dark:border-grey relative flex w-full flex-col rounded-[23px] border dark:bg-transparent">
|
||||||
<div className="flex flex-wrap gap-1.5 px-2 py-2 sm:gap-2 sm:px-3">
|
<div className="flex flex-wrap gap-1.5 px-4 pt-3 pb-0 sm:gap-2 sm:px-6">
|
||||||
{attachments.map((attachment, index) => (
|
{attachments.map((attachment, index) => (
|
||||||
<div
|
<div
|
||||||
key={index}
|
key={index}
|
||||||
@@ -353,14 +353,14 @@ export default function MessageInput({
|
|||||||
onChange={handleChange}
|
onChange={handleChange}
|
||||||
tabIndex={1}
|
tabIndex={1}
|
||||||
placeholder={t('inputPlaceholder')}
|
placeholder={t('inputPlaceholder')}
|
||||||
className="inputbox-style no-scrollbar bg-lotion dark:text-bright-gray dark:placeholder:text-bright-gray/50 w-full overflow-x-hidden overflow-y-auto rounded-t-[23px] px-2 text-base leading-tight whitespace-pre-wrap opacity-100 placeholder:text-gray-500 focus:outline-hidden sm:px-3 dark:bg-transparent"
|
className="inputbox-style no-scrollbar bg-lotion dark:text-bright-gray dark:placeholder:text-bright-gray/50 w-full overflow-x-hidden overflow-y-auto rounded-t-[23px] px-4 py-3 text-base leading-tight whitespace-pre-wrap opacity-100 placeholder:text-gray-500 focus:outline-hidden sm:px-6 sm:py-5 dark:bg-transparent"
|
||||||
onInput={handleInput}
|
onInput={handleInput}
|
||||||
onKeyDown={handleKeyDown}
|
onKeyDown={handleKeyDown}
|
||||||
aria-label={t('inputPlaceholder')}
|
aria-label={t('inputPlaceholder')}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div className="flex items-center px-2 pb-1.5 sm:px-3 sm:pb-2">
|
<div className="flex items-center px-3 py-1.5 sm:px-4 sm:py-2">
|
||||||
<div className="flex grow flex-wrap gap-1 sm:gap-2">
|
<div className="flex grow flex-wrap gap-1 sm:gap-2">
|
||||||
{showSourceButton && (
|
{showSourceButton && (
|
||||||
<button
|
<button
|
||||||
@@ -368,8 +368,8 @@ export default function MessageInput({
|
|||||||
className="xs:px-3 xs:py-1.5 dark:border-purple-taupe flex max-w-[130px] items-center rounded-[32px] border border-[#AAAAAA] px-2 py-1 transition-colors hover:bg-gray-100 sm:max-w-[150px] dark:hover:bg-[#2C2E3C]"
|
className="xs:px-3 xs:py-1.5 dark:border-purple-taupe flex max-w-[130px] items-center rounded-[32px] border border-[#AAAAAA] px-2 py-1 transition-colors hover:bg-gray-100 sm:max-w-[150px] dark:hover:bg-[#2C2E3C]"
|
||||||
onClick={() => setIsSourcesPopupOpen(!isSourcesPopupOpen)}
|
onClick={() => setIsSourcesPopupOpen(!isSourcesPopupOpen)}
|
||||||
title={
|
title={
|
||||||
selectedDocs && selectedDocs.length > 0
|
selectedDocs
|
||||||
? selectedDocs.map((doc) => doc.name).join(', ')
|
? selectedDocs.name
|
||||||
: t('conversation.sources.title')
|
: t('conversation.sources.title')
|
||||||
}
|
}
|
||||||
>
|
>
|
||||||
@@ -379,10 +379,8 @@ export default function MessageInput({
|
|||||||
className="mr-1 h-3.5 w-3.5 shrink-0 sm:mr-1.5 sm:h-4"
|
className="mr-1 h-3.5 w-3.5 shrink-0 sm:mr-1.5 sm:h-4"
|
||||||
/>
|
/>
|
||||||
<span className="xs:text-[12px] dark:text-bright-gray truncate overflow-hidden text-[10px] font-medium text-[#5D5D5D] sm:text-[14px]">
|
<span className="xs:text-[12px] dark:text-bright-gray truncate overflow-hidden text-[10px] font-medium text-[#5D5D5D] sm:text-[14px]">
|
||||||
{selectedDocs && selectedDocs.length > 0
|
{selectedDocs
|
||||||
? selectedDocs.length === 1
|
? selectedDocs.name
|
||||||
? selectedDocs[0].name
|
|
||||||
: `${selectedDocs.length} sources selected`
|
|
||||||
: t('conversation.sources.title')}
|
: t('conversation.sources.title')}
|
||||||
</span>
|
</span>
|
||||||
{!isTouch && (
|
{!isTouch && (
|
||||||
@@ -430,18 +428,18 @@ export default function MessageInput({
|
|||||||
<button
|
<button
|
||||||
onClick={loading ? undefined : handleSubmit}
|
onClick={loading ? undefined : handleSubmit}
|
||||||
aria-label={loading ? t('loading') : t('send')}
|
aria-label={loading ? t('loading') : t('send')}
|
||||||
className={`flex h-7 w-7 items-center justify-center rounded-full sm:h-9 sm:w-9 ${loading || !value.trim() ? 'bg-black opacity-60 dark:bg-[#F0F3F4] dark:opacity-80' : 'bg-black opacity-100 dark:bg-[#F0F3F4]'} ml-auto shrink-0`}
|
className={`flex items-center justify-center rounded-full p-2 sm:p-2.5 ${loading ? 'bg-gray-300 dark:bg-gray-600' : 'bg-black dark:bg-white'} ml-auto shrink-0`}
|
||||||
disabled={loading}
|
disabled={loading}
|
||||||
>
|
>
|
||||||
{loading ? (
|
{loading ? (
|
||||||
<img
|
<img
|
||||||
src={isDarkTheme ? SpinnerDark : Spinner}
|
src={isDarkTheme ? SpinnerDark : Spinner}
|
||||||
className="mx-auto my-auto block h-3.5 w-3.5 animate-spin sm:h-4 sm:w-4"
|
className="h-3.5 w-3.5 animate-spin sm:h-4 sm:w-4"
|
||||||
alt={t('loading')}
|
alt={t('loading')}
|
||||||
/>
|
/>
|
||||||
) : (
|
) : (
|
||||||
<img
|
<img
|
||||||
className={`mx-auto my-auto block h-3.5 w-3.5 translate-x-[-0.9px] translate-y-[1.1px] sm:h-4 sm:w-4 ${isDarkTheme ? 'invert filter' : ''}`}
|
className={`h-3.5 w-3.5 sm:h-4 sm:w-4 ${isDarkTheme ? 'invert filter' : ''}`}
|
||||||
src={PaperPlane}
|
src={PaperPlane}
|
||||||
alt={t('send')}
|
alt={t('send')}
|
||||||
/>
|
/>
|
||||||
|
|||||||
@@ -248,7 +248,7 @@ export default function MultiSelectPopup({
|
|||||||
</div>
|
</div>
|
||||||
<div className="shrink-0">
|
<div className="shrink-0">
|
||||||
<div
|
<div
|
||||||
className={`dark:bg-charleston-green-2 flex h-4 w-4 items-center justify-center rounded-xs border-2 border-[#C6C6C6] bg-white dark:border-[#757783]`}
|
className={`dark:bg-charleston-green-2 flex h-4 w-4 items-center justify-center rounded-xs border border-[#C6C6C6] bg-white dark:border-[#757783]`}
|
||||||
aria-hidden="true"
|
aria-hidden="true"
|
||||||
>
|
>
|
||||||
{isSelected && (
|
{isSelected && (
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ type SourcesPopupProps = {
|
|||||||
isOpen: boolean;
|
isOpen: boolean;
|
||||||
onClose: () => void;
|
onClose: () => void;
|
||||||
anchorRef: React.RefObject<HTMLButtonElement | null>;
|
anchorRef: React.RefObject<HTMLButtonElement | null>;
|
||||||
handlePostDocumentSelect: (doc: Doc[] | null) => void;
|
handlePostDocumentSelect: (doc: Doc | null) => void;
|
||||||
setUploadModalState: React.Dispatch<React.SetStateAction<ActiveState>>;
|
setUploadModalState: React.Dispatch<React.SetStateAction<ActiveState>>;
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -149,13 +149,9 @@ export default function SourcesPopup({
|
|||||||
if (option.model === embeddingsName) {
|
if (option.model === embeddingsName) {
|
||||||
const isSelected =
|
const isSelected =
|
||||||
selectedDocs &&
|
selectedDocs &&
|
||||||
Array.isArray(selectedDocs) &&
|
(option.id
|
||||||
selectedDocs.length > 0 &&
|
? selectedDocs.id === option.id
|
||||||
selectedDocs.some((doc) =>
|
: selectedDocs.date === option.date);
|
||||||
option.id
|
|
||||||
? doc.id === option.id
|
|
||||||
: doc.date === option.date,
|
|
||||||
);
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div
|
<div
|
||||||
@@ -163,29 +159,11 @@ export default function SourcesPopup({
|
|||||||
className="border-opacity-80 dark:border-dim-gray flex cursor-pointer items-center border-b border-[#D9D9D9] p-3 transition-colors hover:bg-gray-100 dark:text-[14px] dark:hover:bg-[#2C2E3C]"
|
className="border-opacity-80 dark:border-dim-gray flex cursor-pointer items-center border-b border-[#D9D9D9] p-3 transition-colors hover:bg-gray-100 dark:text-[14px] dark:hover:bg-[#2C2E3C]"
|
||||||
onClick={() => {
|
onClick={() => {
|
||||||
if (isSelected) {
|
if (isSelected) {
|
||||||
const updatedDocs =
|
dispatch(setSelectedDocs(null));
|
||||||
selectedDocs && Array.isArray(selectedDocs)
|
handlePostDocumentSelect(null);
|
||||||
? selectedDocs.filter((doc) =>
|
|
||||||
option.id
|
|
||||||
? doc.id !== option.id
|
|
||||||
: doc.date !== option.date,
|
|
||||||
)
|
|
||||||
: [];
|
|
||||||
dispatch(
|
|
||||||
setSelectedDocs(
|
|
||||||
updatedDocs.length > 0 ? updatedDocs : null,
|
|
||||||
),
|
|
||||||
);
|
|
||||||
handlePostDocumentSelect(
|
|
||||||
updatedDocs.length > 0 ? updatedDocs : null,
|
|
||||||
);
|
|
||||||
} else {
|
} else {
|
||||||
const updatedDocs =
|
dispatch(setSelectedDocs(option));
|
||||||
selectedDocs && Array.isArray(selectedDocs)
|
handlePostDocumentSelect(option);
|
||||||
? [...selectedDocs, option]
|
|
||||||
: [option];
|
|
||||||
dispatch(setSelectedDocs(updatedDocs));
|
|
||||||
handlePostDocumentSelect(updatedDocs);
|
|
||||||
}
|
}
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
@@ -200,7 +178,7 @@ export default function SourcesPopup({
|
|||||||
{option.name}
|
{option.name}
|
||||||
</span>
|
</span>
|
||||||
<div
|
<div
|
||||||
className={`flex h-4 w-4 shrink-0 items-center justify-center rounded-xs border-2 border-[#C6C6C6] p-[0.5px] dark:border-[#757783]`}
|
className={`flex h-4 w-4 shrink-0 items-center justify-center border border-[#C6C6C6] p-[0.5px] dark:border-[#757783]`}
|
||||||
>
|
>
|
||||||
{isSelected && (
|
{isSelected && (
|
||||||
<img
|
<img
|
||||||
|
|||||||
@@ -1,172 +0,0 @@
|
|||||||
import React from 'react';
|
|
||||||
|
|
||||||
interface TableProps {
|
|
||||||
children: React.ReactNode;
|
|
||||||
className?: string;
|
|
||||||
minWidth?: string;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
interface TableContainerProps {
|
|
||||||
children: React.ReactNode;
|
|
||||||
className?: string;
|
|
||||||
ref?: React.Ref<HTMLDivElement>;
|
|
||||||
height?: string;
|
|
||||||
bordered?: boolean;
|
|
||||||
}
|
|
||||||
|
|
||||||
interface TableHeadProps {
|
|
||||||
children: React.ReactNode;
|
|
||||||
className?: string;
|
|
||||||
}
|
|
||||||
|
|
||||||
interface TableRowProps {
|
|
||||||
children: React.ReactNode;
|
|
||||||
className?: string;
|
|
||||||
onClick?: () => void;
|
|
||||||
}
|
|
||||||
|
|
||||||
interface TableCellProps {
|
|
||||||
children?: React.ReactNode;
|
|
||||||
className?: string;
|
|
||||||
minWidth?: string;
|
|
||||||
width?: string;
|
|
||||||
align?: 'left' | 'right' | 'center';
|
|
||||||
}
|
|
||||||
|
|
||||||
const TableContainer = React.forwardRef<HTMLDivElement, TableContainerProps>(({
|
|
||||||
children,
|
|
||||||
className = '',
|
|
||||||
height = 'auto',
|
|
||||||
bordered = true
|
|
||||||
}, ref) => {
|
|
||||||
return (
|
|
||||||
<div className={`relative rounded-[6px] ${className}`}>
|
|
||||||
<div
|
|
||||||
ref={ref}
|
|
||||||
className={`w-full overflow-x-auto rounded-[6px] bg-transparent ${bordered ? 'border border-[#D7D7D7] dark:border-[#6A6A6A]' : ''}`}
|
|
||||||
style={{
|
|
||||||
maxHeight: height === 'auto' ? undefined : height,
|
|
||||||
overflowY: height === 'auto' ? 'hidden' : 'auto'
|
|
||||||
}}
|
|
||||||
>
|
|
||||||
{children}
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
});;
|
|
||||||
const Table: React.FC<TableProps> = ({
|
|
||||||
children,
|
|
||||||
className = '',
|
|
||||||
minWidth = 'min-w-[600px]'
|
|
||||||
}) => {
|
|
||||||
return (
|
|
||||||
<table className={`w-full table-auto border-collapse bg-transparent ${minWidth} ${className}`}>
|
|
||||||
{children}
|
|
||||||
</table>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
const TableHead: React.FC<TableHeadProps> = ({ children, className = '' }) => {
|
|
||||||
return (
|
|
||||||
<thead className={`
|
|
||||||
sticky top-0 z-10
|
|
||||||
bg-gray-100 dark:bg-[#27282D]
|
|
||||||
${className}
|
|
||||||
`}>
|
|
||||||
{children}
|
|
||||||
</thead>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
const TableBody: React.FC<TableHeadProps> = ({ children, className = '' }) => {
|
|
||||||
return (
|
|
||||||
<tbody className={`[&>tr:last-child]:border-b-0 ${className}`}>
|
|
||||||
{children}
|
|
||||||
</tbody>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
const TableRow: React.FC<TableRowProps> = ({ children, className = '', onClick }) => {
|
|
||||||
const baseClasses = "border-b border-[#D7D7D7] hover:bg-[#ECEEEF] dark:border-[#6A6A6A] dark:hover:bg-[#27282D]";
|
|
||||||
const cursorClass = onClick ? "cursor-pointer" : "";
|
|
||||||
|
|
||||||
return (
|
|
||||||
<tr className={`${baseClasses} ${cursorClass} ${className}`} onClick={onClick}>
|
|
||||||
{children}
|
|
||||||
</tr>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
const TableHeader: React.FC<TableCellProps> = ({
|
|
||||||
children,
|
|
||||||
className = '',
|
|
||||||
minWidth,
|
|
||||||
width,
|
|
||||||
align = 'left'
|
|
||||||
}) => {
|
|
||||||
const getAlignmentClass = () => {
|
|
||||||
switch (align) {
|
|
||||||
case 'right':
|
|
||||||
return 'text-right';
|
|
||||||
case 'center':
|
|
||||||
return 'text-center';
|
|
||||||
default:
|
|
||||||
return 'text-left';
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
const baseClasses = `px-2 py-3 text-sm font-medium text-gray-700 lg:px-3 dark:text-[#59636E] border-b border-[#D7D7D7] dark:border-[#6A6A6A] relative box-border ${getAlignmentClass()}`;
|
|
||||||
const widthClasses = minWidth ? minWidth : '';
|
|
||||||
|
|
||||||
return (
|
|
||||||
<th
|
|
||||||
className={`${baseClasses} ${widthClasses} ${className}`}
|
|
||||||
style={width ? { width, minWidth: width, maxWidth: width } : {}}
|
|
||||||
>
|
|
||||||
{children}
|
|
||||||
</th>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
const TableCell: React.FC<TableCellProps> = ({
|
|
||||||
children,
|
|
||||||
className = '',
|
|
||||||
minWidth,
|
|
||||||
width,
|
|
||||||
align = 'left'
|
|
||||||
}) => {
|
|
||||||
const getAlignmentClass = () => {
|
|
||||||
switch (align) {
|
|
||||||
case 'right':
|
|
||||||
return 'text-right';
|
|
||||||
case 'center':
|
|
||||||
return 'text-center';
|
|
||||||
default:
|
|
||||||
return 'text-left';
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
const baseClasses = `px-2 py-2 text-sm lg:px-3 dark:text-[#E0E0E0] box-border ${getAlignmentClass()}`;
|
|
||||||
const widthClasses = minWidth ? minWidth : '';
|
|
||||||
|
|
||||||
return (
|
|
||||||
<td
|
|
||||||
className={`${baseClasses} ${widthClasses} ${className}`}
|
|
||||||
style={width ? { width, minWidth: width, maxWidth: width } : {}}
|
|
||||||
>
|
|
||||||
{children}
|
|
||||||
</td>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
export {
|
|
||||||
Table,
|
|
||||||
TableContainer,
|
|
||||||
TableHead,
|
|
||||||
TableBody,
|
|
||||||
TableRow,
|
|
||||||
TableHeader,
|
|
||||||
TableCell,
|
|
||||||
};
|
|
||||||
|
|
||||||
export default Table;
|
|
||||||
@@ -46,14 +46,14 @@ const ToggleSwitch: React.FC<ToggleSwitchProps> = ({
|
|||||||
|
|
||||||
return (
|
return (
|
||||||
<label
|
<label
|
||||||
className={`flex cursor-pointer flex-row items-center select-none ${
|
className={`flex cursor-pointer select-none flex-row items-center ${
|
||||||
labelPosition === 'right' ? 'flex-row-reverse' : ''
|
labelPosition === 'right' ? 'flex-row-reverse' : ''
|
||||||
} ${disabled ? 'cursor-not-allowed opacity-50' : ''} ${className}`}
|
} ${disabled ? 'cursor-not-allowed opacity-50' : ''} ${className}`}
|
||||||
>
|
>
|
||||||
{label && (
|
{label && (
|
||||||
<span
|
<span
|
||||||
className={`text-eerie-black dark:text-white ${
|
className={`text-eerie-black dark:text-white ${
|
||||||
labelPosition === 'left' ? 'mr-3' : 'ml-3'
|
labelPosition === 'left' ? 'mr-1' : 'ml-1'
|
||||||
}`}
|
}`}
|
||||||
>
|
>
|
||||||
{label}
|
{label}
|
||||||
@@ -75,7 +75,7 @@ const ToggleSwitch: React.FC<ToggleSwitchProps> = ({
|
|||||||
}`}
|
}`}
|
||||||
></div>
|
></div>
|
||||||
<div
|
<div
|
||||||
className={`absolute ${toggle} flex items-center justify-center rounded-full bg-white transition ${
|
className={`absolute ${toggle} flex items-center justify-center rounded-full bg-white opacity-80 transition ${
|
||||||
checked ? `${translate} bg-silver` : ''
|
checked ? `${translate} bg-silver` : ''
|
||||||
}`}
|
}`}
|
||||||
></div>
|
></div>
|
||||||
|
|||||||
@@ -207,7 +207,7 @@ export default function ToolsPopup({
|
|||||||
</div>
|
</div>
|
||||||
<div className="flex shrink-0 items-center">
|
<div className="flex shrink-0 items-center">
|
||||||
<div
|
<div
|
||||||
className={`flex h-4 w-4 items-center justify-center rounded-xs border-2 border-[#C6C6C6] p-[0.5px] dark:border-[#757783]`}
|
className={`flex h-4 w-4 items-center justify-center border border-[#C6C6C6] p-[0.5px] dark:border-[#757783]`}
|
||||||
>
|
>
|
||||||
{tool.status && (
|
{tool.status && (
|
||||||
<img
|
<img
|
||||||
|
|||||||
@@ -22,7 +22,6 @@ export type InputProps = {
|
|||||||
onKeyDown?: (
|
onKeyDown?: (
|
||||||
e: React.KeyboardEvent<HTMLTextAreaElement | HTMLInputElement>,
|
e: React.KeyboardEvent<HTMLTextAreaElement | HTMLInputElement>,
|
||||||
) => void;
|
) => void;
|
||||||
leftIcon?: React.ReactNode;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
export type MermaidRendererProps = {
|
export type MermaidRendererProps = {
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import 'katex/dist/katex.min.css';
|
import 'katex/dist/katex.min.css';
|
||||||
|
|
||||||
import { forwardRef, Fragment, useEffect, useRef, useState } from 'react';
|
import { forwardRef, Fragment, useRef, useState, useEffect } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import ReactMarkdown from 'react-markdown';
|
import ReactMarkdown from 'react-markdown';
|
||||||
import { useSelector } from 'react-redux';
|
import { useSelector } from 'react-redux';
|
||||||
@@ -12,13 +12,12 @@ import {
|
|||||||
import rehypeKatex from 'rehype-katex';
|
import rehypeKatex from 'rehype-katex';
|
||||||
import remarkGfm from 'remark-gfm';
|
import remarkGfm from 'remark-gfm';
|
||||||
import remarkMath from 'remark-math';
|
import remarkMath from 'remark-math';
|
||||||
|
import DocumentationDark from '../assets/documentation-dark.svg';
|
||||||
import ChevronDown from '../assets/chevron-down.svg';
|
import ChevronDown from '../assets/chevron-down.svg';
|
||||||
import Cloud from '../assets/cloud.svg';
|
import Cloud from '../assets/cloud.svg';
|
||||||
import DocsGPT3 from '../assets/cute_docsgpt3.svg';
|
import DocsGPT3 from '../assets/cute_docsgpt3.svg';
|
||||||
import Dislike from '../assets/dislike.svg?react';
|
import Dislike from '../assets/dislike.svg?react';
|
||||||
import Document from '../assets/document.svg';
|
import Document from '../assets/document.svg';
|
||||||
import DocumentationDark from '../assets/documentation-dark.svg';
|
|
||||||
import Edit from '../assets/edit.svg';
|
import Edit from '../assets/edit.svg';
|
||||||
import Like from '../assets/like.svg?react';
|
import Like from '../assets/like.svg?react';
|
||||||
import Link from '../assets/link.svg';
|
import Link from '../assets/link.svg';
|
||||||
@@ -762,11 +761,7 @@ function ToolCalls({ toolCalls }: { toolCalls: ToolCallsType[] }) {
|
|||||||
Response
|
Response
|
||||||
</span>{' '}
|
</span>{' '}
|
||||||
<CopyButton
|
<CopyButton
|
||||||
textToCopy={
|
textToCopy={JSON.stringify(toolCall.result, null, 2)}
|
||||||
toolCall.status === 'error'
|
|
||||||
? toolCall.error || 'Unknown error'
|
|
||||||
: JSON.stringify(toolCall.result, null, 2)
|
|
||||||
}
|
|
||||||
/>
|
/>
|
||||||
</p>
|
</p>
|
||||||
{toolCall.status === 'pending' && (
|
{toolCall.status === 'pending' && (
|
||||||
@@ -784,16 +779,6 @@ function ToolCalls({ toolCalls }: { toolCalls: ToolCallsType[] }) {
|
|||||||
</span>
|
</span>
|
||||||
</p>
|
</p>
|
||||||
)}
|
)}
|
||||||
{toolCall.status === 'error' && (
|
|
||||||
<p className="dark:bg-raisin-black rounded-b-2xl p-2 font-mono text-sm break-words">
|
|
||||||
<span
|
|
||||||
className="leading-[23px] text-red-500 dark:text-red-400"
|
|
||||||
style={{ fontFamily: 'IBMPlexMono-Medium' }}
|
|
||||||
>
|
|
||||||
{toolCall.error}
|
|
||||||
</span>
|
|
||||||
</p>
|
|
||||||
)}
|
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</Accordion>
|
</Accordion>
|
||||||
|
|||||||
@@ -210,7 +210,7 @@ export default function ConversationMessages({
|
|||||||
)}
|
)}
|
||||||
|
|
||||||
<div className="w-full max-w-[1300px] px-2 md:w-9/12 lg:w-8/12 xl:w-8/12 2xl:w-6/12">
|
<div className="w-full max-w-[1300px] px-2 md:w-9/12 lg:w-8/12 xl:w-8/12 2xl:w-6/12">
|
||||||
{headerContent}
|
{headerContent && headerContent}
|
||||||
|
|
||||||
{queries.length > 0 ? (
|
{queries.length > 0 ? (
|
||||||
queries.map((query, index) => (
|
queries.map((query, index) => (
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ export function handleFetchAnswer(
|
|||||||
question: string,
|
question: string,
|
||||||
signal: AbortSignal,
|
signal: AbortSignal,
|
||||||
token: string | null,
|
token: string | null,
|
||||||
selectedDocs: Doc[] | null,
|
selectedDocs: Doc | null,
|
||||||
conversationId: string | null,
|
conversationId: string | null,
|
||||||
promptId: string | null,
|
promptId: string | null,
|
||||||
chunks: string,
|
chunks: string,
|
||||||
@@ -52,17 +52,10 @@ export function handleFetchAnswer(
|
|||||||
payload.attachments = attachments;
|
payload.attachments = attachments;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (selectedDocs && Array.isArray(selectedDocs)) {
|
if (selectedDocs && 'id' in selectedDocs) {
|
||||||
if (selectedDocs.length > 1) {
|
payload.active_docs = selectedDocs.id as string;
|
||||||
// Handle multiple documents
|
|
||||||
payload.active_docs = selectedDocs.map((doc) => doc.id!);
|
|
||||||
payload.retriever = selectedDocs[0]?.retriever as string;
|
|
||||||
} else if (selectedDocs.length === 1 && 'id' in selectedDocs[0]) {
|
|
||||||
// Handle single document (backward compatibility)
|
|
||||||
payload.active_docs = selectedDocs[0].id as string;
|
|
||||||
payload.retriever = selectedDocs[0].retriever as string;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
payload.retriever = selectedDocs?.retriever as string;
|
||||||
return conversationService
|
return conversationService
|
||||||
.answer(payload, token, signal)
|
.answer(payload, token, signal)
|
||||||
.then((response) => {
|
.then((response) => {
|
||||||
@@ -91,7 +84,7 @@ export function handleFetchAnswerSteaming(
|
|||||||
question: string,
|
question: string,
|
||||||
signal: AbortSignal,
|
signal: AbortSignal,
|
||||||
token: string | null,
|
token: string | null,
|
||||||
selectedDocs: Doc[] | null,
|
selectedDocs: Doc | null,
|
||||||
conversationId: string | null,
|
conversationId: string | null,
|
||||||
promptId: string | null,
|
promptId: string | null,
|
||||||
chunks: string,
|
chunks: string,
|
||||||
@@ -119,17 +112,10 @@ export function handleFetchAnswerSteaming(
|
|||||||
payload.attachments = attachments;
|
payload.attachments = attachments;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (selectedDocs && Array.isArray(selectedDocs)) {
|
if (selectedDocs && 'id' in selectedDocs) {
|
||||||
if (selectedDocs.length > 1) {
|
payload.active_docs = selectedDocs.id as string;
|
||||||
// Handle multiple documents
|
|
||||||
payload.active_docs = selectedDocs.map((doc) => doc.id!);
|
|
||||||
payload.retriever = selectedDocs[0]?.retriever as string;
|
|
||||||
} else if (selectedDocs.length === 1 && 'id' in selectedDocs[0]) {
|
|
||||||
// Handle single document (backward compatibility)
|
|
||||||
payload.active_docs = selectedDocs[0].id as string;
|
|
||||||
payload.retriever = selectedDocs[0].retriever as string;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
payload.retriever = selectedDocs?.retriever as string;
|
||||||
|
|
||||||
return new Promise<Answer>((resolve, reject) => {
|
return new Promise<Answer>((resolve, reject) => {
|
||||||
conversationService
|
conversationService
|
||||||
@@ -185,7 +171,7 @@ export function handleFetchAnswerSteaming(
|
|||||||
export function handleSearch(
|
export function handleSearch(
|
||||||
question: string,
|
question: string,
|
||||||
token: string | null,
|
token: string | null,
|
||||||
selectedDocs: Doc[] | null,
|
selectedDocs: Doc | null,
|
||||||
conversation_id: string | null,
|
conversation_id: string | null,
|
||||||
chunks: string,
|
chunks: string,
|
||||||
token_limit: number,
|
token_limit: number,
|
||||||
@@ -197,17 +183,9 @@ export function handleSearch(
|
|||||||
token_limit: token_limit,
|
token_limit: token_limit,
|
||||||
isNoneDoc: selectedDocs === null,
|
isNoneDoc: selectedDocs === null,
|
||||||
};
|
};
|
||||||
if (selectedDocs && Array.isArray(selectedDocs)) {
|
if (selectedDocs && 'id' in selectedDocs)
|
||||||
if (selectedDocs.length > 1) {
|
payload.active_docs = selectedDocs.id as string;
|
||||||
// Handle multiple documents
|
payload.retriever = selectedDocs?.retriever as string;
|
||||||
payload.active_docs = selectedDocs.map((doc) => doc.id!);
|
|
||||||
payload.retriever = selectedDocs[0]?.retriever as string;
|
|
||||||
} else if (selectedDocs.length === 1 && 'id' in selectedDocs[0]) {
|
|
||||||
// Handle single document (backward compatibility)
|
|
||||||
payload.active_docs = selectedDocs[0].id as string;
|
|
||||||
payload.retriever = selectedDocs[0].retriever as string;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return conversationService
|
return conversationService
|
||||||
.search(payload, token)
|
.search(payload, token)
|
||||||
.then((response) => response.json())
|
.then((response) => response.json())
|
||||||
|
|||||||
@@ -54,7 +54,7 @@ export interface Query {
|
|||||||
|
|
||||||
export interface RetrievalPayload {
|
export interface RetrievalPayload {
|
||||||
question: string;
|
question: string;
|
||||||
active_docs?: string | string[];
|
active_docs?: string;
|
||||||
retriever?: string;
|
retriever?: string;
|
||||||
conversation_id: string | null;
|
conversation_id: string | null;
|
||||||
prompt_id?: string | null;
|
prompt_id?: string | null;
|
||||||
|
|||||||
@@ -4,6 +4,5 @@ export type ToolCallsType = {
|
|||||||
call_id: string;
|
call_id: string;
|
||||||
arguments: Record<string, any>;
|
arguments: Record<string, any>;
|
||||||
result?: Record<string, any>;
|
result?: Record<string, any>;
|
||||||
error?: string;
|
status?: 'pending' | 'completed';
|
||||||
status?: 'pending' | 'completed' | 'error';
|
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -18,14 +18,11 @@ export default function useDefaultDocument() {
|
|||||||
const fetchDocs = () => {
|
const fetchDocs = () => {
|
||||||
getDocs(token).then((data) => {
|
getDocs(token).then((data) => {
|
||||||
dispatch(setSourceDocs(data));
|
dispatch(setSourceDocs(data));
|
||||||
if (
|
if (!selectedDoc)
|
||||||
!selectedDoc ||
|
|
||||||
(Array.isArray(selectedDoc) && selectedDoc.length === 0)
|
|
||||||
)
|
|
||||||
Array.isArray(data) &&
|
Array.isArray(data) &&
|
||||||
data?.forEach((doc: Doc) => {
|
data?.forEach((doc: Doc) => {
|
||||||
if (doc.model && doc.name === 'default') {
|
if (doc.model && doc.name === 'default') {
|
||||||
dispatch(setSelectedDocs([doc]));
|
dispatch(setSelectedDocs(doc));
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -68,7 +68,6 @@
|
|||||||
"private": "Private",
|
"private": "Private",
|
||||||
"sync": "Sync",
|
"sync": "Sync",
|
||||||
"syncing": "Syncing...",
|
"syncing": "Syncing...",
|
||||||
"syncConfirmation": "Are you sure you want to sync \"{{sourceName}}\"? This will update the content with your cloud storage and may override any edits you made to individual chunks.",
|
|
||||||
"syncFrequency": {
|
"syncFrequency": {
|
||||||
"never": "Never",
|
"never": "Never",
|
||||||
"daily": "Daily",
|
"daily": "Daily",
|
||||||
@@ -185,44 +184,7 @@
|
|||||||
"cancel": "Cancel",
|
"cancel": "Cancel",
|
||||||
"addNew": "Add New",
|
"addNew": "Add New",
|
||||||
"name": "Name",
|
"name": "Name",
|
||||||
"type": "Type",
|
"type": "Type"
|
||||||
"mcp": {
|
|
||||||
"addServer": "Add MCP Server",
|
|
||||||
"editServer": "Edit Server",
|
|
||||||
"serverName": "Server Name",
|
|
||||||
"serverUrl": "Server URL",
|
|
||||||
"headerName": "Header Name",
|
|
||||||
"timeout": "Timeout (seconds)",
|
|
||||||
"testConnection": "Test Connection",
|
|
||||||
"testing": "Testing",
|
|
||||||
"saving": "Saving",
|
|
||||||
"save": "Save",
|
|
||||||
"cancel": "Cancel",
|
|
||||||
"noAuth": "No Authentication",
|
|
||||||
"oauthInProgress": "Waiting for OAuth completion...",
|
|
||||||
"oauthCompleted": "OAuth completed successfully",
|
|
||||||
"placeholders": {
|
|
||||||
"serverUrl": "https://api.example.com",
|
|
||||||
"apiKey": "Your secret API key",
|
|
||||||
"bearerToken": "Your secret token",
|
|
||||||
"username": "Your username",
|
|
||||||
"password": "Your password",
|
|
||||||
"oauthScopes": "OAuth scopes (comma separated)"
|
|
||||||
},
|
|
||||||
"errors": {
|
|
||||||
"nameRequired": "Server name is required",
|
|
||||||
"urlRequired": "Server URL is required",
|
|
||||||
"invalidUrl": "Please enter a valid URL",
|
|
||||||
"apiKeyRequired": "API key is required",
|
|
||||||
"tokenRequired": "Bearer token is required",
|
|
||||||
"usernameRequired": "Username is required",
|
|
||||||
"passwordRequired": "Password is required",
|
|
||||||
"testFailed": "Connection test failed",
|
|
||||||
"saveFailed": "Failed to save MCP server",
|
|
||||||
"oauthFailed": "OAuth process failed or was cancelled",
|
|
||||||
"oauthTimeout": "OAuth process timed out, please try again"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"modals": {
|
"modals": {
|
||||||
|
|||||||
@@ -68,7 +68,6 @@
|
|||||||
"private": "Privado",
|
"private": "Privado",
|
||||||
"sync": "Sincronizar",
|
"sync": "Sincronizar",
|
||||||
"syncing": "Sincronizando...",
|
"syncing": "Sincronizando...",
|
||||||
"syncConfirmation": "¿Estás seguro de que deseas sincronizar \"{{sourceName}}\"? Esto actualizará el contenido con tu almacenamiento en la nube y puede anular cualquier edición que hayas realizado en fragmentos individuales.",
|
|
||||||
"syncFrequency": {
|
"syncFrequency": {
|
||||||
"never": "Nunca",
|
"never": "Nunca",
|
||||||
"daily": "Diario",
|
"daily": "Diario",
|
||||||
|
|||||||
@@ -68,7 +68,6 @@
|
|||||||
"private": "プライベート",
|
"private": "プライベート",
|
||||||
"sync": "同期",
|
"sync": "同期",
|
||||||
"syncing": "同期中...",
|
"syncing": "同期中...",
|
||||||
"syncConfirmation": "\"{{sourceName}}\"を同期してもよろしいですか?これにより、コンテンツがクラウドストレージで更新され、個々のチャンクに加えた編集が上書きされる可能性があります。",
|
|
||||||
"syncFrequency": {
|
"syncFrequency": {
|
||||||
"never": "なし",
|
"never": "なし",
|
||||||
"daily": "毎日",
|
"daily": "毎日",
|
||||||
|
|||||||
@@ -68,7 +68,6 @@
|
|||||||
"private": "Частный",
|
"private": "Частный",
|
||||||
"sync": "Синхронизация",
|
"sync": "Синхронизация",
|
||||||
"syncing": "Синхронизация...",
|
"syncing": "Синхронизация...",
|
||||||
"syncConfirmation": "Вы уверены, что хотите синхронизировать \"{{sourceName}}\"? Это обновит содержимое с вашим облачным хранилищем и может перезаписать любые изменения, внесенные вами в отдельные фрагменты.",
|
|
||||||
"syncFrequency": {
|
"syncFrequency": {
|
||||||
"never": "Никогда",
|
"never": "Никогда",
|
||||||
"daily": "Ежедневно",
|
"daily": "Ежедневно",
|
||||||
|
|||||||
@@ -68,7 +68,6 @@
|
|||||||
"private": "私人",
|
"private": "私人",
|
||||||
"sync": "同步",
|
"sync": "同步",
|
||||||
"syncing": "同步中...",
|
"syncing": "同步中...",
|
||||||
"syncConfirmation": "您確定要同步 \"{{sourceName}}\" 嗎?這將使用您的雲端儲存更新內容,並可能覆蓋您對個別文本塊所做的任何編輯。",
|
|
||||||
"syncFrequency": {
|
"syncFrequency": {
|
||||||
"never": "從不",
|
"never": "從不",
|
||||||
"daily": "每天",
|
"daily": "每天",
|
||||||
|
|||||||
@@ -68,7 +68,6 @@
|
|||||||
"private": "私有",
|
"private": "私有",
|
||||||
"sync": "同步",
|
"sync": "同步",
|
||||||
"syncing": "同步中...",
|
"syncing": "同步中...",
|
||||||
"syncConfirmation": "您确定要同步 \"{{sourceName}}\" 吗?这将使用您的云存储更新内容,并可能覆盖您对单个文本块所做的任何编辑。",
|
|
||||||
"syncFrequency": {
|
"syncFrequency": {
|
||||||
"never": "从不",
|
"never": "从不",
|
||||||
"daily": "每天",
|
"daily": "每天",
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ import { useOutsideAlerter } from '../hooks';
|
|||||||
import { ActiveState } from '../models/misc';
|
import { ActiveState } from '../models/misc';
|
||||||
import { selectToken } from '../preferences/preferenceSlice';
|
import { selectToken } from '../preferences/preferenceSlice';
|
||||||
import ConfigToolModal from './ConfigToolModal';
|
import ConfigToolModal from './ConfigToolModal';
|
||||||
import MCPServerModal from './MCPServerModal';
|
|
||||||
import { AvailableToolType } from './types';
|
import { AvailableToolType } from './types';
|
||||||
import WrapperComponent from './WrapperModal';
|
import WrapperComponent from './WrapperModal';
|
||||||
|
|
||||||
@@ -35,8 +34,6 @@ export default function AddToolModal({
|
|||||||
React.useState<AvailableToolType | null>(null);
|
React.useState<AvailableToolType | null>(null);
|
||||||
const [configModalState, setConfigModalState] =
|
const [configModalState, setConfigModalState] =
|
||||||
React.useState<ActiveState>('INACTIVE');
|
React.useState<ActiveState>('INACTIVE');
|
||||||
const [mcpModalState, setMcpModalState] =
|
|
||||||
React.useState<ActiveState>('INACTIVE');
|
|
||||||
const [loading, setLoading] = React.useState(false);
|
const [loading, setLoading] = React.useState(false);
|
||||||
|
|
||||||
useOutsideAlerter(modalRef, () => {
|
useOutsideAlerter(modalRef, () => {
|
||||||
@@ -89,9 +86,6 @@ export default function AddToolModal({
|
|||||||
.catch((error) => {
|
.catch((error) => {
|
||||||
console.error('Failed to create tool:', error);
|
console.error('Failed to create tool:', error);
|
||||||
});
|
});
|
||||||
} else if (tool.name === 'mcp_tool') {
|
|
||||||
setModalState('INACTIVE');
|
|
||||||
setMcpModalState('ACTIVE');
|
|
||||||
} else {
|
} else {
|
||||||
setModalState('INACTIVE');
|
setModalState('INACTIVE');
|
||||||
setConfigModalState('ACTIVE');
|
setConfigModalState('ACTIVE');
|
||||||
@@ -101,12 +95,6 @@ export default function AddToolModal({
|
|||||||
React.useEffect(() => {
|
React.useEffect(() => {
|
||||||
if (modalState === 'ACTIVE') getAvailableTools();
|
if (modalState === 'ACTIVE') getAvailableTools();
|
||||||
}, [modalState]);
|
}, [modalState]);
|
||||||
|
|
||||||
const handleMcpServerAdded = () => {
|
|
||||||
getUserTools();
|
|
||||||
setMcpModalState('INACTIVE');
|
|
||||||
};
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
{modalState === 'ACTIVE' && (
|
{modalState === 'ACTIVE' && (
|
||||||
@@ -178,11 +166,6 @@ export default function AddToolModal({
|
|||||||
tool={selectedTool}
|
tool={selectedTool}
|
||||||
getUserTools={getUserTools}
|
getUserTools={getUserTools}
|
||||||
/>
|
/>
|
||||||
<MCPServerModal
|
|
||||||
modalState={mcpModalState}
|
|
||||||
setModalState={setMcpModalState}
|
|
||||||
onServerSaved={handleMcpServerAdded}
|
|
||||||
/>
|
|
||||||
</>
|
</>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,609 +0,0 @@
|
|||||||
import { useRef, useState } from 'react';
|
|
||||||
import { useTranslation } from 'react-i18next';
|
|
||||||
import { useSelector } from 'react-redux';
|
|
||||||
|
|
||||||
import userService from '../api/services/userService';
|
|
||||||
import Dropdown from '../components/Dropdown';
|
|
||||||
import Input from '../components/Input';
|
|
||||||
import Spinner from '../components/Spinner';
|
|
||||||
import { useOutsideAlerter } from '../hooks';
|
|
||||||
import { ActiveState } from '../models/misc';
|
|
||||||
import { selectToken } from '../preferences/preferenceSlice';
|
|
||||||
import WrapperComponent from './WrapperModal';
|
|
||||||
|
|
||||||
interface MCPServerModalProps {
|
|
||||||
modalState: ActiveState;
|
|
||||||
setModalState: (state: ActiveState) => void;
|
|
||||||
server?: any;
|
|
||||||
onServerSaved: () => void;
|
|
||||||
}
|
|
||||||
|
|
||||||
const authTypes = [
|
|
||||||
{ label: 'No Authentication', value: 'none' },
|
|
||||||
{ label: 'API Key', value: 'api_key' },
|
|
||||||
{ label: 'Bearer Token', value: 'bearer' },
|
|
||||||
{ label: 'OAuth', value: 'oauth' },
|
|
||||||
// { label: 'Basic Authentication', value: 'basic' },
|
|
||||||
];
|
|
||||||
|
|
||||||
export default function MCPServerModal({
|
|
||||||
modalState,
|
|
||||||
setModalState,
|
|
||||||
server,
|
|
||||||
onServerSaved,
|
|
||||||
}: MCPServerModalProps) {
|
|
||||||
const { t } = useTranslation();
|
|
||||||
const token = useSelector(selectToken);
|
|
||||||
const modalRef = useRef<HTMLDivElement>(null);
|
|
||||||
|
|
||||||
const [formData, setFormData] = useState({
|
|
||||||
name: server?.displayName || 'My MCP Server',
|
|
||||||
server_url: server?.server_url || '',
|
|
||||||
auth_type: server?.auth_type || 'none',
|
|
||||||
api_key: '',
|
|
||||||
header_name: 'X-API-Key',
|
|
||||||
bearer_token: '',
|
|
||||||
username: '',
|
|
||||||
password: '',
|
|
||||||
timeout: server?.timeout || 30,
|
|
||||||
oauth_scopes: '',
|
|
||||||
oauth_task_id: '',
|
|
||||||
});
|
|
||||||
|
|
||||||
const [loading, setLoading] = useState(false);
|
|
||||||
const [testing, setTesting] = useState(false);
|
|
||||||
const [testResult, setTestResult] = useState<{
|
|
||||||
success: boolean;
|
|
||||||
message: string;
|
|
||||||
status?: string;
|
|
||||||
authorization_url?: string;
|
|
||||||
} | null>(null);
|
|
||||||
const [errors, setErrors] = useState<{ [key: string]: string }>({});
|
|
||||||
const oauthPopupRef = useRef<Window | null>(null);
|
|
||||||
const [oauthCompleted, setOAuthCompleted] = useState(false);
|
|
||||||
const [saveActive, setSaveActive] = useState(false);
|
|
||||||
|
|
||||||
useOutsideAlerter(modalRef, () => {
|
|
||||||
if (modalState === 'ACTIVE') {
|
|
||||||
setModalState('INACTIVE');
|
|
||||||
resetForm();
|
|
||||||
}
|
|
||||||
}, [modalState]);
|
|
||||||
|
|
||||||
const resetForm = () => {
|
|
||||||
setFormData({
|
|
||||||
name: 'My MCP Server',
|
|
||||||
server_url: '',
|
|
||||||
auth_type: 'none',
|
|
||||||
api_key: '',
|
|
||||||
header_name: 'X-API-Key',
|
|
||||||
bearer_token: '',
|
|
||||||
username: '',
|
|
||||||
password: '',
|
|
||||||
timeout: 30,
|
|
||||||
oauth_scopes: '',
|
|
||||||
oauth_task_id: '',
|
|
||||||
});
|
|
||||||
setErrors({});
|
|
||||||
setTestResult(null);
|
|
||||||
setSaveActive(false);
|
|
||||||
};
|
|
||||||
|
|
||||||
const validateForm = () => {
|
|
||||||
const requiredFields: { [key: string]: boolean } = {
|
|
||||||
name: !formData.name.trim(),
|
|
||||||
server_url: !formData.server_url.trim(),
|
|
||||||
};
|
|
||||||
|
|
||||||
const authFieldChecks: { [key: string]: () => void } = {
|
|
||||||
api_key: () => {
|
|
||||||
if (!formData.api_key.trim())
|
|
||||||
newErrors.api_key = t('settings.tools.mcp.errors.apiKeyRequired');
|
|
||||||
},
|
|
||||||
bearer: () => {
|
|
||||||
if (!formData.bearer_token.trim())
|
|
||||||
newErrors.bearer_token = t('settings.tools.mcp.errors.tokenRequired');
|
|
||||||
},
|
|
||||||
basic: () => {
|
|
||||||
if (!formData.username.trim())
|
|
||||||
newErrors.username = t('settings.tools.mcp.errors.usernameRequired');
|
|
||||||
if (!formData.password.trim())
|
|
||||||
newErrors.password = t('settings.tools.mcp.errors.passwordRequired');
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
const newErrors: { [key: string]: string } = {};
|
|
||||||
Object.entries(requiredFields).forEach(([field, isEmpty]) => {
|
|
||||||
if (isEmpty)
|
|
||||||
newErrors[field] = t(
|
|
||||||
`settings.tools.mcp.errors.${field === 'name' ? 'nameRequired' : 'urlRequired'}`,
|
|
||||||
);
|
|
||||||
});
|
|
||||||
|
|
||||||
if (formData.server_url.trim()) {
|
|
||||||
try {
|
|
||||||
new URL(formData.server_url);
|
|
||||||
} catch {
|
|
||||||
newErrors.server_url = t('settings.tools.mcp.errors.invalidUrl');
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const timeoutValue = formData.timeout === '' ? 30 : formData.timeout;
|
|
||||||
if (
|
|
||||||
typeof timeoutValue === 'number' &&
|
|
||||||
(timeoutValue < 1 || timeoutValue > 300)
|
|
||||||
)
|
|
||||||
newErrors.timeout = 'Timeout must be between 1 and 300 seconds';
|
|
||||||
|
|
||||||
if (authFieldChecks[formData.auth_type])
|
|
||||||
authFieldChecks[formData.auth_type]();
|
|
||||||
|
|
||||||
setErrors(newErrors);
|
|
||||||
return Object.keys(newErrors).length === 0;
|
|
||||||
};
|
|
||||||
|
|
||||||
const handleInputChange = (name: string, value: string | number) => {
|
|
||||||
setFormData((prev) => ({ ...prev, [name]: value }));
|
|
||||||
if (errors[name]) {
|
|
||||||
setErrors((prev) => ({ ...prev, [name]: '' }));
|
|
||||||
}
|
|
||||||
setTestResult(null);
|
|
||||||
};
|
|
||||||
|
|
||||||
const buildToolConfig = () => {
|
|
||||||
const config: any = {
|
|
||||||
server_url: formData.server_url.trim(),
|
|
||||||
auth_type: formData.auth_type,
|
|
||||||
timeout: formData.timeout === '' ? 30 : formData.timeout,
|
|
||||||
};
|
|
||||||
|
|
||||||
if (formData.auth_type === 'api_key') {
|
|
||||||
config.api_key = formData.api_key.trim();
|
|
||||||
config.api_key_header = formData.header_name.trim() || 'X-API-Key';
|
|
||||||
} else if (formData.auth_type === 'bearer') {
|
|
||||||
config.bearer_token = formData.bearer_token.trim();
|
|
||||||
} else if (formData.auth_type === 'basic') {
|
|
||||||
config.username = formData.username.trim();
|
|
||||||
config.password = formData.password.trim();
|
|
||||||
} else if (formData.auth_type === 'oauth') {
|
|
||||||
config.oauth_scopes = formData.oauth_scopes
|
|
||||||
.split(',')
|
|
||||||
.map((s) => s.trim())
|
|
||||||
.filter(Boolean);
|
|
||||||
config.oauth_task_id = formData.oauth_task_id.trim();
|
|
||||||
}
|
|
||||||
return config;
|
|
||||||
};
|
|
||||||
|
|
||||||
const pollOAuthStatus = async (
|
|
||||||
taskId: string,
|
|
||||||
onComplete: (result: any) => void,
|
|
||||||
) => {
|
|
||||||
let attempts = 0;
|
|
||||||
const maxAttempts = 60;
|
|
||||||
let popupOpened = false;
|
|
||||||
const poll = async () => {
|
|
||||||
try {
|
|
||||||
const resp = await userService.getMCPOAuthStatus(taskId, token);
|
|
||||||
const data = await resp.json();
|
|
||||||
if (data.authorization_url && !popupOpened) {
|
|
||||||
if (oauthPopupRef.current && !oauthPopupRef.current.closed) {
|
|
||||||
oauthPopupRef.current.close();
|
|
||||||
}
|
|
||||||
oauthPopupRef.current = window.open(
|
|
||||||
data.authorization_url,
|
|
||||||
'oauthPopup',
|
|
||||||
'width=600,height=700',
|
|
||||||
);
|
|
||||||
popupOpened = true;
|
|
||||||
}
|
|
||||||
if (data.status === 'completed') {
|
|
||||||
setOAuthCompleted(true);
|
|
||||||
setSaveActive(true);
|
|
||||||
onComplete({
|
|
||||||
...data,
|
|
||||||
success: true,
|
|
||||||
message: t('settings.tools.mcp.oauthCompleted'),
|
|
||||||
});
|
|
||||||
if (oauthPopupRef.current && !oauthPopupRef.current.closed) {
|
|
||||||
oauthPopupRef.current.close();
|
|
||||||
}
|
|
||||||
} else if (data.status === 'error' || data.success === false) {
|
|
||||||
setSaveActive(false);
|
|
||||||
onComplete({
|
|
||||||
...data,
|
|
||||||
success: false,
|
|
||||||
message: t('settings.tools.mcp.errors.oauthFailed'),
|
|
||||||
});
|
|
||||||
if (oauthPopupRef.current && !oauthPopupRef.current.closed) {
|
|
||||||
oauthPopupRef.current.close();
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if (++attempts < maxAttempts) setTimeout(poll, 1000);
|
|
||||||
else {
|
|
||||||
setSaveActive(false);
|
|
||||||
onComplete({
|
|
||||||
success: false,
|
|
||||||
message: t('settings.tools.mcp.errors.oauthTimeout'),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} catch {
|
|
||||||
if (++attempts < maxAttempts) setTimeout(poll, 1000);
|
|
||||||
else
|
|
||||||
onComplete({
|
|
||||||
success: false,
|
|
||||||
message: t('settings.tools.mcp.errors.oauthTimeout'),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
};
|
|
||||||
poll();
|
|
||||||
};
|
|
||||||
|
|
||||||
const testConnection = async () => {
|
|
||||||
if (!validateForm()) return;
|
|
||||||
setTesting(true);
|
|
||||||
setTestResult(null);
|
|
||||||
try {
|
|
||||||
const config = buildToolConfig();
|
|
||||||
const response = await userService.testMCPConnection({ config }, token);
|
|
||||||
const result = await response.json();
|
|
||||||
|
|
||||||
if (
|
|
||||||
formData.auth_type === 'oauth' &&
|
|
||||||
result.requires_oauth &&
|
|
||||||
result.task_id
|
|
||||||
) {
|
|
||||||
setTestResult({
|
|
||||||
success: true,
|
|
||||||
message: t('settings.tools.mcp.oauthInProgress'),
|
|
||||||
});
|
|
||||||
setOAuthCompleted(false);
|
|
||||||
setSaveActive(false);
|
|
||||||
pollOAuthStatus(result.task_id, (finalResult) => {
|
|
||||||
setTestResult(finalResult);
|
|
||||||
setFormData((prev) => ({
|
|
||||||
...prev,
|
|
||||||
oauth_task_id: result.task_id || '',
|
|
||||||
}));
|
|
||||||
setTesting(false);
|
|
||||||
});
|
|
||||||
} else {
|
|
||||||
setTestResult(result);
|
|
||||||
setSaveActive(result.success === true);
|
|
||||||
setTesting(false);
|
|
||||||
}
|
|
||||||
} catch (error) {
|
|
||||||
setTestResult({
|
|
||||||
success: false,
|
|
||||||
message: t('settings.tools.mcp.errors.testFailed'),
|
|
||||||
});
|
|
||||||
setOAuthCompleted(false);
|
|
||||||
setSaveActive(false);
|
|
||||||
setTesting(false);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
const handleSave = async () => {
|
|
||||||
if (!validateForm()) return;
|
|
||||||
setLoading(true);
|
|
||||||
try {
|
|
||||||
const config = buildToolConfig();
|
|
||||||
const serverData = {
|
|
||||||
displayName: formData.name,
|
|
||||||
config,
|
|
||||||
status: true,
|
|
||||||
...(server?.id && { id: server.id }),
|
|
||||||
};
|
|
||||||
|
|
||||||
const response = await userService.saveMCPServer(serverData, token);
|
|
||||||
const result = await response.json();
|
|
||||||
|
|
||||||
if (response.ok && result.success) {
|
|
||||||
setTestResult({
|
|
||||||
success: true,
|
|
||||||
message: result.message,
|
|
||||||
});
|
|
||||||
onServerSaved();
|
|
||||||
setModalState('INACTIVE');
|
|
||||||
resetForm();
|
|
||||||
} else {
|
|
||||||
setErrors({
|
|
||||||
general: result.error || t('settings.tools.mcp.errors.saveFailed'),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
} catch (error) {
|
|
||||||
console.error('Error saving MCP server:', error);
|
|
||||||
setErrors({ general: t('settings.tools.mcp.errors.saveFailed') });
|
|
||||||
} finally {
|
|
||||||
setLoading(false);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
const renderAuthFields = () => {
|
|
||||||
switch (formData.auth_type) {
|
|
||||||
case 'api_key':
|
|
||||||
return (
|
|
||||||
<div className="mb-10">
|
|
||||||
<div className="mt-6">
|
|
||||||
<Input
|
|
||||||
name="api_key"
|
|
||||||
type="text"
|
|
||||||
className="rounded-md"
|
|
||||||
value={formData.api_key}
|
|
||||||
onChange={(e) => handleInputChange('api_key', e.target.value)}
|
|
||||||
placeholder={t('settings.tools.mcp.placeholders.apiKey')}
|
|
||||||
borderVariant="thin"
|
|
||||||
labelBgClassName="bg-white dark:bg-charleston-green-2"
|
|
||||||
/>
|
|
||||||
{errors.api_key && (
|
|
||||||
<p className="mt-1 text-sm text-red-600">{errors.api_key}</p>
|
|
||||||
)}
|
|
||||||
</div>
|
|
||||||
<div className="mt-5">
|
|
||||||
<Input
|
|
||||||
name="header_name"
|
|
||||||
type="text"
|
|
||||||
className="rounded-md"
|
|
||||||
value={formData.header_name}
|
|
||||||
onChange={(e) =>
|
|
||||||
handleInputChange('header_name', e.target.value)
|
|
||||||
}
|
|
||||||
placeholder={t('settings.tools.mcp.headerName')}
|
|
||||||
borderVariant="thin"
|
|
||||||
labelBgClassName="bg-white dark:bg-charleston-green-2"
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
case 'bearer':
|
|
||||||
return (
|
|
||||||
<div className="mb-10">
|
|
||||||
<Input
|
|
||||||
name="bearer_token"
|
|
||||||
type="text"
|
|
||||||
className="rounded-md"
|
|
||||||
value={formData.bearer_token}
|
|
||||||
onChange={(e) =>
|
|
||||||
handleInputChange('bearer_token', e.target.value)
|
|
||||||
}
|
|
||||||
placeholder={t('settings.tools.mcp.placeholders.bearerToken')}
|
|
||||||
borderVariant="thin"
|
|
||||||
labelBgClassName="bg-white dark:bg-charleston-green-2"
|
|
||||||
/>
|
|
||||||
{errors.bearer_token && (
|
|
||||||
<p className="mt-1 text-sm text-red-600">{errors.bearer_token}</p>
|
|
||||||
)}
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
case 'basic':
|
|
||||||
return (
|
|
||||||
<div className="mb-10">
|
|
||||||
<div className="mt-6">
|
|
||||||
<Input
|
|
||||||
name="username"
|
|
||||||
type="text"
|
|
||||||
className="rounded-md"
|
|
||||||
value={formData.username}
|
|
||||||
onChange={(e) => handleInputChange('username', e.target.value)}
|
|
||||||
placeholder={t('settings.tools.mcp.username')}
|
|
||||||
borderVariant="thin"
|
|
||||||
labelBgClassName="bg-white dark:bg-charleston-green-2"
|
|
||||||
/>
|
|
||||||
{errors.username && (
|
|
||||||
<p className="mt-1 text-sm text-red-600">{errors.username}</p>
|
|
||||||
)}
|
|
||||||
</div>
|
|
||||||
<div className="mt-5">
|
|
||||||
<Input
|
|
||||||
name="password"
|
|
||||||
type="text"
|
|
||||||
className="rounded-md"
|
|
||||||
value={formData.password}
|
|
||||||
onChange={(e) => handleInputChange('password', e.target.value)}
|
|
||||||
placeholder={t('settings.tools.mcp.password')}
|
|
||||||
borderVariant="thin"
|
|
||||||
labelBgClassName="bg-white dark:bg-charleston-green-2"
|
|
||||||
/>
|
|
||||||
{errors.password && (
|
|
||||||
<p className="mt-1 text-sm text-red-600">{errors.password}</p>
|
|
||||||
)}
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
case 'oauth':
|
|
||||||
return (
|
|
||||||
<div className="mb-10">
|
|
||||||
<div className="mt-6">
|
|
||||||
<Input
|
|
||||||
name="oauth_scopes"
|
|
||||||
type="text"
|
|
||||||
className="rounded-md"
|
|
||||||
value={formData.oauth_scopes}
|
|
||||||
onChange={(e) =>
|
|
||||||
handleInputChange('oauth_scopes', e.target.value)
|
|
||||||
}
|
|
||||||
placeholder={
|
|
||||||
t('settings.tools.mcp.placeholders.oauthScopes') ||
|
|
||||||
'Scopes (comma separated)'
|
|
||||||
}
|
|
||||||
borderVariant="thin"
|
|
||||||
labelBgClassName="bg-white dark:bg-charleston-green-2"
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
default:
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
return (
|
|
||||||
modalState === 'ACTIVE' && (
|
|
||||||
<WrapperComponent
|
|
||||||
close={() => {
|
|
||||||
setModalState('INACTIVE');
|
|
||||||
resetForm();
|
|
||||||
}}
|
|
||||||
className="max-w-[600px] md:w-[80vw] lg:w-[60vw]"
|
|
||||||
>
|
|
||||||
<div className="flex h-full flex-col">
|
|
||||||
<div className="px-6 py-4">
|
|
||||||
<h2 className="text-jet dark:text-bright-gray text-xl font-semibold">
|
|
||||||
{server
|
|
||||||
? t('settings.tools.mcp.editServer')
|
|
||||||
: t('settings.tools.mcp.addServer')}
|
|
||||||
</h2>
|
|
||||||
</div>
|
|
||||||
<div className="flex-1 px-6">
|
|
||||||
<div className="space-y-6 py-6">
|
|
||||||
<div>
|
|
||||||
<Input
|
|
||||||
type="text"
|
|
||||||
className="rounded-md"
|
|
||||||
value={formData.name}
|
|
||||||
onChange={(e) => handleInputChange('name', e.target.value)}
|
|
||||||
borderVariant="thin"
|
|
||||||
placeholder={t('settings.tools.mcp.serverName')}
|
|
||||||
labelBgClassName="bg-white dark:bg-charleston-green-2"
|
|
||||||
/>
|
|
||||||
{errors.name && (
|
|
||||||
<p className="mt-1 text-sm text-red-600">{errors.name}</p>
|
|
||||||
)}
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div>
|
|
||||||
<Input
|
|
||||||
name="server_url"
|
|
||||||
type="text"
|
|
||||||
className="rounded-md"
|
|
||||||
value={formData.server_url}
|
|
||||||
onChange={(e) =>
|
|
||||||
handleInputChange('server_url', e.target.value)
|
|
||||||
}
|
|
||||||
placeholder={t('settings.tools.mcp.serverUrl')}
|
|
||||||
borderVariant="thin"
|
|
||||||
labelBgClassName="bg-white dark:bg-charleston-green-2"
|
|
||||||
/>
|
|
||||||
{errors.server_url && (
|
|
||||||
<p className="mt-1 text-sm text-red-600">
|
|
||||||
{errors.server_url}
|
|
||||||
</p>
|
|
||||||
)}
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<Dropdown
|
|
||||||
placeholder={t('settings.tools.mcp.authType')}
|
|
||||||
selectedValue={
|
|
||||||
authTypes.find((type) => type.value === formData.auth_type)
|
|
||||||
?.label || null
|
|
||||||
}
|
|
||||||
onSelect={(selection: { label: string; value: string }) => {
|
|
||||||
handleInputChange('auth_type', selection.value);
|
|
||||||
}}
|
|
||||||
options={authTypes}
|
|
||||||
size="w-full"
|
|
||||||
rounded="3xl"
|
|
||||||
border="border"
|
|
||||||
/>
|
|
||||||
|
|
||||||
{renderAuthFields()}
|
|
||||||
|
|
||||||
<div>
|
|
||||||
<Input
|
|
||||||
name="timeout"
|
|
||||||
type="number"
|
|
||||||
className="rounded-md"
|
|
||||||
value={formData.timeout}
|
|
||||||
onChange={(e) => {
|
|
||||||
const value = e.target.value;
|
|
||||||
if (value === '') {
|
|
||||||
handleInputChange('timeout', '');
|
|
||||||
} else {
|
|
||||||
const numValue = parseInt(value);
|
|
||||||
if (!isNaN(numValue) && numValue >= 1) {
|
|
||||||
handleInputChange('timeout', numValue);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}}
|
|
||||||
placeholder={t('settings.tools.mcp.timeout')}
|
|
||||||
borderVariant="thin"
|
|
||||||
labelBgClassName="bg-white dark:bg-charleston-green-2"
|
|
||||||
/>
|
|
||||||
{errors.timeout && (
|
|
||||||
<p className="mt-2 text-sm text-red-600">{errors.timeout}</p>
|
|
||||||
)}
|
|
||||||
</div>
|
|
||||||
|
|
||||||
{testResult && (
|
|
||||||
<div
|
|
||||||
className={`rounded-2xl p-5 ${
|
|
||||||
testResult.success
|
|
||||||
? 'bg-green-50 text-green-700 dark:bg-green-900/40 dark:text-green-300'
|
|
||||||
: 'bg-red-50 text-red-700 dark:bg-red-900 dark:text-red-300'
|
|
||||||
}`}
|
|
||||||
>
|
|
||||||
{testResult.message}
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
{errors.general && (
|
|
||||||
<div className="rounded-2xl bg-red-50 p-5 text-red-700 dark:bg-red-900 dark:text-red-300">
|
|
||||||
{errors.general}
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div className="px-6 py-2">
|
|
||||||
<div className="flex flex-col gap-4 sm:flex-row sm:justify-between">
|
|
||||||
<button
|
|
||||||
onClick={testConnection}
|
|
||||||
disabled={testing}
|
|
||||||
className="border-silver dark:border-dim-gray dark:text-light-gray w-full rounded-3xl border px-6 py-2 text-sm font-medium transition-all hover:bg-gray-100 disabled:opacity-50 sm:w-auto dark:hover:bg-[#767183]/50"
|
|
||||||
>
|
|
||||||
{testing ? (
|
|
||||||
<div className="flex items-center justify-center">
|
|
||||||
<Spinner size="small" />
|
|
||||||
<span className="ml-2">
|
|
||||||
{t('settings.tools.mcp.testing')}
|
|
||||||
</span>
|
|
||||||
</div>
|
|
||||||
) : (
|
|
||||||
t('settings.tools.mcp.testConnection')
|
|
||||||
)}
|
|
||||||
</button>
|
|
||||||
|
|
||||||
<div className="flex flex-col-reverse gap-3 sm:flex-row sm:gap-3">
|
|
||||||
<button
|
|
||||||
onClick={() => {
|
|
||||||
setModalState('INACTIVE');
|
|
||||||
resetForm();
|
|
||||||
}}
|
|
||||||
className="dark:text-light-gray w-full cursor-pointer rounded-3xl px-6 py-2 text-sm font-medium hover:bg-gray-100 sm:w-auto dark:bg-transparent dark:hover:bg-[#767183]/50"
|
|
||||||
>
|
|
||||||
{t('settings.tools.mcp.cancel')}
|
|
||||||
</button>
|
|
||||||
<button
|
|
||||||
onClick={handleSave}
|
|
||||||
disabled={loading || !saveActive}
|
|
||||||
className="bg-purple-30 hover:bg-violets-are-blue w-full rounded-3xl px-6 py-2 text-sm font-medium text-white transition-all disabled:opacity-50 sm:w-auto"
|
|
||||||
>
|
|
||||||
{loading ? (
|
|
||||||
<div className="flex items-center justify-center">
|
|
||||||
<Spinner size="small" />
|
|
||||||
<span className="ml-2">
|
|
||||||
{t('settings.tools.mcp.saving')}
|
|
||||||
</span>
|
|
||||||
</div>
|
|
||||||
) : (
|
|
||||||
t('settings.tools.mcp.save')
|
|
||||||
)}
|
|
||||||
</button>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</WrapperComponent>
|
|
||||||
)
|
|
||||||
);
|
|
||||||
}
|
|
||||||
@@ -60,7 +60,7 @@ export const ShareConversationModal = ({
|
|||||||
const [sourcePath, setSourcePath] = useState<{
|
const [sourcePath, setSourcePath] = useState<{
|
||||||
label: string;
|
label: string;
|
||||||
value: string;
|
value: string;
|
||||||
} | null>(preSelectedDoc ? extractDocPaths(preSelectedDoc)[0] : null);
|
} | null>(preSelectedDoc ? extractDocPaths([preSelectedDoc])[0] : null);
|
||||||
|
|
||||||
const handleCopyKey = (url: string) => {
|
const handleCopyKey = (url: string) => {
|
||||||
navigator.clipboard.writeText(url);
|
navigator.clipboard.writeText(url);
|
||||||
@@ -105,14 +105,14 @@ export const ShareConversationModal = ({
|
|||||||
return (
|
return (
|
||||||
<WrapperModal close={close}>
|
<WrapperModal close={close}>
|
||||||
<div className="flex flex-col gap-2">
|
<div className="flex flex-col gap-2">
|
||||||
<h2 className="text-eerie-black dark:text-chinese-white text-xl font-medium">
|
<h2 className="text-xl font-medium text-eerie-black dark:text-chinese-white">
|
||||||
{t('modals.shareConv.label')}
|
{t('modals.shareConv.label')}
|
||||||
</h2>
|
</h2>
|
||||||
<p className="text-eerie-black dark:text-silver/60 text-sm">
|
<p className="text-sm text-eerie-black dark:text-silver/60">
|
||||||
{t('modals.shareConv.note')}
|
{t('modals.shareConv.note')}
|
||||||
</p>
|
</p>
|
||||||
<div className="flex items-center justify-between">
|
<div className="flex items-center justify-between">
|
||||||
<span className="text-eerie-black text-lg dark:text-white">
|
<span className="text-lg text-eerie-black dark:text-white">
|
||||||
{t('modals.shareConv.option')}
|
{t('modals.shareConv.option')}
|
||||||
</span>
|
</span>
|
||||||
<ToggleSwitch
|
<ToggleSwitch
|
||||||
@@ -136,19 +136,19 @@ export const ShareConversationModal = ({
|
|||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
<div className="flex items-baseline justify-between gap-2">
|
<div className="flex items-baseline justify-between gap-2">
|
||||||
<span className="no-scrollbar border-silver text-eerie-black dark:border-silver/40 w-full overflow-x-auto rounded-full border-2 px-4 py-3 whitespace-nowrap dark:text-white">
|
<span className="no-scrollbar w-full overflow-x-auto whitespace-nowrap rounded-full border-2 border-silver px-4 py-3 text-eerie-black dark:border-silver/40 dark:text-white">
|
||||||
{`${domain}/share/${identifier ?? '....'}`}
|
{`${domain}/share/${identifier ?? '....'}`}
|
||||||
</span>
|
</span>
|
||||||
{status === 'fetched' ? (
|
{status === 'fetched' ? (
|
||||||
<button
|
<button
|
||||||
className="bg-purple-30 hover:bg-violets-are-blue my-1 h-10 w-28 rounded-full p-2 text-sm text-white"
|
className="my-1 h-10 w-28 rounded-full bg-purple-30 p-2 text-sm text-white hover:bg-violets-are-blue"
|
||||||
onClick={() => handleCopyKey(`${domain}/share/${identifier}`)}
|
onClick={() => handleCopyKey(`${domain}/share/${identifier}`)}
|
||||||
>
|
>
|
||||||
{isCopied ? t('modals.saveKey.copied') : t('modals.saveKey.copy')}
|
{isCopied ? t('modals.saveKey.copied') : t('modals.saveKey.copy')}
|
||||||
</button>
|
</button>
|
||||||
) : (
|
) : (
|
||||||
<button
|
<button
|
||||||
className="bg-purple-30 hover:bg-violets-are-blue my-1 flex h-10 w-28 items-center justify-evenly rounded-full p-2 text-center text-sm font-normal text-white"
|
className="my-1 flex h-10 w-28 items-center justify-evenly rounded-full bg-purple-30 p-2 text-center text-sm font-normal text-white hover:bg-violets-are-blue"
|
||||||
onClick={() => {
|
onClick={() => {
|
||||||
shareCoversationPublicly(allowPrompt);
|
shareCoversationPublicly(allowPrompt);
|
||||||
}}
|
}}
|
||||||
|
|||||||
@@ -42,10 +42,10 @@ export default function WrapperModal({
|
|||||||
}, [close, isPerformingTask]);
|
}, [close, isPerformingTask]);
|
||||||
|
|
||||||
const modalContent = (
|
const modalContent = (
|
||||||
<div className="fixed top-0 left-0 z-30 flex h-screen w-screen items-center justify-center">
|
<div className="bg-gray-alpha bg-opacity-50 fixed top-0 left-0 z-30 flex h-screen w-screen items-center justify-center">
|
||||||
<div
|
<div
|
||||||
ref={modalRef}
|
ref={modalRef}
|
||||||
className={`relative rounded-2xl bg-white dark:bg-[#26272E] p-8 shadow-[0px_4px_40px_-3px_#0000001A] ${className}`}
|
className={`relative w-11/12 rounded-2xl bg-white p-8 sm:w-[512px] dark:bg-[#26272E] ${className}`}
|
||||||
>
|
>
|
||||||
{!isPerformingTask && (
|
{!isPerformingTask && (
|
||||||
<button
|
<button
|
||||||
@@ -55,7 +55,7 @@ export default function WrapperModal({
|
|||||||
<img className="filter dark:invert" src={Exit} alt="Close" />
|
<img className="filter dark:invert" src={Exit} alt="Close" />
|
||||||
</button>
|
</button>
|
||||||
)}
|
)}
|
||||||
<div className={`overflow-y-auto no-scrollbar text-[#18181B] dark:text-[#ECECF1] ${contentClassName}`}>{children}</div>
|
<div className={`${contentClassName}`}>{children}</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
|
|||||||
@@ -90,9 +90,9 @@ export function getLocalApiKey(): string | null {
|
|||||||
return key;
|
return key;
|
||||||
}
|
}
|
||||||
|
|
||||||
export function getLocalRecentDocs(): Doc[] | null {
|
export function getLocalRecentDocs(): string | null {
|
||||||
const docs = localStorage.getItem('DocsGPTRecentDocs');
|
const doc = localStorage.getItem('DocsGPTRecentDocs');
|
||||||
return docs ? (JSON.parse(docs) as Doc[]) : null;
|
return doc;
|
||||||
}
|
}
|
||||||
|
|
||||||
export function getLocalPrompt(): string | null {
|
export function getLocalPrompt(): string | null {
|
||||||
@@ -108,20 +108,19 @@ export function setLocalPrompt(prompt: string): void {
|
|||||||
localStorage.setItem('DocsGPTPrompt', prompt);
|
localStorage.setItem('DocsGPTPrompt', prompt);
|
||||||
}
|
}
|
||||||
|
|
||||||
export function setLocalRecentDocs(docs: Doc[] | null): void {
|
export function setLocalRecentDocs(doc: Doc | null): void {
|
||||||
if (docs && docs.length > 0) {
|
localStorage.setItem('DocsGPTRecentDocs', JSON.stringify(doc));
|
||||||
localStorage.setItem('DocsGPTRecentDocs', JSON.stringify(docs));
|
|
||||||
|
|
||||||
docs.forEach((doc) => {
|
let docPath = 'default';
|
||||||
let docPath = 'default';
|
if (doc?.type === 'local') {
|
||||||
if (doc.type === 'local') {
|
docPath = 'local' + '/' + doc.name + '/';
|
||||||
docPath = 'local' + '/' + doc.name + '/';
|
|
||||||
}
|
|
||||||
userService
|
|
||||||
.checkDocs({ docs: docPath }, null)
|
|
||||||
.then((response) => response.json());
|
|
||||||
});
|
|
||||||
} else {
|
|
||||||
localStorage.removeItem('DocsGPTRecentDocs');
|
|
||||||
}
|
}
|
||||||
|
userService
|
||||||
|
.checkDocs(
|
||||||
|
{
|
||||||
|
docs: docPath,
|
||||||
|
},
|
||||||
|
null,
|
||||||
|
)
|
||||||
|
.then((response) => response.json());
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ export interface Preference {
|
|||||||
prompt: { name: string; id: string; type: string };
|
prompt: { name: string; id: string; type: string };
|
||||||
chunks: string;
|
chunks: string;
|
||||||
token_limit: number;
|
token_limit: number;
|
||||||
selectedDocs: Doc[] | null;
|
selectedDocs: Doc | null;
|
||||||
sourceDocs: Doc[] | null;
|
sourceDocs: Doc[] | null;
|
||||||
conversations: {
|
conversations: {
|
||||||
data: { name: string; id: string }[] | null;
|
data: { name: string; id: string }[] | null;
|
||||||
@@ -34,16 +34,15 @@ const initialState: Preference = {
|
|||||||
prompt: { name: 'default', id: 'default', type: 'public' },
|
prompt: { name: 'default', id: 'default', type: 'public' },
|
||||||
chunks: '2',
|
chunks: '2',
|
||||||
token_limit: 2000,
|
token_limit: 2000,
|
||||||
selectedDocs: [
|
selectedDocs: {
|
||||||
{
|
id: 'default',
|
||||||
id: 'default',
|
name: 'default',
|
||||||
name: 'default',
|
type: 'remote',
|
||||||
type: 'remote',
|
date: 'default',
|
||||||
date: 'default',
|
docLink: 'default',
|
||||||
model: 'openai_text-embedding-ada-002',
|
model: 'openai_text-embedding-ada-002',
|
||||||
retriever: 'classic',
|
retriever: 'classic',
|
||||||
},
|
} as Doc,
|
||||||
] as Doc[],
|
|
||||||
sourceDocs: null,
|
sourceDocs: null,
|
||||||
conversations: {
|
conversations: {
|
||||||
data: null,
|
data: null,
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
|
||||||
import React, { useCallback, useEffect, useRef, useState } from 'react';
|
import React, { useCallback, useEffect, useRef, useState } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { useDispatch, useSelector } from 'react-redux';
|
import { useDispatch, useSelector } from 'react-redux';
|
||||||
@@ -271,27 +272,27 @@ export default function Sources({
|
|||||||
|
|
||||||
return documentToView ? (
|
return documentToView ? (
|
||||||
<div className="mt-8 flex flex-col">
|
<div className="mt-8 flex flex-col">
|
||||||
{documentToView.isNested ? (
|
{documentToView.isNested ? (
|
||||||
documentToView.type === 'connector:file' ? (
|
documentToView.type === 'connector' ? (
|
||||||
<ConnectorTreeComponent
|
<ConnectorTreeComponent
|
||||||
docId={documentToView.id || ''}
|
docId={documentToView.id || ''}
|
||||||
sourceName={documentToView.name}
|
sourceName={documentToView.name}
|
||||||
onBackToDocuments={() => setDocumentToView(undefined)}
|
onBackToDocuments={() => setDocumentToView(undefined)}
|
||||||
/>
|
|
||||||
) : (
|
|
||||||
<FileTreeComponent
|
|
||||||
docId={documentToView.id || ''}
|
|
||||||
sourceName={documentToView.name}
|
|
||||||
onBackToDocuments={() => setDocumentToView(undefined)}
|
|
||||||
/>
|
|
||||||
)
|
|
||||||
) : (
|
|
||||||
<Chunks
|
|
||||||
documentId={documentToView.id || ''}
|
|
||||||
documentName={documentToView.name}
|
|
||||||
handleGoBack={() => setDocumentToView(undefined)}
|
|
||||||
/>
|
/>
|
||||||
)}
|
) : (
|
||||||
|
<FileTreeComponent
|
||||||
|
docId={documentToView.id || ''}
|
||||||
|
sourceName={documentToView.name}
|
||||||
|
onBackToDocuments={() => setDocumentToView(undefined)}
|
||||||
|
/>
|
||||||
|
)
|
||||||
|
) : (
|
||||||
|
<Chunks
|
||||||
|
documentId={documentToView.id || ''}
|
||||||
|
documentName={documentToView.name}
|
||||||
|
handleGoBack={() => setDocumentToView(undefined)}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
</div>
|
</div>
|
||||||
) : (
|
) : (
|
||||||
<div className="mt-8 flex w-full max-w-full flex-col overflow-hidden">
|
<div className="mt-8 flex w-full max-w-full flex-col overflow-hidden">
|
||||||
@@ -318,7 +319,7 @@ export default function Sources({
|
|||||||
setSearchTerm(e.target.value);
|
setSearchTerm(e.target.value);
|
||||||
setCurrentPage(1);
|
setCurrentPage(1);
|
||||||
}}
|
}}
|
||||||
className="border-silver dark:border-silver/40 text-jet dark:text-bright-gray focus:border-silver dark:focus:border-silver/60 h-[32px] w-full rounded-full border bg-transparent px-3 text-sm outline-none placeholder:text-gray-400 dark:placeholder:text-gray-500"
|
className="w-full h-[32px] rounded-full border border-silver dark:border-silver/40 bg-transparent px-3 text-sm text-jet dark:text-bright-gray placeholder:text-gray-400 dark:placeholder:text-gray-500 outline-none focus:border-silver dark:focus:border-silver/60"
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
@@ -335,7 +336,7 @@ export default function Sources({
|
|||||||
</div>
|
</div>
|
||||||
<div className="relative w-full">
|
<div className="relative w-full">
|
||||||
{loading ? (
|
{loading ? (
|
||||||
<div className="grid w-full grid-cols-1 gap-6 px-2 py-4 sm:grid-cols-2 lg:grid-cols-3 xl:grid-cols-4">
|
<div className="w-full grid grid-cols-1 sm:grid-cols-2 lg:grid-cols-3 xl:grid-cols-4 gap-6 px-2 py-4">
|
||||||
<SkeletonLoader component="sourceCards" count={rowsPerPage} />
|
<SkeletonLoader component="sourceCards" count={rowsPerPage} />
|
||||||
</div>
|
</div>
|
||||||
) : !currentDocuments?.length ? (
|
) : !currentDocuments?.length ? (
|
||||||
@@ -350,19 +351,19 @@ export default function Sources({
|
|||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
) : (
|
) : (
|
||||||
<div className="grid w-full grid-cols-1 gap-6 px-2 py-4 sm:grid-cols-2 lg:grid-cols-3 xl:grid-cols-4">
|
<div className="w-full grid grid-cols-1 sm:grid-cols-2 lg:grid-cols-3 xl:grid-cols-4 gap-6 px-2 py-4">
|
||||||
{currentDocuments.map((document, index) => {
|
{currentDocuments.map((document, index) => {
|
||||||
const docId = document.id ? document.id.toString() : '';
|
const docId = document.id ? document.id.toString() : '';
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div key={docId} className="relative">
|
<div key={docId} className="relative">
|
||||||
<div
|
<div
|
||||||
className={`flex h-[130px] w-full flex-col rounded-2xl bg-[#F9F9F9] p-3 transition-all duration-200 dark:bg-[#383838] ${
|
className={`flex h-[130px] w-full flex-col rounded-2xl bg-[#F9F9F9] p-3 transition-all duration-200 dark:bg-[#383838] ${
|
||||||
activeMenuId === docId || syncMenuState.docId === docId
|
activeMenuId === docId || syncMenuState.docId === docId
|
||||||
? 'scale-[1.05]'
|
? 'scale-[1.05]'
|
||||||
: 'hover:scale-[1.05]'
|
: 'hover:scale-[1.05]'
|
||||||
}`}
|
}`}
|
||||||
>
|
>
|
||||||
<div className="w-full flex-1">
|
<div className="w-full flex-1">
|
||||||
<div className="flex w-full items-center justify-between gap-2">
|
<div className="flex w-full items-center justify-between gap-2">
|
||||||
<h3
|
<h3
|
||||||
@@ -426,7 +427,7 @@ export default function Sources({
|
|||||||
<img
|
<img
|
||||||
src={CalendarIcon}
|
src={CalendarIcon}
|
||||||
alt=""
|
alt=""
|
||||||
className="h-[14px] w-[14px]"
|
className="w-[14px] h-[14px]"
|
||||||
/>
|
/>
|
||||||
<span className="font-inter text-[12px] leading-[18px] font-[500] text-[#848484] dark:text-[#848484]">
|
<span className="font-inter text-[12px] leading-[18px] font-[500] text-[#848484] dark:text-[#848484]">
|
||||||
{document.date ? formatDate(document.date) : ''}
|
{document.date ? formatDate(document.date) : ''}
|
||||||
@@ -436,7 +437,7 @@ export default function Sources({
|
|||||||
<img
|
<img
|
||||||
src={DiscIcon}
|
src={DiscIcon}
|
||||||
alt=""
|
alt=""
|
||||||
className="h-[14px] w-[14px]"
|
className="w-[14px] h-[14px]"
|
||||||
/>
|
/>
|
||||||
<span className="font-inter text-[12px] leading-[18px] font-[500] text-[#848484] dark:text-[#848484]">
|
<span className="font-inter text-[12px] leading-[18px] font-[500] text-[#848484] dark:text-[#848484]">
|
||||||
{document.tokens
|
{document.tokens
|
||||||
|
|||||||
@@ -30,22 +30,9 @@ export default function ToolConfig({
|
|||||||
handleGoBack: () => void;
|
handleGoBack: () => void;
|
||||||
}) {
|
}) {
|
||||||
const token = useSelector(selectToken);
|
const token = useSelector(selectToken);
|
||||||
const [authKey, setAuthKey] = React.useState<string>(() => {
|
const [authKey, setAuthKey] = React.useState<string>(
|
||||||
if (tool.name === 'mcp_tool') {
|
'token' in tool.config ? tool.config.token : '',
|
||||||
const config = tool.config as any;
|
);
|
||||||
if (config.auth_type === 'api_key') {
|
|
||||||
return config.api_key || '';
|
|
||||||
} else if (config.auth_type === 'bearer') {
|
|
||||||
return config.encrypted_token || '';
|
|
||||||
} else if (config.auth_type === 'basic') {
|
|
||||||
return config.password || '';
|
|
||||||
}
|
|
||||||
return '';
|
|
||||||
} else if ('token' in tool.config) {
|
|
||||||
return tool.config.token;
|
|
||||||
}
|
|
||||||
return '';
|
|
||||||
});
|
|
||||||
const [customName, setCustomName] = React.useState<string>(
|
const [customName, setCustomName] = React.useState<string>(
|
||||||
tool.customName || '',
|
tool.customName || '',
|
||||||
);
|
);
|
||||||
@@ -110,26 +97,6 @@ export default function ToolConfig({
|
|||||||
};
|
};
|
||||||
|
|
||||||
const handleSaveChanges = () => {
|
const handleSaveChanges = () => {
|
||||||
let configToSave;
|
|
||||||
if (tool.name === 'api_tool') {
|
|
||||||
configToSave = tool.config;
|
|
||||||
} else if (tool.name === 'mcp_tool') {
|
|
||||||
configToSave = { ...tool.config } as any;
|
|
||||||
const mcpConfig = tool.config as any;
|
|
||||||
|
|
||||||
if (authKey.trim()) {
|
|
||||||
if (mcpConfig.auth_type === 'api_key') {
|
|
||||||
configToSave.api_key = authKey;
|
|
||||||
} else if (mcpConfig.auth_type === 'bearer') {
|
|
||||||
configToSave.encrypted_token = authKey;
|
|
||||||
} else if (mcpConfig.auth_type === 'basic') {
|
|
||||||
configToSave.password = authKey;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
configToSave = { token: authKey };
|
|
||||||
}
|
|
||||||
|
|
||||||
userService
|
userService
|
||||||
.updateTool(
|
.updateTool(
|
||||||
{
|
{
|
||||||
@@ -138,7 +105,7 @@ export default function ToolConfig({
|
|||||||
displayName: tool.displayName,
|
displayName: tool.displayName,
|
||||||
customName: customName,
|
customName: customName,
|
||||||
description: tool.description,
|
description: tool.description,
|
||||||
config: configToSave,
|
config: tool.name === 'api_tool' ? tool.config : { token: authKey },
|
||||||
actions: 'actions' in tool ? tool.actions : [],
|
actions: 'actions' in tool ? tool.actions : [],
|
||||||
status: tool.status,
|
status: tool.status,
|
||||||
},
|
},
|
||||||
@@ -229,15 +196,7 @@ export default function ToolConfig({
|
|||||||
<div className="mt-1">
|
<div className="mt-1">
|
||||||
{Object.keys(tool?.config).length !== 0 && tool.name !== 'api_tool' && (
|
{Object.keys(tool?.config).length !== 0 && tool.name !== 'api_tool' && (
|
||||||
<p className="text-eerie-black dark:text-bright-gray text-sm font-semibold">
|
<p className="text-eerie-black dark:text-bright-gray text-sm font-semibold">
|
||||||
{tool.name === 'mcp_tool'
|
{t('settings.tools.authentication')}
|
||||||
? (tool.config as any)?.auth_type === 'bearer'
|
|
||||||
? 'Bearer Token'
|
|
||||||
: (tool.config as any)?.auth_type === 'api_key'
|
|
||||||
? 'API Key'
|
|
||||||
: (tool.config as any)?.auth_type === 'basic'
|
|
||||||
? 'Password'
|
|
||||||
: t('settings.tools.authentication')
|
|
||||||
: t('settings.tools.authentication')}
|
|
||||||
</p>
|
</p>
|
||||||
)}
|
)}
|
||||||
<div className="mt-4 flex flex-col items-start gap-2 sm:flex-row sm:items-center">
|
<div className="mt-4 flex flex-col items-start gap-2 sm:flex-row sm:items-center">
|
||||||
@@ -249,17 +208,7 @@ export default function ToolConfig({
|
|||||||
value={authKey}
|
value={authKey}
|
||||||
onChange={(e) => setAuthKey(e.target.value)}
|
onChange={(e) => setAuthKey(e.target.value)}
|
||||||
borderVariant="thin"
|
borderVariant="thin"
|
||||||
placeholder={
|
placeholder={t('modals.configTool.apiKeyPlaceholder')}
|
||||||
tool.name === 'mcp_tool'
|
|
||||||
? (tool.config as any)?.auth_type === 'bearer'
|
|
||||||
? 'Bearer Token'
|
|
||||||
: (tool.config as any)?.auth_type === 'api_key'
|
|
||||||
? 'API Key'
|
|
||||||
: (tool.config as any)?.auth_type === 'basic'
|
|
||||||
? 'Password'
|
|
||||||
: t('modals.configTool.apiKeyPlaceholder')
|
|
||||||
: t('modals.configTool.apiKeyPlaceholder')
|
|
||||||
}
|
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
@@ -501,26 +450,6 @@ export default function ToolConfig({
|
|||||||
setModalState={(state) => setShowUnsavedModal(state === 'ACTIVE')}
|
setModalState={(state) => setShowUnsavedModal(state === 'ACTIVE')}
|
||||||
submitLabel={t('settings.tools.saveAndLeave')}
|
submitLabel={t('settings.tools.saveAndLeave')}
|
||||||
handleSubmit={() => {
|
handleSubmit={() => {
|
||||||
let configToSave;
|
|
||||||
if (tool.name === 'api_tool') {
|
|
||||||
configToSave = tool.config;
|
|
||||||
} else if (tool.name === 'mcp_tool') {
|
|
||||||
configToSave = { ...tool.config } as any;
|
|
||||||
const mcpConfig = tool.config as any;
|
|
||||||
|
|
||||||
if (authKey.trim()) {
|
|
||||||
if (mcpConfig.auth_type === 'api_key') {
|
|
||||||
configToSave.api_key = authKey;
|
|
||||||
} else if (mcpConfig.auth_type === 'bearer') {
|
|
||||||
configToSave.encrypted_token = authKey;
|
|
||||||
} else if (mcpConfig.auth_type === 'basic') {
|
|
||||||
configToSave.password = authKey;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
configToSave = { token: authKey };
|
|
||||||
}
|
|
||||||
|
|
||||||
userService
|
userService
|
||||||
.updateTool(
|
.updateTool(
|
||||||
{
|
{
|
||||||
@@ -529,7 +458,10 @@ export default function ToolConfig({
|
|||||||
displayName: tool.displayName,
|
displayName: tool.displayName,
|
||||||
customName: customName,
|
customName: customName,
|
||||||
description: tool.description,
|
description: tool.description,
|
||||||
config: configToSave,
|
config:
|
||||||
|
tool.name === 'api_tool'
|
||||||
|
? tool.config
|
||||||
|
: { token: authKey },
|
||||||
actions: 'actions' in tool ? tool.actions : [],
|
actions: 'actions' in tool ? tool.actions : [],
|
||||||
status: tool.status,
|
status: tool.status,
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -1,16 +1,45 @@
|
|||||||
import CrawlerIcon from '../../assets/crawler.svg';
|
export interface BaseIngestorConfig {
|
||||||
import FileUploadIcon from '../../assets/file_upload.svg';
|
[key: string]: string | number | boolean | undefined;
|
||||||
import UrlIcon from '../../assets/url.svg';
|
}
|
||||||
import GithubIcon from '../../assets/github.svg';
|
|
||||||
import RedditIcon from '../../assets/reddit.svg';
|
|
||||||
import DriveIcon from '../../assets/drive.svg';
|
|
||||||
|
|
||||||
export type IngestorType = 'crawler' | 'github' | 'reddit' | 'url' | 'google_drive' | 'local_file';
|
export interface RedditIngestorConfig extends BaseIngestorConfig {
|
||||||
|
client_id: string;
|
||||||
|
client_secret: string;
|
||||||
|
user_agent: string;
|
||||||
|
search_queries: string;
|
||||||
|
number_posts: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface GithubIngestorConfig extends BaseIngestorConfig {
|
||||||
|
repo_url: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface CrawlerIngestorConfig extends BaseIngestorConfig {
|
||||||
|
url: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface UrlIngestorConfig extends BaseIngestorConfig {
|
||||||
|
url: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface GoogleDriveIngestorConfig extends BaseIngestorConfig {
|
||||||
|
folder_id?: string;
|
||||||
|
file_ids?: string;
|
||||||
|
recursive?: boolean;
|
||||||
|
token_info?: any;
|
||||||
|
}
|
||||||
|
|
||||||
|
export type IngestorType = 'crawler' | 'github' | 'reddit' | 'url' | 'google_drive';
|
||||||
|
|
||||||
export interface IngestorConfig {
|
export interface IngestorConfig {
|
||||||
type: IngestorType | null;
|
type: IngestorType;
|
||||||
name: string;
|
name: string;
|
||||||
config: Record<string, string | number | boolean | File[]>;
|
config:
|
||||||
|
| RedditIngestorConfig
|
||||||
|
| GithubIngestorConfig
|
||||||
|
| CrawlerIngestorConfig
|
||||||
|
| UrlIngestorConfig
|
||||||
|
| GoogleDriveIngestorConfig;
|
||||||
}
|
}
|
||||||
|
|
||||||
export type IngestorFormData = {
|
export type IngestorFormData = {
|
||||||
@@ -20,7 +49,7 @@ export type IngestorFormData = {
|
|||||||
data: string;
|
data: string;
|
||||||
};
|
};
|
||||||
|
|
||||||
export type FieldType = 'string' | 'number' | 'enum' | 'boolean' | 'local_file_picker' | 'remote_file_picker' | 'google_drive_picker';
|
export type FieldType = 'string' | 'number' | 'enum' | 'boolean';
|
||||||
|
|
||||||
export interface FormField {
|
export interface FormField {
|
||||||
name: string;
|
name: string;
|
||||||
@@ -31,82 +60,89 @@ export interface FormField {
|
|||||||
options?: { label: string; value: string }[];
|
options?: { label: string; value: string }[];
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface IngestorSchema {
|
export const IngestorFormSchemas: Record<IngestorType, FormField[]> = {
|
||||||
key: IngestorType;
|
crawler: [
|
||||||
label: string;
|
{
|
||||||
icon: string;
|
name: 'url',
|
||||||
heading: string;
|
label: 'URL',
|
||||||
validate?: () => boolean;
|
type: 'string',
|
||||||
fields: FormField[];
|
required: true,
|
||||||
}
|
|
||||||
|
|
||||||
export const IngestorFormSchemas: IngestorSchema[] = [
|
|
||||||
{
|
|
||||||
key: 'local_file',
|
|
||||||
label: 'Upload File',
|
|
||||||
icon: FileUploadIcon,
|
|
||||||
heading: 'Upload new document',
|
|
||||||
fields: [
|
|
||||||
{ name: 'files', label: 'Select files', type: 'local_file_picker', required: true },
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
key: 'crawler',
|
|
||||||
label: 'Crawler',
|
|
||||||
icon: CrawlerIcon,
|
|
||||||
heading: 'Add content with Web Crawler',
|
|
||||||
fields: [{ name: 'url', label: 'URL', type: 'string', required: true }]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
key: 'url',
|
|
||||||
label: 'Link',
|
|
||||||
icon: UrlIcon,
|
|
||||||
heading: 'Add content from URL',
|
|
||||||
fields: [{ name: 'url', label: 'URL', type: 'string', required: true }]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
key: 'github',
|
|
||||||
label: 'GitHub',
|
|
||||||
icon: GithubIcon,
|
|
||||||
heading: 'Add content from GitHub',
|
|
||||||
fields: [{ name: 'repo_url', label: 'Repository URL', type: 'string', required: true }]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
key: 'reddit',
|
|
||||||
label: 'Reddit',
|
|
||||||
icon: RedditIcon,
|
|
||||||
heading: 'Add content from Reddit',
|
|
||||||
fields: [
|
|
||||||
{ name: 'client_id', label: 'Client ID', type: 'string', required: true },
|
|
||||||
{ name: 'client_secret', label: 'Client Secret', type: 'string', required: true },
|
|
||||||
{ name: 'user_agent', label: 'User Agent', type: 'string', required: true },
|
|
||||||
{ name: 'search_queries', label: 'Search Queries', type: 'string', required: true },
|
|
||||||
{ name: 'number_posts', label: 'Number of Posts', type: 'number', required: true },
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
key: 'google_drive',
|
|
||||||
label: 'Google Drive',
|
|
||||||
icon: DriveIcon,
|
|
||||||
heading: 'Upload from Google Drive',
|
|
||||||
validate: () => {
|
|
||||||
const googleClientId = import.meta.env.VITE_GOOGLE_CLIENT_ID;
|
|
||||||
return !!(googleClientId);
|
|
||||||
},
|
},
|
||||||
fields: [
|
],
|
||||||
{
|
url: [
|
||||||
name: 'files',
|
{
|
||||||
label: 'Select Files from Google Drive',
|
name: 'url',
|
||||||
type: 'google_drive_picker',
|
label: 'URL',
|
||||||
required: true,
|
type: 'string',
|
||||||
}
|
required: true,
|
||||||
]
|
},
|
||||||
},
|
],
|
||||||
];
|
reddit: [
|
||||||
|
{
|
||||||
|
name: 'client_id',
|
||||||
|
label: 'Client ID',
|
||||||
|
type: 'string',
|
||||||
|
required: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: 'client_secret',
|
||||||
|
label: 'Client Secret',
|
||||||
|
type: 'string',
|
||||||
|
required: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: 'user_agent',
|
||||||
|
label: 'User Agent',
|
||||||
|
type: 'string',
|
||||||
|
required: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: 'search_queries',
|
||||||
|
label: 'Search Queries',
|
||||||
|
type: 'string',
|
||||||
|
required: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: 'number_posts',
|
||||||
|
label: 'Number of Posts',
|
||||||
|
type: 'number',
|
||||||
|
required: true,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
github: [
|
||||||
|
{
|
||||||
|
name: 'repo_url',
|
||||||
|
label: 'Repository URL',
|
||||||
|
type: 'string',
|
||||||
|
required: true,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
google_drive: [
|
||||||
|
{
|
||||||
|
name: 'recursive',
|
||||||
|
label: 'Include subfolders',
|
||||||
|
type: 'boolean',
|
||||||
|
required: false,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
};
|
||||||
|
|
||||||
export const IngestorDefaultConfigs: Record<IngestorType, Omit<IngestorConfig, 'type'>> = {
|
export const IngestorDefaultConfigs: Record<
|
||||||
crawler: { name: '', config: { url: '' } },
|
IngestorType,
|
||||||
url: { name: '', config: { url: '' } },
|
Omit<IngestorConfig, 'type'>
|
||||||
|
> = {
|
||||||
|
crawler: {
|
||||||
|
name: '',
|
||||||
|
config: {
|
||||||
|
url: '',
|
||||||
|
} as CrawlerIngestorConfig,
|
||||||
|
},
|
||||||
|
url: {
|
||||||
|
name: '',
|
||||||
|
config: {
|
||||||
|
url: '',
|
||||||
|
} as UrlIngestorConfig,
|
||||||
|
},
|
||||||
reddit: {
|
reddit: {
|
||||||
name: '',
|
name: '',
|
||||||
config: {
|
config: {
|
||||||
@@ -114,30 +150,21 @@ export const IngestorDefaultConfigs: Record<IngestorType, Omit<IngestorConfig, '
|
|||||||
client_secret: '',
|
client_secret: '',
|
||||||
user_agent: '',
|
user_agent: '',
|
||||||
search_queries: '',
|
search_queries: '',
|
||||||
number_posts: 10
|
number_posts: 10,
|
||||||
}
|
} as RedditIngestorConfig,
|
||||||
|
},
|
||||||
|
github: {
|
||||||
|
name: '',
|
||||||
|
config: {
|
||||||
|
repo_url: '',
|
||||||
|
} as GithubIngestorConfig,
|
||||||
},
|
},
|
||||||
github: { name: '', config: { repo_url: '' } },
|
|
||||||
google_drive: {
|
google_drive: {
|
||||||
name: '',
|
name: '',
|
||||||
config: {
|
config: {
|
||||||
|
folder_id: '',
|
||||||
file_ids: '',
|
file_ids: '',
|
||||||
folder_ids: '',
|
recursive: true,
|
||||||
recursive: true
|
} as GoogleDriveIngestorConfig,
|
||||||
}
|
|
||||||
},
|
},
|
||||||
local_file: { name: '', config: { files: [] } },
|
|
||||||
};
|
};
|
||||||
|
|
||||||
export interface IngestorOption {
|
|
||||||
label: string;
|
|
||||||
value: IngestorType;
|
|
||||||
icon: string;
|
|
||||||
heading: string;
|
|
||||||
}
|
|
||||||
|
|
||||||
export const getIngestorSchema = (key: IngestorType): IngestorSchema | undefined => {
|
|
||||||
return IngestorFormSchemas.find(schema => schema.key === key);
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
* Follows the convention: {provider}_session_token
|
* Follows the convention: {provider}_session_token
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
|
||||||
export const getSessionToken = (provider: string): string | null => {
|
export const getSessionToken = (provider: string): string | null => {
|
||||||
return localStorage.getItem(`${provider}_session_token`);
|
return localStorage.getItem(`${provider}_session_token`);
|
||||||
};
|
};
|
||||||
@@ -13,4 +14,4 @@ export const setSessionToken = (provider: string, token: string): void => {
|
|||||||
|
|
||||||
export const removeSessionToken = (provider: string): void => {
|
export const removeSessionToken = (provider: string): void => {
|
||||||
localStorage.removeItem(`${provider}_session_token`);
|
localStorage.removeItem(`${provider}_session_token`);
|
||||||
};
|
};
|
||||||
@@ -1,231 +0,0 @@
|
|||||||
from unittest.mock import Mock, patch
|
|
||||||
from typing import Any, Dict, Generator
|
|
||||||
|
|
||||||
from application.llm.handlers.base import LLMHandler, LLMResponse, ToolCall
|
|
||||||
|
|
||||||
|
|
||||||
class TestToolCall:
|
|
||||||
"""Test ToolCall dataclass."""
|
|
||||||
|
|
||||||
def test_tool_call_creation(self):
|
|
||||||
"""Test basic ToolCall creation."""
|
|
||||||
tool_call = ToolCall(
|
|
||||||
id="test_id",
|
|
||||||
name="test_function",
|
|
||||||
arguments={"arg1": "value1"},
|
|
||||||
index=0
|
|
||||||
)
|
|
||||||
assert tool_call.id == "test_id"
|
|
||||||
assert tool_call.name == "test_function"
|
|
||||||
assert tool_call.arguments == {"arg1": "value1"}
|
|
||||||
assert tool_call.index == 0
|
|
||||||
|
|
||||||
def test_tool_call_from_dict(self):
|
|
||||||
"""Test ToolCall creation from dictionary."""
|
|
||||||
data = {
|
|
||||||
"id": "call_123",
|
|
||||||
"name": "get_weather",
|
|
||||||
"arguments": {"location": "New York"},
|
|
||||||
"index": 1
|
|
||||||
}
|
|
||||||
tool_call = ToolCall.from_dict(data)
|
|
||||||
assert tool_call.id == "call_123"
|
|
||||||
assert tool_call.name == "get_weather"
|
|
||||||
assert tool_call.arguments == {"location": "New York"}
|
|
||||||
assert tool_call.index == 1
|
|
||||||
|
|
||||||
def test_tool_call_from_dict_missing_fields(self):
|
|
||||||
"""Test ToolCall creation with missing fields."""
|
|
||||||
data = {"name": "test_func"}
|
|
||||||
tool_call = ToolCall.from_dict(data)
|
|
||||||
assert tool_call.id == ""
|
|
||||||
assert tool_call.name == "test_func"
|
|
||||||
assert tool_call.arguments == {}
|
|
||||||
assert tool_call.index is None
|
|
||||||
|
|
||||||
|
|
||||||
class TestLLMResponse:
|
|
||||||
"""Test LLMResponse dataclass."""
|
|
||||||
|
|
||||||
def test_llm_response_creation(self):
|
|
||||||
"""Test basic LLMResponse creation."""
|
|
||||||
tool_calls = [ToolCall(id="1", name="func", arguments={})]
|
|
||||||
response = LLMResponse(
|
|
||||||
content="Hello",
|
|
||||||
tool_calls=tool_calls,
|
|
||||||
finish_reason="tool_calls",
|
|
||||||
raw_response={"test": "data"}
|
|
||||||
)
|
|
||||||
assert response.content == "Hello"
|
|
||||||
assert len(response.tool_calls) == 1
|
|
||||||
assert response.finish_reason == "tool_calls"
|
|
||||||
assert response.raw_response == {"test": "data"}
|
|
||||||
|
|
||||||
def test_requires_tool_call_true(self):
|
|
||||||
"""Test requires_tool_call property when tool calls are needed."""
|
|
||||||
tool_calls = [ToolCall(id="1", name="func", arguments={})]
|
|
||||||
response = LLMResponse(
|
|
||||||
content="",
|
|
||||||
tool_calls=tool_calls,
|
|
||||||
finish_reason="tool_calls",
|
|
||||||
raw_response={}
|
|
||||||
)
|
|
||||||
assert response.requires_tool_call is True
|
|
||||||
|
|
||||||
def test_requires_tool_call_false_no_tools(self):
|
|
||||||
"""Test requires_tool_call property when no tool calls."""
|
|
||||||
response = LLMResponse(
|
|
||||||
content="Hello",
|
|
||||||
tool_calls=[],
|
|
||||||
finish_reason="stop",
|
|
||||||
raw_response={}
|
|
||||||
)
|
|
||||||
assert response.requires_tool_call is False
|
|
||||||
|
|
||||||
def test_requires_tool_call_false_wrong_finish_reason(self):
|
|
||||||
"""Test requires_tool_call property with tools but wrong finish reason."""
|
|
||||||
tool_calls = [ToolCall(id="1", name="func", arguments={})]
|
|
||||||
response = LLMResponse(
|
|
||||||
content="Hello",
|
|
||||||
tool_calls=tool_calls,
|
|
||||||
finish_reason="stop",
|
|
||||||
raw_response={}
|
|
||||||
)
|
|
||||||
assert response.requires_tool_call is False
|
|
||||||
|
|
||||||
|
|
||||||
class ConcreteHandler(LLMHandler):
|
|
||||||
"""Concrete implementation for testing abstract base class."""
|
|
||||||
|
|
||||||
def parse_response(self, response: Any) -> LLMResponse:
|
|
||||||
return LLMResponse(
|
|
||||||
content=str(response),
|
|
||||||
tool_calls=[],
|
|
||||||
finish_reason="stop",
|
|
||||||
raw_response=response
|
|
||||||
)
|
|
||||||
|
|
||||||
def create_tool_message(self, tool_call: ToolCall, result: Any) -> Dict:
|
|
||||||
return {
|
|
||||||
"role": "tool",
|
|
||||||
"content": str(result),
|
|
||||||
"tool_call_id": tool_call.id
|
|
||||||
}
|
|
||||||
|
|
||||||
def _iterate_stream(self, response: Any) -> Generator:
|
|
||||||
for chunk in response:
|
|
||||||
yield chunk
|
|
||||||
|
|
||||||
|
|
||||||
class TestLLMHandler:
|
|
||||||
"""Test LLMHandler base class."""
|
|
||||||
|
|
||||||
def test_handler_initialization(self):
|
|
||||||
"""Test handler initialization."""
|
|
||||||
handler = ConcreteHandler()
|
|
||||||
assert handler.llm_calls == []
|
|
||||||
assert handler.tool_calls == []
|
|
||||||
|
|
||||||
def test_prepare_messages_no_attachments(self):
|
|
||||||
"""Test prepare_messages with no attachments."""
|
|
||||||
handler = ConcreteHandler()
|
|
||||||
messages = [{"role": "user", "content": "Hello"}]
|
|
||||||
|
|
||||||
mock_agent = Mock()
|
|
||||||
result = handler.prepare_messages(mock_agent, messages, None)
|
|
||||||
assert result == messages
|
|
||||||
|
|
||||||
def test_prepare_messages_with_supported_attachments(self):
|
|
||||||
"""Test prepare_messages with supported attachments."""
|
|
||||||
handler = ConcreteHandler()
|
|
||||||
messages = [{"role": "user", "content": "Hello"}]
|
|
||||||
attachments = [{"mime_type": "image/png", "path": "/test.png"}]
|
|
||||||
|
|
||||||
mock_agent = Mock()
|
|
||||||
mock_agent.llm.get_supported_attachment_types.return_value = ["image/png"]
|
|
||||||
mock_agent.llm.prepare_messages_with_attachments.return_value = messages
|
|
||||||
|
|
||||||
result = handler.prepare_messages(mock_agent, messages, attachments)
|
|
||||||
mock_agent.llm.prepare_messages_with_attachments.assert_called_once_with(
|
|
||||||
messages, attachments
|
|
||||||
)
|
|
||||||
assert result == messages
|
|
||||||
|
|
||||||
@patch('application.llm.handlers.base.logger')
|
|
||||||
def test_prepare_messages_with_unsupported_attachments(self, mock_logger):
|
|
||||||
"""Test prepare_messages with unsupported attachments."""
|
|
||||||
handler = ConcreteHandler()
|
|
||||||
messages = [{"role": "user", "content": "Hello"}]
|
|
||||||
attachments = [{"mime_type": "text/plain", "path": "/test.txt"}]
|
|
||||||
|
|
||||||
mock_agent = Mock()
|
|
||||||
mock_agent.llm.get_supported_attachment_types.return_value = ["image/png"]
|
|
||||||
|
|
||||||
with patch.object(handler, '_append_unsupported_attachments', return_value=messages) as mock_append:
|
|
||||||
result = handler.prepare_messages(mock_agent, messages, attachments)
|
|
||||||
mock_append.assert_called_once_with(messages, attachments)
|
|
||||||
assert result == messages
|
|
||||||
|
|
||||||
def test_prepare_messages_mixed_attachments(self):
|
|
||||||
"""Test prepare_messages with both supported and unsupported attachments."""
|
|
||||||
handler = ConcreteHandler()
|
|
||||||
messages = [{"role": "user", "content": "Hello"}]
|
|
||||||
attachments = [
|
|
||||||
{"mime_type": "image/png", "path": "/test.png"},
|
|
||||||
{"mime_type": "text/plain", "path": "/test.txt"}
|
|
||||||
]
|
|
||||||
|
|
||||||
mock_agent = Mock()
|
|
||||||
mock_agent.llm.get_supported_attachment_types.return_value = ["image/png"]
|
|
||||||
mock_agent.llm.prepare_messages_with_attachments.return_value = messages
|
|
||||||
|
|
||||||
with patch.object(handler, '_append_unsupported_attachments', return_value=messages) as mock_append:
|
|
||||||
result = handler.prepare_messages(mock_agent, messages, attachments)
|
|
||||||
|
|
||||||
# Should call both methods
|
|
||||||
mock_agent.llm.prepare_messages_with_attachments.assert_called_once()
|
|
||||||
mock_append.assert_called_once()
|
|
||||||
assert result == messages
|
|
||||||
|
|
||||||
def test_process_message_flow_non_streaming(self):
|
|
||||||
"""Test process_message_flow for non-streaming."""
|
|
||||||
handler = ConcreteHandler()
|
|
||||||
mock_agent = Mock()
|
|
||||||
initial_response = "test response"
|
|
||||||
tools_dict = {}
|
|
||||||
messages = [{"role": "user", "content": "Hello"}]
|
|
||||||
|
|
||||||
with patch.object(handler, 'prepare_messages', return_value=messages) as mock_prepare:
|
|
||||||
with patch.object(handler, 'handle_non_streaming', return_value="final") as mock_handle:
|
|
||||||
result = handler.process_message_flow(
|
|
||||||
mock_agent, initial_response, tools_dict, messages, stream=False
|
|
||||||
)
|
|
||||||
|
|
||||||
mock_prepare.assert_called_once_with(mock_agent, messages, None)
|
|
||||||
mock_handle.assert_called_once_with(mock_agent, initial_response, tools_dict, messages)
|
|
||||||
assert result == "final"
|
|
||||||
|
|
||||||
def test_process_message_flow_streaming(self):
|
|
||||||
"""Test process_message_flow for streaming."""
|
|
||||||
handler = ConcreteHandler()
|
|
||||||
mock_agent = Mock()
|
|
||||||
initial_response = "test response"
|
|
||||||
tools_dict = {}
|
|
||||||
messages = [{"role": "user", "content": "Hello"}]
|
|
||||||
|
|
||||||
def mock_generator():
|
|
||||||
yield "chunk1"
|
|
||||||
yield "chunk2"
|
|
||||||
|
|
||||||
with patch.object(handler, 'prepare_messages', return_value=messages) as mock_prepare:
|
|
||||||
with patch.object(handler, 'handle_streaming', return_value=mock_generator()) as mock_handle:
|
|
||||||
result = handler.process_message_flow(
|
|
||||||
mock_agent, initial_response, tools_dict, messages, stream=True
|
|
||||||
)
|
|
||||||
|
|
||||||
mock_prepare.assert_called_once_with(mock_agent, messages, None)
|
|
||||||
mock_handle.assert_called_once_with(mock_agent, initial_response, tools_dict, messages)
|
|
||||||
|
|
||||||
# Verify it's a generator
|
|
||||||
chunks = list(result)
|
|
||||||
assert chunks == ["chunk1", "chunk2"]
|
|
||||||
@@ -1,270 +0,0 @@
|
|||||||
from unittest.mock import Mock, patch
|
|
||||||
from types import SimpleNamespace
|
|
||||||
import uuid
|
|
||||||
|
|
||||||
from application.llm.handlers.google import GoogleLLMHandler
|
|
||||||
from application.llm.handlers.base import ToolCall, LLMResponse
|
|
||||||
|
|
||||||
|
|
||||||
class TestGoogleLLMHandler:
|
|
||||||
"""Test GoogleLLMHandler class."""
|
|
||||||
|
|
||||||
def test_handler_initialization(self):
|
|
||||||
"""Test handler initialization."""
|
|
||||||
handler = GoogleLLMHandler()
|
|
||||||
assert handler.llm_calls == []
|
|
||||||
assert handler.tool_calls == []
|
|
||||||
|
|
||||||
def test_parse_response_string_input(self):
|
|
||||||
"""Test parsing string response."""
|
|
||||||
handler = GoogleLLMHandler()
|
|
||||||
response = "Hello from Google!"
|
|
||||||
|
|
||||||
result = handler.parse_response(response)
|
|
||||||
|
|
||||||
assert isinstance(result, LLMResponse)
|
|
||||||
assert result.content == "Hello from Google!"
|
|
||||||
assert result.tool_calls == []
|
|
||||||
assert result.finish_reason == "stop"
|
|
||||||
assert result.raw_response == "Hello from Google!"
|
|
||||||
|
|
||||||
def test_parse_response_with_candidates_text_only(self):
|
|
||||||
"""Test parsing response with candidates containing only text."""
|
|
||||||
handler = GoogleLLMHandler()
|
|
||||||
|
|
||||||
mock_part = SimpleNamespace(text="Google response text")
|
|
||||||
mock_content = SimpleNamespace(parts=[mock_part])
|
|
||||||
mock_candidate = SimpleNamespace(content=mock_content)
|
|
||||||
mock_response = SimpleNamespace(candidates=[mock_candidate])
|
|
||||||
|
|
||||||
result = handler.parse_response(mock_response)
|
|
||||||
|
|
||||||
assert result.content == "Google response text"
|
|
||||||
assert result.tool_calls == []
|
|
||||||
assert result.finish_reason == "stop"
|
|
||||||
assert result.raw_response == mock_response
|
|
||||||
|
|
||||||
def test_parse_response_with_multiple_text_parts(self):
|
|
||||||
"""Test parsing response with multiple text parts."""
|
|
||||||
handler = GoogleLLMHandler()
|
|
||||||
|
|
||||||
mock_part1 = SimpleNamespace(text="First part")
|
|
||||||
mock_part2 = SimpleNamespace(text="Second part")
|
|
||||||
mock_content = SimpleNamespace(parts=[mock_part1, mock_part2])
|
|
||||||
mock_candidate = SimpleNamespace(content=mock_content)
|
|
||||||
mock_response = SimpleNamespace(candidates=[mock_candidate])
|
|
||||||
|
|
||||||
result = handler.parse_response(mock_response)
|
|
||||||
|
|
||||||
assert result.content == "First part Second part"
|
|
||||||
assert result.tool_calls == []
|
|
||||||
assert result.finish_reason == "stop"
|
|
||||||
|
|
||||||
@patch('uuid.uuid4')
|
|
||||||
def test_parse_response_with_function_call(self, mock_uuid):
|
|
||||||
"""Test parsing response with function call."""
|
|
||||||
mock_uuid.return_value = Mock(spec=uuid.UUID)
|
|
||||||
mock_uuid.return_value.__str__ = Mock(return_value="test-uuid-123")
|
|
||||||
|
|
||||||
handler = GoogleLLMHandler()
|
|
||||||
|
|
||||||
mock_function_call = SimpleNamespace(
|
|
||||||
name="get_weather",
|
|
||||||
args={"location": "San Francisco"}
|
|
||||||
)
|
|
||||||
mock_part = SimpleNamespace(function_call=mock_function_call)
|
|
||||||
mock_content = SimpleNamespace(parts=[mock_part])
|
|
||||||
mock_candidate = SimpleNamespace(content=mock_content)
|
|
||||||
mock_response = SimpleNamespace(candidates=[mock_candidate])
|
|
||||||
|
|
||||||
result = handler.parse_response(mock_response)
|
|
||||||
|
|
||||||
assert result.content == ""
|
|
||||||
assert len(result.tool_calls) == 1
|
|
||||||
assert result.tool_calls[0].id == "test-uuid-123"
|
|
||||||
assert result.tool_calls[0].name == "get_weather"
|
|
||||||
assert result.tool_calls[0].arguments == {"location": "San Francisco"}
|
|
||||||
assert result.finish_reason == "tool_calls"
|
|
||||||
|
|
||||||
@patch('uuid.uuid4')
|
|
||||||
def test_parse_response_with_mixed_parts(self, mock_uuid):
|
|
||||||
"""Test parsing response with both text and function call parts."""
|
|
||||||
mock_uuid.return_value = Mock(spec=uuid.UUID)
|
|
||||||
mock_uuid.return_value.__str__ = Mock(return_value="test-uuid-456")
|
|
||||||
|
|
||||||
handler = GoogleLLMHandler()
|
|
||||||
|
|
||||||
mock_text_part = SimpleNamespace(text="I'll check the weather for you.")
|
|
||||||
mock_function_call = SimpleNamespace(
|
|
||||||
name="get_weather",
|
|
||||||
args={"location": "NYC"}
|
|
||||||
)
|
|
||||||
mock_function_part = SimpleNamespace(function_call=mock_function_call)
|
|
||||||
|
|
||||||
mock_content = SimpleNamespace(parts=[mock_text_part, mock_function_part])
|
|
||||||
mock_candidate = SimpleNamespace(content=mock_content)
|
|
||||||
mock_response = SimpleNamespace(candidates=[mock_candidate])
|
|
||||||
|
|
||||||
result = handler.parse_response(mock_response)
|
|
||||||
|
|
||||||
assert result.content == "I'll check the weather for you."
|
|
||||||
assert len(result.tool_calls) == 1
|
|
||||||
assert result.tool_calls[0].name == "get_weather"
|
|
||||||
assert result.finish_reason == "tool_calls"
|
|
||||||
|
|
||||||
def test_parse_response_empty_candidates(self):
|
|
||||||
"""Test parsing response with empty candidates."""
|
|
||||||
handler = GoogleLLMHandler()
|
|
||||||
|
|
||||||
mock_response = SimpleNamespace(candidates=[])
|
|
||||||
|
|
||||||
result = handler.parse_response(mock_response)
|
|
||||||
|
|
||||||
assert result.content == ""
|
|
||||||
assert result.tool_calls == []
|
|
||||||
assert result.finish_reason == "stop"
|
|
||||||
|
|
||||||
def test_parse_response_parts_with_none_text(self):
|
|
||||||
"""Test parsing response with parts that have None text."""
|
|
||||||
handler = GoogleLLMHandler()
|
|
||||||
|
|
||||||
mock_part1 = SimpleNamespace(text=None)
|
|
||||||
mock_part2 = SimpleNamespace(text="Valid text")
|
|
||||||
mock_content = SimpleNamespace(parts=[mock_part1, mock_part2])
|
|
||||||
mock_candidate = SimpleNamespace(content=mock_content)
|
|
||||||
mock_response = SimpleNamespace(candidates=[mock_candidate])
|
|
||||||
|
|
||||||
result = handler.parse_response(mock_response)
|
|
||||||
|
|
||||||
assert result.content == "Valid text"
|
|
||||||
|
|
||||||
def test_parse_response_parts_without_text_attribute(self):
|
|
||||||
"""Test parsing response with parts missing text attribute."""
|
|
||||||
handler = GoogleLLMHandler()
|
|
||||||
|
|
||||||
mock_part1 = SimpleNamespace()
|
|
||||||
mock_part2 = SimpleNamespace(text="Valid text")
|
|
||||||
mock_content = SimpleNamespace(parts=[mock_part1, mock_part2])
|
|
||||||
mock_candidate = SimpleNamespace(content=mock_content)
|
|
||||||
mock_response = SimpleNamespace(candidates=[mock_candidate])
|
|
||||||
|
|
||||||
result = handler.parse_response(mock_response)
|
|
||||||
|
|
||||||
assert result.content == "Valid text"
|
|
||||||
|
|
||||||
@patch('uuid.uuid4')
|
|
||||||
def test_parse_response_direct_function_call(self, mock_uuid):
|
|
||||||
"""Test parsing response with direct function call (not in candidates)."""
|
|
||||||
mock_uuid.return_value = Mock(spec=uuid.UUID)
|
|
||||||
mock_uuid.return_value.__str__ = Mock(return_value="direct-uuid-789")
|
|
||||||
|
|
||||||
handler = GoogleLLMHandler()
|
|
||||||
|
|
||||||
mock_function_call = SimpleNamespace(
|
|
||||||
name="calculate",
|
|
||||||
args={"expression": "2+2"}
|
|
||||||
)
|
|
||||||
mock_response = SimpleNamespace(
|
|
||||||
function_call=mock_function_call,
|
|
||||||
text="The calculation result is:"
|
|
||||||
)
|
|
||||||
|
|
||||||
result = handler.parse_response(mock_response)
|
|
||||||
|
|
||||||
assert result.content == "The calculation result is:"
|
|
||||||
assert len(result.tool_calls) == 1
|
|
||||||
assert result.tool_calls[0].id == "direct-uuid-789"
|
|
||||||
assert result.tool_calls[0].name == "calculate"
|
|
||||||
assert result.tool_calls[0].arguments == {"expression": "2+2"}
|
|
||||||
assert result.finish_reason == "tool_calls"
|
|
||||||
|
|
||||||
def test_parse_response_direct_function_call_no_text(self):
|
|
||||||
"""Test parsing response with direct function call and no text."""
|
|
||||||
handler = GoogleLLMHandler()
|
|
||||||
|
|
||||||
mock_function_call = SimpleNamespace(
|
|
||||||
name="get_data",
|
|
||||||
args={"id": 123}
|
|
||||||
)
|
|
||||||
mock_response = SimpleNamespace(function_call=mock_function_call)
|
|
||||||
|
|
||||||
result = handler.parse_response(mock_response)
|
|
||||||
|
|
||||||
assert result.content == ""
|
|
||||||
assert len(result.tool_calls) == 1
|
|
||||||
assert result.tool_calls[0].name == "get_data"
|
|
||||||
assert result.finish_reason == "tool_calls"
|
|
||||||
|
|
||||||
def test_create_tool_message(self):
|
|
||||||
"""Test creating tool message."""
|
|
||||||
handler = GoogleLLMHandler()
|
|
||||||
|
|
||||||
tool_call = ToolCall(
|
|
||||||
id="call_123",
|
|
||||||
name="get_weather",
|
|
||||||
arguments={"location": "Tokyo"},
|
|
||||||
index=0
|
|
||||||
)
|
|
||||||
result = {"temperature": "25C", "condition": "cloudy"}
|
|
||||||
|
|
||||||
message = handler.create_tool_message(tool_call, result)
|
|
||||||
|
|
||||||
expected = {
|
|
||||||
"role": "model",
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"function_response": {
|
|
||||||
"name": "get_weather",
|
|
||||||
"response": {"result": result},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
assert message == expected
|
|
||||||
|
|
||||||
def test_create_tool_message_string_result(self):
|
|
||||||
"""Test creating tool message with string result."""
|
|
||||||
handler = GoogleLLMHandler()
|
|
||||||
|
|
||||||
tool_call = ToolCall(id="call_456", name="get_time", arguments={})
|
|
||||||
result = "2023-12-01 15:30:00 JST"
|
|
||||||
|
|
||||||
message = handler.create_tool_message(tool_call, result)
|
|
||||||
|
|
||||||
assert message["role"] == "model"
|
|
||||||
assert message["content"][0]["function_response"]["response"]["result"] == result
|
|
||||||
assert message["content"][0]["function_response"]["name"] == "get_time"
|
|
||||||
|
|
||||||
def test_iterate_stream(self):
|
|
||||||
"""Test stream iteration."""
|
|
||||||
handler = GoogleLLMHandler()
|
|
||||||
|
|
||||||
mock_chunks = ["chunk1", "chunk2", "chunk3"]
|
|
||||||
|
|
||||||
result = list(handler._iterate_stream(mock_chunks))
|
|
||||||
|
|
||||||
assert result == mock_chunks
|
|
||||||
|
|
||||||
def test_iterate_stream_empty(self):
|
|
||||||
"""Test stream iteration with empty response."""
|
|
||||||
handler = GoogleLLMHandler()
|
|
||||||
|
|
||||||
result = list(handler._iterate_stream([]))
|
|
||||||
|
|
||||||
assert result == []
|
|
||||||
|
|
||||||
def test_parse_response_parts_without_function_call_attribute(self):
|
|
||||||
"""Test parsing response with parts missing function_call attribute."""
|
|
||||||
handler = GoogleLLMHandler()
|
|
||||||
|
|
||||||
mock_part = SimpleNamespace(text="Normal text")
|
|
||||||
mock_content = SimpleNamespace(parts=[mock_part])
|
|
||||||
mock_candidate = SimpleNamespace(content=mock_content)
|
|
||||||
mock_response = SimpleNamespace(candidates=[mock_candidate])
|
|
||||||
|
|
||||||
result = handler.parse_response(mock_response)
|
|
||||||
|
|
||||||
assert result.content == "Normal text"
|
|
||||||
assert result.tool_calls == []
|
|
||||||
assert result.finish_reason == "stop"
|
|
||||||
@@ -1,125 +0,0 @@
|
|||||||
|
|
||||||
from application.llm.handlers.handler_creator import LLMHandlerCreator
|
|
||||||
from application.llm.handlers.base import LLMHandler
|
|
||||||
from application.llm.handlers.openai import OpenAILLMHandler
|
|
||||||
from application.llm.handlers.google import GoogleLLMHandler
|
|
||||||
|
|
||||||
|
|
||||||
class TestLLMHandlerCreator:
|
|
||||||
"""Test LLMHandlerCreator class."""
|
|
||||||
|
|
||||||
def test_create_openai_handler(self):
|
|
||||||
"""Test creating OpenAI handler."""
|
|
||||||
handler = LLMHandlerCreator.create_handler("openai")
|
|
||||||
|
|
||||||
assert isinstance(handler, OpenAILLMHandler)
|
|
||||||
assert isinstance(handler, LLMHandler)
|
|
||||||
|
|
||||||
def test_create_openai_handler_case_insensitive(self):
|
|
||||||
"""Test creating OpenAI handler with different cases."""
|
|
||||||
handler_upper = LLMHandlerCreator.create_handler("OPENAI")
|
|
||||||
handler_mixed = LLMHandlerCreator.create_handler("OpenAI")
|
|
||||||
|
|
||||||
assert isinstance(handler_upper, OpenAILLMHandler)
|
|
||||||
assert isinstance(handler_mixed, OpenAILLMHandler)
|
|
||||||
|
|
||||||
def test_create_google_handler(self):
|
|
||||||
"""Test creating Google handler."""
|
|
||||||
handler = LLMHandlerCreator.create_handler("google")
|
|
||||||
|
|
||||||
assert isinstance(handler, GoogleLLMHandler)
|
|
||||||
assert isinstance(handler, LLMHandler)
|
|
||||||
|
|
||||||
def test_create_google_handler_case_insensitive(self):
|
|
||||||
"""Test creating Google handler with different cases."""
|
|
||||||
handler_upper = LLMHandlerCreator.create_handler("GOOGLE")
|
|
||||||
handler_mixed = LLMHandlerCreator.create_handler("Google")
|
|
||||||
|
|
||||||
assert isinstance(handler_upper, GoogleLLMHandler)
|
|
||||||
assert isinstance(handler_mixed, GoogleLLMHandler)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def test_create_default_handler(self):
|
|
||||||
"""Test creating default handler."""
|
|
||||||
handler = LLMHandlerCreator.create_handler("default")
|
|
||||||
|
|
||||||
assert isinstance(handler, OpenAILLMHandler)
|
|
||||||
|
|
||||||
def test_create_unknown_handler_fallback(self):
|
|
||||||
"""Test creating handler for unknown type falls back to OpenAI."""
|
|
||||||
handler = LLMHandlerCreator.create_handler("unknown_provider")
|
|
||||||
|
|
||||||
assert isinstance(handler, OpenAILLMHandler)
|
|
||||||
|
|
||||||
def test_create_anthropic_handler_fallback(self):
|
|
||||||
"""Test creating Anthropic handler falls back to OpenAI (not supported in handlers)."""
|
|
||||||
handler = LLMHandlerCreator.create_handler("anthropic")
|
|
||||||
|
|
||||||
assert isinstance(handler, OpenAILLMHandler)
|
|
||||||
|
|
||||||
def test_create_empty_string_handler_fallback(self):
|
|
||||||
"""Test creating handler with empty string falls back to OpenAI."""
|
|
||||||
handler = LLMHandlerCreator.create_handler("")
|
|
||||||
|
|
||||||
assert isinstance(handler, OpenAILLMHandler)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def test_handlers_registry(self):
|
|
||||||
"""Test the handlers registry contains expected mappings."""
|
|
||||||
expected_handlers = {
|
|
||||||
"openai": OpenAILLMHandler,
|
|
||||||
"google": GoogleLLMHandler,
|
|
||||||
"default": OpenAILLMHandler,
|
|
||||||
}
|
|
||||||
|
|
||||||
assert LLMHandlerCreator.handlers == expected_handlers
|
|
||||||
|
|
||||||
def test_create_handler_with_args(self):
|
|
||||||
"""Test creating handler with additional arguments."""
|
|
||||||
handler = LLMHandlerCreator.create_handler("openai")
|
|
||||||
|
|
||||||
assert isinstance(handler, OpenAILLMHandler)
|
|
||||||
assert handler.llm_calls == []
|
|
||||||
assert handler.tool_calls == []
|
|
||||||
|
|
||||||
def test_create_handler_with_kwargs(self):
|
|
||||||
"""Test creating handler with keyword arguments."""
|
|
||||||
handler = LLMHandlerCreator.create_handler("google")
|
|
||||||
|
|
||||||
assert isinstance(handler, GoogleLLMHandler)
|
|
||||||
assert handler.llm_calls == []
|
|
||||||
assert handler.tool_calls == []
|
|
||||||
|
|
||||||
def test_all_registered_handlers_are_valid(self):
|
|
||||||
"""Test that all registered handlers can be instantiated."""
|
|
||||||
for handler_type in LLMHandlerCreator.handlers.keys():
|
|
||||||
handler = LLMHandlerCreator.create_handler(handler_type)
|
|
||||||
assert isinstance(handler, LLMHandler)
|
|
||||||
assert hasattr(handler, 'parse_response')
|
|
||||||
assert hasattr(handler, 'create_tool_message')
|
|
||||||
assert hasattr(handler, '_iterate_stream')
|
|
||||||
|
|
||||||
def test_handler_inheritance(self):
|
|
||||||
"""Test that all created handlers inherit from LLMHandler."""
|
|
||||||
test_types = ["openai", "google", "default", "unknown"]
|
|
||||||
|
|
||||||
for handler_type in test_types:
|
|
||||||
handler = LLMHandlerCreator.create_handler(handler_type)
|
|
||||||
assert isinstance(handler, LLMHandler)
|
|
||||||
|
|
||||||
assert callable(getattr(handler, 'parse_response'))
|
|
||||||
assert callable(getattr(handler, 'create_tool_message'))
|
|
||||||
assert callable(getattr(handler, '_iterate_stream'))
|
|
||||||
|
|
||||||
def test_create_handler_preserves_handler_state(self):
|
|
||||||
"""Test that each created handler has independent state."""
|
|
||||||
handler1 = LLMHandlerCreator.create_handler("openai")
|
|
||||||
handler2 = LLMHandlerCreator.create_handler("openai")
|
|
||||||
|
|
||||||
handler1.llm_calls.append("test_call")
|
|
||||||
|
|
||||||
assert len(handler1.llm_calls) == 1
|
|
||||||
assert len(handler2.llm_calls) == 0
|
|
||||||
assert handler1 is not handler2
|
|
||||||
@@ -1,208 +0,0 @@
|
|||||||
from types import SimpleNamespace
|
|
||||||
|
|
||||||
from application.llm.handlers.openai import OpenAILLMHandler
|
|
||||||
from application.llm.handlers.base import ToolCall, LLMResponse
|
|
||||||
|
|
||||||
|
|
||||||
class TestOpenAILLMHandler:
|
|
||||||
"""Test OpenAILLMHandler class."""
|
|
||||||
|
|
||||||
def test_handler_initialization(self):
|
|
||||||
"""Test handler initialization."""
|
|
||||||
handler = OpenAILLMHandler()
|
|
||||||
assert handler.llm_calls == []
|
|
||||||
assert handler.tool_calls == []
|
|
||||||
|
|
||||||
def test_parse_response_string_input(self):
|
|
||||||
"""Test parsing string response."""
|
|
||||||
handler = OpenAILLMHandler()
|
|
||||||
response = "Hello, world!"
|
|
||||||
|
|
||||||
result = handler.parse_response(response)
|
|
||||||
|
|
||||||
assert isinstance(result, LLMResponse)
|
|
||||||
assert result.content == "Hello, world!"
|
|
||||||
assert result.tool_calls == []
|
|
||||||
assert result.finish_reason == "stop"
|
|
||||||
assert result.raw_response == "Hello, world!"
|
|
||||||
|
|
||||||
def test_parse_response_with_message_content(self):
|
|
||||||
"""Test parsing response with message content."""
|
|
||||||
handler = OpenAILLMHandler()
|
|
||||||
|
|
||||||
# Mock OpenAI response structure
|
|
||||||
mock_message = SimpleNamespace(content="Test content", tool_calls=None)
|
|
||||||
mock_response = SimpleNamespace(message=mock_message, finish_reason="stop")
|
|
||||||
|
|
||||||
result = handler.parse_response(mock_response)
|
|
||||||
|
|
||||||
assert result.content == "Test content"
|
|
||||||
assert result.tool_calls == []
|
|
||||||
assert result.finish_reason == "stop"
|
|
||||||
assert result.raw_response == mock_response
|
|
||||||
|
|
||||||
def test_parse_response_with_delta_content(self):
|
|
||||||
"""Test parsing response with delta content (streaming)."""
|
|
||||||
handler = OpenAILLMHandler()
|
|
||||||
|
|
||||||
# Mock streaming response structure
|
|
||||||
mock_delta = SimpleNamespace(content="Stream chunk", tool_calls=None)
|
|
||||||
mock_response = SimpleNamespace(delta=mock_delta, finish_reason="")
|
|
||||||
|
|
||||||
result = handler.parse_response(mock_response)
|
|
||||||
|
|
||||||
assert result.content == "Stream chunk"
|
|
||||||
assert result.tool_calls == []
|
|
||||||
assert result.finish_reason == ""
|
|
||||||
assert result.raw_response == mock_response
|
|
||||||
|
|
||||||
def test_parse_response_with_tool_calls(self):
|
|
||||||
"""Test parsing response with tool calls."""
|
|
||||||
handler = OpenAILLMHandler()
|
|
||||||
|
|
||||||
# Mock tool call structure
|
|
||||||
mock_function = SimpleNamespace(name="get_weather", arguments='{"location": "NYC"}')
|
|
||||||
mock_tool_call = SimpleNamespace(
|
|
||||||
id="call_123",
|
|
||||||
function=mock_function,
|
|
||||||
index=0
|
|
||||||
)
|
|
||||||
mock_message = SimpleNamespace(content="", tool_calls=[mock_tool_call])
|
|
||||||
mock_response = SimpleNamespace(message=mock_message, finish_reason="tool_calls")
|
|
||||||
|
|
||||||
result = handler.parse_response(mock_response)
|
|
||||||
|
|
||||||
assert result.content == ""
|
|
||||||
assert len(result.tool_calls) == 1
|
|
||||||
assert result.tool_calls[0].id == "call_123"
|
|
||||||
assert result.tool_calls[0].name == "get_weather"
|
|
||||||
assert result.tool_calls[0].arguments == '{"location": "NYC"}'
|
|
||||||
assert result.tool_calls[0].index == 0
|
|
||||||
assert result.finish_reason == "tool_calls"
|
|
||||||
|
|
||||||
def test_parse_response_with_multiple_tool_calls(self):
|
|
||||||
"""Test parsing response with multiple tool calls."""
|
|
||||||
handler = OpenAILLMHandler()
|
|
||||||
|
|
||||||
# Mock multiple tool calls
|
|
||||||
mock_function1 = SimpleNamespace(name="get_weather", arguments='{"location": "NYC"}')
|
|
||||||
mock_function2 = SimpleNamespace(name="get_time", arguments='{"timezone": "UTC"}')
|
|
||||||
|
|
||||||
mock_tool_call1 = SimpleNamespace(id="call_1", function=mock_function1, index=0)
|
|
||||||
mock_tool_call2 = SimpleNamespace(id="call_2", function=mock_function2, index=1)
|
|
||||||
|
|
||||||
mock_message = SimpleNamespace(content="", tool_calls=[mock_tool_call1, mock_tool_call2])
|
|
||||||
mock_response = SimpleNamespace(message=mock_message, finish_reason="tool_calls")
|
|
||||||
|
|
||||||
result = handler.parse_response(mock_response)
|
|
||||||
|
|
||||||
assert len(result.tool_calls) == 2
|
|
||||||
assert result.tool_calls[0].name == "get_weather"
|
|
||||||
assert result.tool_calls[1].name == "get_time"
|
|
||||||
|
|
||||||
def test_parse_response_empty_tool_calls(self):
|
|
||||||
"""Test parsing response with empty tool_calls."""
|
|
||||||
handler = OpenAILLMHandler()
|
|
||||||
|
|
||||||
mock_message = SimpleNamespace(content="No tools needed", tool_calls=None)
|
|
||||||
mock_response = SimpleNamespace(message=mock_message, finish_reason="stop")
|
|
||||||
|
|
||||||
result = handler.parse_response(mock_response)
|
|
||||||
|
|
||||||
assert result.content == "No tools needed"
|
|
||||||
assert result.tool_calls == []
|
|
||||||
assert result.finish_reason == "stop"
|
|
||||||
|
|
||||||
def test_parse_response_missing_attributes(self):
|
|
||||||
"""Test parsing response with missing attributes."""
|
|
||||||
handler = OpenAILLMHandler()
|
|
||||||
|
|
||||||
# Mock response with missing attributes
|
|
||||||
mock_message = SimpleNamespace() # No content or tool_calls
|
|
||||||
mock_response = SimpleNamespace(message=mock_message) # No finish_reason
|
|
||||||
|
|
||||||
result = handler.parse_response(mock_response)
|
|
||||||
|
|
||||||
assert result.content == ""
|
|
||||||
assert result.tool_calls == []
|
|
||||||
assert result.finish_reason == ""
|
|
||||||
|
|
||||||
def test_create_tool_message(self):
|
|
||||||
"""Test creating tool message."""
|
|
||||||
handler = OpenAILLMHandler()
|
|
||||||
|
|
||||||
tool_call = ToolCall(
|
|
||||||
id="call_123",
|
|
||||||
name="get_weather",
|
|
||||||
arguments={"location": "NYC"},
|
|
||||||
index=0
|
|
||||||
)
|
|
||||||
result = {"temperature": "72F", "condition": "sunny"}
|
|
||||||
|
|
||||||
message = handler.create_tool_message(tool_call, result)
|
|
||||||
|
|
||||||
expected = {
|
|
||||||
"role": "tool",
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"function_response": {
|
|
||||||
"name": "get_weather",
|
|
||||||
"response": {"result": result},
|
|
||||||
"call_id": "call_123",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
assert message == expected
|
|
||||||
|
|
||||||
def test_create_tool_message_string_result(self):
|
|
||||||
"""Test creating tool message with string result."""
|
|
||||||
handler = OpenAILLMHandler()
|
|
||||||
|
|
||||||
tool_call = ToolCall(id="call_456", name="get_time", arguments={})
|
|
||||||
result = "2023-12-01 10:30:00"
|
|
||||||
|
|
||||||
message = handler.create_tool_message(tool_call, result)
|
|
||||||
|
|
||||||
assert message["role"] == "tool"
|
|
||||||
assert message["content"][0]["function_response"]["response"]["result"] == result
|
|
||||||
assert message["content"][0]["function_response"]["call_id"] == "call_456"
|
|
||||||
|
|
||||||
def test_iterate_stream(self):
|
|
||||||
"""Test stream iteration."""
|
|
||||||
handler = OpenAILLMHandler()
|
|
||||||
|
|
||||||
# Mock streaming response
|
|
||||||
mock_chunks = ["chunk1", "chunk2", "chunk3"]
|
|
||||||
|
|
||||||
result = list(handler._iterate_stream(mock_chunks))
|
|
||||||
|
|
||||||
assert result == mock_chunks
|
|
||||||
|
|
||||||
def test_iterate_stream_empty(self):
|
|
||||||
"""Test stream iteration with empty response."""
|
|
||||||
handler = OpenAILLMHandler()
|
|
||||||
|
|
||||||
result = list(handler._iterate_stream([]))
|
|
||||||
|
|
||||||
assert result == []
|
|
||||||
|
|
||||||
def test_parse_response_tool_call_missing_attributes(self):
|
|
||||||
"""Test parsing tool calls with missing attributes."""
|
|
||||||
handler = OpenAILLMHandler()
|
|
||||||
|
|
||||||
# Mock tool call with missing attributes
|
|
||||||
mock_function = SimpleNamespace() # No name or arguments
|
|
||||||
mock_tool_call = SimpleNamespace(function=mock_function) # No id or index
|
|
||||||
|
|
||||||
mock_message = SimpleNamespace(content="", tool_calls=[mock_tool_call])
|
|
||||||
mock_response = SimpleNamespace(message=mock_message, finish_reason="tool_calls")
|
|
||||||
|
|
||||||
result = handler.parse_response(mock_response)
|
|
||||||
|
|
||||||
assert len(result.tool_calls) == 1
|
|
||||||
assert result.tool_calls[0].id == ""
|
|
||||||
assert result.tool_calls[0].name == ""
|
|
||||||
assert result.tool_calls[0].arguments == ""
|
|
||||||
assert result.tool_calls[0].index is None
|
|
||||||
68
tests/llm/test_anthropic.py
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
import unittest
|
||||||
|
from unittest.mock import patch, Mock
|
||||||
|
from application.llm.anthropic import AnthropicLLM
|
||||||
|
|
||||||
|
class TestAnthropicLLM(unittest.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.api_key = "TEST_API_KEY"
|
||||||
|
self.llm = AnthropicLLM(api_key=self.api_key)
|
||||||
|
|
||||||
|
@patch("application.llm.anthropic.settings")
|
||||||
|
def test_init_default_api_key(self, mock_settings):
|
||||||
|
mock_settings.ANTHROPIC_API_KEY = "DEFAULT_API_KEY"
|
||||||
|
llm = AnthropicLLM()
|
||||||
|
self.assertEqual(llm.api_key, "DEFAULT_API_KEY")
|
||||||
|
|
||||||
|
def test_gen(self):
|
||||||
|
messages = [
|
||||||
|
{"content": "context"},
|
||||||
|
{"content": "question"}
|
||||||
|
]
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.completion = "test completion"
|
||||||
|
|
||||||
|
with patch("application.cache.get_redis_instance") as mock_make_redis:
|
||||||
|
mock_redis_instance = mock_make_redis.return_value
|
||||||
|
mock_redis_instance.get.return_value = None
|
||||||
|
mock_redis_instance.set = Mock()
|
||||||
|
|
||||||
|
with patch.object(self.llm.anthropic.completions, "create", return_value=mock_response) as mock_create:
|
||||||
|
response = self.llm.gen("test_model", messages)
|
||||||
|
self.assertEqual(response, "test completion")
|
||||||
|
|
||||||
|
prompt_expected = "### Context \n context \n ### Question \n question"
|
||||||
|
mock_create.assert_called_with(
|
||||||
|
model="test_model",
|
||||||
|
max_tokens_to_sample=300,
|
||||||
|
stream=False,
|
||||||
|
prompt=f"{self.llm.HUMAN_PROMPT} {prompt_expected}{self.llm.AI_PROMPT}"
|
||||||
|
)
|
||||||
|
mock_redis_instance.set.assert_called_once()
|
||||||
|
|
||||||
|
def test_gen_stream(self):
|
||||||
|
messages = [
|
||||||
|
{"content": "context"},
|
||||||
|
{"content": "question"}
|
||||||
|
]
|
||||||
|
mock_responses = [Mock(completion="response_1"), Mock(completion="response_2")]
|
||||||
|
mock_tools = Mock()
|
||||||
|
|
||||||
|
with patch("application.cache.get_redis_instance") as mock_make_redis:
|
||||||
|
mock_redis_instance = mock_make_redis.return_value
|
||||||
|
mock_redis_instance.get.return_value = None
|
||||||
|
mock_redis_instance.set = Mock()
|
||||||
|
|
||||||
|
with patch.object(self.llm.anthropic.completions, "create", return_value=iter(mock_responses)) as mock_create:
|
||||||
|
responses = list(self.llm.gen_stream("test_model", messages, tools=mock_tools))
|
||||||
|
self.assertListEqual(responses, ["response_1", "response_2"])
|
||||||
|
|
||||||
|
prompt_expected = "### Context \n context \n ### Question \n question"
|
||||||
|
mock_create.assert_called_with(
|
||||||
|
model="test_model",
|
||||||
|
prompt=f"{self.llm.HUMAN_PROMPT} {prompt_expected}{self.llm.AI_PROMPT}",
|
||||||
|
max_tokens_to_sample=300,
|
||||||
|
stream=True
|
||||||
|
)
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
@@ -1,65 +0,0 @@
|
|||||||
import sys
|
|
||||||
import types
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
class _FakeCompletion:
|
|
||||||
def __init__(self, text):
|
|
||||||
self.completion = text
|
|
||||||
|
|
||||||
class _FakeCompletions:
|
|
||||||
def __init__(self):
|
|
||||||
self.last_kwargs = None
|
|
||||||
self._stream = [_FakeCompletion("s1"), _FakeCompletion("s2")]
|
|
||||||
|
|
||||||
def create(self, **kwargs):
|
|
||||||
self.last_kwargs = kwargs
|
|
||||||
if kwargs.get("stream"):
|
|
||||||
return self._stream
|
|
||||||
return _FakeCompletion("final")
|
|
||||||
|
|
||||||
class _FakeAnthropic:
|
|
||||||
def __init__(self, api_key=None):
|
|
||||||
self.api_key = api_key
|
|
||||||
self.completions = _FakeCompletions()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def patch_anthropic(monkeypatch):
|
|
||||||
fake = types.ModuleType("anthropic")
|
|
||||||
fake.Anthropic = _FakeAnthropic
|
|
||||||
fake.HUMAN_PROMPT = "<HUMAN>"
|
|
||||||
fake.AI_PROMPT = "<AI>"
|
|
||||||
sys.modules["anthropic"] = fake
|
|
||||||
yield
|
|
||||||
sys.modules.pop("anthropic", None)
|
|
||||||
|
|
||||||
|
|
||||||
def test_anthropic_raw_gen_builds_prompt_and_returns_completion():
|
|
||||||
from application.llm.anthropic import AnthropicLLM
|
|
||||||
|
|
||||||
llm = AnthropicLLM(api_key="k")
|
|
||||||
msgs = [
|
|
||||||
{"content": "ctx"},
|
|
||||||
{"content": "q"},
|
|
||||||
]
|
|
||||||
out = llm._raw_gen(llm, model="claude-2", messages=msgs, stream=False, max_tokens=55)
|
|
||||||
assert out == "final"
|
|
||||||
last = llm.anthropic.completions.last_kwargs
|
|
||||||
assert last["model"] == "claude-2"
|
|
||||||
assert last["max_tokens_to_sample"] == 55
|
|
||||||
assert last["prompt"].startswith("<HUMAN>") and last["prompt"].endswith("<AI>")
|
|
||||||
assert "### Context" in last["prompt"] and "### Question" in last["prompt"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_anthropic_raw_gen_stream_yields_chunks():
|
|
||||||
from application.llm.anthropic import AnthropicLLM
|
|
||||||
|
|
||||||
llm = AnthropicLLM(api_key="k")
|
|
||||||
msgs = [
|
|
||||||
{"content": "ctx"},
|
|
||||||
{"content": "q"},
|
|
||||||
]
|
|
||||||
gen = llm._raw_gen_stream(llm, model="claude", messages=msgs, stream=True, max_tokens=10)
|
|
||||||
chunks = list(gen)
|
|
||||||
assert chunks == ["s1", "s2"]
|
|
||||||
|
|
||||||
@@ -1,151 +0,0 @@
|
|||||||
import types
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from application.llm.google_ai import GoogleLLM
|
|
||||||
|
|
||||||
class _FakePart:
|
|
||||||
def __init__(self, text=None, function_call=None, file_data=None):
|
|
||||||
self.text = text
|
|
||||||
self.function_call = function_call
|
|
||||||
self.file_data = file_data
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_text(text):
|
|
||||||
return _FakePart(text=text)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_function_call(name, args):
|
|
||||||
return _FakePart(function_call=types.SimpleNamespace(name=name, args=args))
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_function_response(name, response):
|
|
||||||
# not used in assertions but present for completeness
|
|
||||||
return _FakePart(function_call=None, text=str(response))
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_uri(file_uri, mime_type):
|
|
||||||
# mimic presence of file data for streaming detection
|
|
||||||
return _FakePart(file_data=types.SimpleNamespace(file_uri=file_uri, mime_type=mime_type))
|
|
||||||
|
|
||||||
|
|
||||||
class _FakeContent:
|
|
||||||
def __init__(self, role, parts):
|
|
||||||
self.role = role
|
|
||||||
self.parts = parts
|
|
||||||
|
|
||||||
|
|
||||||
class FakeTypesModule:
|
|
||||||
Part = _FakePart
|
|
||||||
Content = _FakeContent
|
|
||||||
|
|
||||||
class GenerateContentConfig:
|
|
||||||
def __init__(self):
|
|
||||||
self.system_instruction = None
|
|
||||||
self.tools = None
|
|
||||||
self.response_schema = None
|
|
||||||
self.response_mime_type = None
|
|
||||||
|
|
||||||
|
|
||||||
class FakeModels:
|
|
||||||
def __init__(self):
|
|
||||||
self.last_args = None
|
|
||||||
self.last_kwargs = None
|
|
||||||
|
|
||||||
class _Resp:
|
|
||||||
def __init__(self, text=None, candidates=None):
|
|
||||||
self.text = text
|
|
||||||
self.candidates = candidates or []
|
|
||||||
|
|
||||||
def generate_content(self, *args, **kwargs):
|
|
||||||
self.last_args, self.last_kwargs = args, kwargs
|
|
||||||
return FakeModels._Resp(text="ok")
|
|
||||||
|
|
||||||
def generate_content_stream(self, *args, **kwargs):
|
|
||||||
self.last_args, self.last_kwargs = args, kwargs
|
|
||||||
# Simulate stream of text parts
|
|
||||||
part1 = types.SimpleNamespace(text="a", candidates=None)
|
|
||||||
part2 = types.SimpleNamespace(text="b", candidates=None)
|
|
||||||
return [part1, part2]
|
|
||||||
|
|
||||||
|
|
||||||
class FakeClient:
|
|
||||||
def __init__(self, *_, **__):
|
|
||||||
self.models = FakeModels()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def patch_google_modules(monkeypatch):
|
|
||||||
# Patch the types module used by GoogleLLM
|
|
||||||
import application.llm.google_ai as gmod
|
|
||||||
monkeypatch.setattr(gmod, "types", FakeTypesModule)
|
|
||||||
monkeypatch.setattr(gmod.genai, "Client", FakeClient)
|
|
||||||
|
|
||||||
|
|
||||||
def test_clean_messages_google_basic():
|
|
||||||
llm = GoogleLLM(api_key="key")
|
|
||||||
msgs = [
|
|
||||||
{"role": "assistant", "content": "hi"},
|
|
||||||
{"role": "user", "content": [
|
|
||||||
{"text": "hello"},
|
|
||||||
{"files": [{"file_uri": "gs://x", "mime_type": "image/png"}]},
|
|
||||||
{"function_call": {"name": "fn", "args": {"a": 1}}},
|
|
||||||
]},
|
|
||||||
]
|
|
||||||
cleaned = llm._clean_messages_google(msgs)
|
|
||||||
|
|
||||||
assert all(hasattr(c, "role") and hasattr(c, "parts") for c in cleaned)
|
|
||||||
assert any(c.role == "model" for c in cleaned)
|
|
||||||
assert any(hasattr(p, "text") for c in cleaned for p in c.parts)
|
|
||||||
|
|
||||||
|
|
||||||
def test_raw_gen_calls_google_client_and_returns_text():
|
|
||||||
llm = GoogleLLM(api_key="key")
|
|
||||||
msgs = [{"role": "user", "content": "hello"}]
|
|
||||||
out = llm._raw_gen(llm, model="gemini-2.0", messages=msgs, stream=False)
|
|
||||||
assert out == "ok"
|
|
||||||
|
|
||||||
|
|
||||||
def test_raw_gen_stream_yields_chunks():
|
|
||||||
llm = GoogleLLM(api_key="key")
|
|
||||||
msgs = [{"role": "user", "content": "hello"}]
|
|
||||||
gen = llm._raw_gen_stream(llm, model="gemini", messages=msgs, stream=True)
|
|
||||||
assert list(gen) == ["a", "b"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_prepare_structured_output_format_type_mapping():
|
|
||||||
llm = GoogleLLM(api_key="key")
|
|
||||||
schema = {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"a": {"type": "string"},
|
|
||||||
"b": {"type": "array", "items": {"type": "integer"}},
|
|
||||||
},
|
|
||||||
"required": ["a"],
|
|
||||||
}
|
|
||||||
out = llm.prepare_structured_output_format(schema)
|
|
||||||
assert out["type"] == "OBJECT"
|
|
||||||
assert out["properties"]["a"]["type"] == "STRING"
|
|
||||||
assert out["properties"]["b"]["type"] == "ARRAY"
|
|
||||||
|
|
||||||
|
|
||||||
def test_prepare_messages_with_attachments_appends_files(monkeypatch):
|
|
||||||
llm = GoogleLLM(api_key="key")
|
|
||||||
llm.storage = types.SimpleNamespace(
|
|
||||||
file_exists=lambda path: True,
|
|
||||||
process_file=lambda path, processor_func, **kwargs: "gs://file_uri"
|
|
||||||
)
|
|
||||||
monkeypatch.setattr(llm, "_upload_file_to_google", lambda att: "gs://file_uri")
|
|
||||||
|
|
||||||
messages = [{"role": "user", "content": "Hi"}]
|
|
||||||
attachments = [
|
|
||||||
{"path": "/tmp/img.png", "mime_type": "image/png"},
|
|
||||||
{"path": "/tmp/doc.pdf", "mime_type": "application/pdf"},
|
|
||||||
]
|
|
||||||
|
|
||||||
out = llm.prepare_messages_with_attachments(messages, attachments)
|
|
||||||
user_msg = next(m for m in out if m["role"] == "user")
|
|
||||||
assert isinstance(user_msg["content"], list)
|
|
||||||
files_entry = next((p for p in user_msg["content"] if isinstance(p, dict) and "files" in p), None)
|
|
||||||
assert files_entry is not None
|
|
||||||
assert isinstance(files_entry["files"], list) and len(files_entry["files"]) == 2
|
|
||||||
|
|
||||||
11
tests/llm/test_openai.py
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
import unittest
|
||||||
|
from application.llm.openai import OpenAILLM
|
||||||
|
|
||||||
|
class TestOpenAILLM(unittest.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.api_key = "test_api_key"
|
||||||
|
self.llm = OpenAILLM(self.api_key)
|
||||||
|
|
||||||
|
def test_init(self):
|
||||||
|
self.assertEqual(self.llm.api_key, self.api_key)
|
||||||
@@ -1,157 +0,0 @@
|
|||||||
import types
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from application.llm.openai import OpenAILLM
|
|
||||||
|
|
||||||
|
|
||||||
class FakeChatCompletions:
|
|
||||||
def __init__(self):
|
|
||||||
self.last_kwargs = None
|
|
||||||
|
|
||||||
class _Msg:
|
|
||||||
def __init__(self, content=None, tool_calls=None):
|
|
||||||
self.content = content
|
|
||||||
self.tool_calls = tool_calls
|
|
||||||
|
|
||||||
class _Delta:
|
|
||||||
def __init__(self, content=None):
|
|
||||||
self.content = content
|
|
||||||
|
|
||||||
class _Choice:
|
|
||||||
def __init__(self, content=None, delta=None, finish_reason="stop"):
|
|
||||||
self.message = FakeChatCompletions._Msg(content=content)
|
|
||||||
self.delta = FakeChatCompletions._Delta(content=delta)
|
|
||||||
self.finish_reason = finish_reason
|
|
||||||
|
|
||||||
class _StreamLine:
|
|
||||||
def __init__(self, deltas):
|
|
||||||
self.choices = [FakeChatCompletions._Choice(delta=d) for d in deltas]
|
|
||||||
|
|
||||||
class _Response:
|
|
||||||
def __init__(self, choices=None, lines=None):
|
|
||||||
self._choices = choices or []
|
|
||||||
self._lines = lines or []
|
|
||||||
|
|
||||||
@property
|
|
||||||
def choices(self):
|
|
||||||
return self._choices
|
|
||||||
|
|
||||||
def __iter__(self):
|
|
||||||
for line in self._lines:
|
|
||||||
yield line
|
|
||||||
|
|
||||||
def create(self, **kwargs):
|
|
||||||
self.last_kwargs = kwargs
|
|
||||||
# default non-streaming: return content
|
|
||||||
if not kwargs.get("stream"):
|
|
||||||
return FakeChatCompletions._Response(choices=[
|
|
||||||
FakeChatCompletions._Choice(content="hello world")
|
|
||||||
])
|
|
||||||
# streaming: yield line objects each with choices[0].delta.content
|
|
||||||
return FakeChatCompletions._Response(lines=[
|
|
||||||
FakeChatCompletions._StreamLine(["part1"]),
|
|
||||||
FakeChatCompletions._StreamLine(["part2"]),
|
|
||||||
])
|
|
||||||
|
|
||||||
|
|
||||||
class FakeClient:
|
|
||||||
def __init__(self):
|
|
||||||
self.chat = types.SimpleNamespace(completions=FakeChatCompletions())
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def openai_llm(monkeypatch):
|
|
||||||
llm = OpenAILLM(api_key="sk-test", user_api_key=None)
|
|
||||||
llm.storage = types.SimpleNamespace(
|
|
||||||
get_file=lambda path: types.SimpleNamespace(read=lambda: b"img"),
|
|
||||||
file_exists=lambda path: True,
|
|
||||||
process_file=lambda path, processor_func, **kwargs: "file_id_123",
|
|
||||||
)
|
|
||||||
llm.client = FakeClient()
|
|
||||||
return llm
|
|
||||||
|
|
||||||
|
|
||||||
def test_clean_messages_openai_variants(openai_llm):
|
|
||||||
messages = [
|
|
||||||
{"role": "system", "content": "sys"},
|
|
||||||
{"role": "model", "content": "asst"},
|
|
||||||
{"role": "user", "content": [
|
|
||||||
{"text": "hello"},
|
|
||||||
{"function_call": {"call_id": "c1", "name": "fn", "args": {"a": 1}}},
|
|
||||||
{"function_response": {"call_id": "c1", "name": "fn", "response": {"result": 42}}},
|
|
||||||
{"type": "image_url", "image_url": {"url": "data:image/png;base64,AAA"}},
|
|
||||||
]},
|
|
||||||
]
|
|
||||||
|
|
||||||
cleaned = openai_llm._clean_messages_openai(messages)
|
|
||||||
|
|
||||||
roles = [m["role"] for m in cleaned]
|
|
||||||
assert roles.count("assistant") >= 1
|
|
||||||
assert any(m["role"] == "tool" for m in cleaned)
|
|
||||||
|
|
||||||
assert any(isinstance(m["content"], list) and any(
|
|
||||||
part.get("type") == "image_url" for part in m["content"] if isinstance(part, dict)
|
|
||||||
) for m in cleaned if m["role"] == "user")
|
|
||||||
|
|
||||||
|
|
||||||
def test_raw_gen_calls_openai_client_and_returns_content(openai_llm):
|
|
||||||
msgs = [
|
|
||||||
{"role": "system", "content": "sys"},
|
|
||||||
{"role": "user", "content": "hello"},
|
|
||||||
]
|
|
||||||
content = openai_llm._raw_gen(openai_llm, model="gpt-4o", messages=msgs, stream=False)
|
|
||||||
assert content == "hello world"
|
|
||||||
|
|
||||||
passed = openai_llm.client.chat.completions.last_kwargs
|
|
||||||
assert passed["model"] == "gpt-4o"
|
|
||||||
assert isinstance(passed["messages"], list)
|
|
||||||
assert passed["stream"] is False
|
|
||||||
|
|
||||||
|
|
||||||
def test_raw_gen_stream_yields_chunks(openai_llm):
|
|
||||||
msgs = [
|
|
||||||
{"role": "user", "content": "hi"},
|
|
||||||
]
|
|
||||||
gen = openai_llm._raw_gen_stream(openai_llm, model="gpt", messages=msgs, stream=True)
|
|
||||||
chunks = list(gen)
|
|
||||||
assert "part1" in "".join(chunks)
|
|
||||||
assert "part2" in "".join(chunks)
|
|
||||||
|
|
||||||
|
|
||||||
def test_prepare_structured_output_format_enforces_required_and_strict(openai_llm):
|
|
||||||
schema = {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"a": {"type": "string"},
|
|
||||||
"b": {"type": "number"},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
result = openai_llm.prepare_structured_output_format(schema)
|
|
||||||
assert result["type"] == "json_schema"
|
|
||||||
js = result["json_schema"]
|
|
||||||
assert js["strict"] is True
|
|
||||||
assert set(js["schema"]["required"]) == {"a", "b"}
|
|
||||||
assert js["schema"]["additionalProperties"] is False
|
|
||||||
|
|
||||||
|
|
||||||
def test_prepare_messages_with_attachments_image_and_pdf(openai_llm, monkeypatch):
|
|
||||||
|
|
||||||
monkeypatch.setattr(openai_llm, "_get_base64_image", lambda att: "AAA=")
|
|
||||||
monkeypatch.setattr(openai_llm, "_upload_file_to_openai", lambda att: "file_xyz")
|
|
||||||
|
|
||||||
messages = [{"role": "user", "content": "Hi"}]
|
|
||||||
attachments = [
|
|
||||||
{"path": "/tmp/img.png", "mime_type": "image/png"},
|
|
||||||
{"path": "/tmp/doc.pdf", "mime_type": "application/pdf"},
|
|
||||||
]
|
|
||||||
out = openai_llm.prepare_messages_with_attachments(messages, attachments)
|
|
||||||
|
|
||||||
# last user message should have list content with text and two attachments
|
|
||||||
user_msg = next(m for m in out if m["role"] == "user")
|
|
||||||
assert isinstance(user_msg["content"], list)
|
|
||||||
types_in_content = [p.get("type") for p in user_msg["content"] if isinstance(p, dict)]
|
|
||||||
assert "image_url" in types_in_content or any(
|
|
||||||
isinstance(p, dict) and p.get("image_url") for p in user_msg["content"]
|
|
||||||
)
|
|
||||||
assert any(isinstance(p, dict) and p.get("file", {}).get("file_id") == "file_xyz" for p in user_msg["content"])
|
|
||||||
|
|
||||||
@@ -1,117 +0,0 @@
|
|||||||
import pytest
|
|
||||||
from pathlib import Path
|
|
||||||
from unittest.mock import patch, MagicMock
|
|
||||||
|
|
||||||
from application.parser.file.docs_parser import PDFParser, DocxParser
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def pdf_parser():
|
|
||||||
return PDFParser()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def docx_parser():
|
|
||||||
return DocxParser()
|
|
||||||
|
|
||||||
|
|
||||||
def test_pdf_init_parser():
|
|
||||||
parser = PDFParser()
|
|
||||||
assert isinstance(parser._init_parser(), dict)
|
|
||||||
assert not parser.parser_config_set
|
|
||||||
parser.init_parser()
|
|
||||||
assert parser.parser_config_set
|
|
||||||
|
|
||||||
|
|
||||||
def test_docx_init_parser():
|
|
||||||
parser = DocxParser()
|
|
||||||
assert isinstance(parser._init_parser(), dict)
|
|
||||||
assert not parser.parser_config_set
|
|
||||||
parser.init_parser()
|
|
||||||
assert parser.parser_config_set
|
|
||||||
|
|
||||||
|
|
||||||
@patch("application.parser.file.docs_parser.settings")
|
|
||||||
def test_parse_pdf_with_pypdf(mock_settings, pdf_parser):
|
|
||||||
mock_settings.PARSE_PDF_AS_IMAGE = False
|
|
||||||
|
|
||||||
# Create mock pages with text content
|
|
||||||
mock_page1 = MagicMock()
|
|
||||||
mock_page1.extract_text.return_value = "Test PDF content page 1"
|
|
||||||
mock_page2 = MagicMock()
|
|
||||||
mock_page2.extract_text.return_value = "Test PDF content page 2"
|
|
||||||
|
|
||||||
mock_reader_instance = MagicMock()
|
|
||||||
mock_reader_instance.pages = [mock_page1, mock_page2]
|
|
||||||
|
|
||||||
original_parse_file = pdf_parser.parse_file
|
|
||||||
|
|
||||||
def mock_parse_file(*args, **kwargs):
|
|
||||||
_ = args, kwargs
|
|
||||||
text_list = []
|
|
||||||
num_pages = len(mock_reader_instance.pages)
|
|
||||||
for page_index in range(num_pages):
|
|
||||||
page = mock_reader_instance.pages[page_index]
|
|
||||||
page_text = page.extract_text()
|
|
||||||
text_list.append(page_text)
|
|
||||||
text = "\n".join(text_list)
|
|
||||||
return text
|
|
||||||
|
|
||||||
pdf_parser.parse_file = mock_parse_file
|
|
||||||
|
|
||||||
try:
|
|
||||||
result = pdf_parser.parse_file(Path("test.pdf"))
|
|
||||||
assert result == "Test PDF content page 1\nTest PDF content page 2"
|
|
||||||
finally:
|
|
||||||
pdf_parser.parse_file = original_parse_file
|
|
||||||
|
|
||||||
|
|
||||||
@patch("application.parser.file.docs_parser.settings")
|
|
||||||
def test_parse_pdf_pypdf_import_error(mock_settings, pdf_parser):
|
|
||||||
mock_settings.PARSE_PDF_AS_IMAGE = False
|
|
||||||
|
|
||||||
original_parse_file = pdf_parser.parse_file
|
|
||||||
|
|
||||||
def mock_parse_file(*args, **kwargs):
|
|
||||||
_ = args, kwargs
|
|
||||||
raise ValueError("pypdf is required to read PDF files.")
|
|
||||||
|
|
||||||
pdf_parser.parse_file = mock_parse_file
|
|
||||||
|
|
||||||
try:
|
|
||||||
with pytest.raises(ValueError, match="pypdf is required to read PDF files"):
|
|
||||||
pdf_parser.parse_file(Path("test.pdf"))
|
|
||||||
finally:
|
|
||||||
pdf_parser.parse_file = original_parse_file
|
|
||||||
|
|
||||||
|
|
||||||
def test_parse_docx(docx_parser):
|
|
||||||
original_parse_file = docx_parser.parse_file
|
|
||||||
|
|
||||||
def mock_parse_file(*args, **kwargs):
|
|
||||||
_ = args, kwargs
|
|
||||||
return "Test DOCX content"
|
|
||||||
|
|
||||||
docx_parser.parse_file = mock_parse_file
|
|
||||||
|
|
||||||
try:
|
|
||||||
result = docx_parser.parse_file(Path("test.docx"))
|
|
||||||
assert result == "Test DOCX content"
|
|
||||||
finally:
|
|
||||||
docx_parser.parse_file = original_parse_file
|
|
||||||
|
|
||||||
|
|
||||||
def test_parse_docx_import_error(docx_parser):
|
|
||||||
original_parse_file = docx_parser.parse_file
|
|
||||||
|
|
||||||
def mock_parse_file(*args, **kwargs):
|
|
||||||
_ = args, kwargs
|
|
||||||
raise ValueError("docx2txt is required to read Microsoft Word files.")
|
|
||||||
|
|
||||||
docx_parser.parse_file = mock_parse_file
|
|
||||||
|
|
||||||
try:
|
|
||||||
with pytest.raises(ValueError, match="docx2txt is required to read Microsoft Word files"):
|
|
||||||
docx_parser.parse_file(Path("test.docx"))
|
|
||||||
finally:
|
|
||||||
docx_parser.parse_file = original_parse_file
|
|
||||||
@@ -1,152 +0,0 @@
|
|||||||
import pytest
|
|
||||||
from pathlib import Path
|
|
||||||
from unittest.mock import patch, MagicMock
|
|
||||||
import sys
|
|
||||||
import types
|
|
||||||
|
|
||||||
from application.parser.file.epub_parser import EpubParser
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def epub_parser():
|
|
||||||
return EpubParser()
|
|
||||||
|
|
||||||
|
|
||||||
def test_epub_init_parser():
|
|
||||||
parser = EpubParser()
|
|
||||||
assert isinstance(parser._init_parser(), dict)
|
|
||||||
assert not parser.parser_config_set
|
|
||||||
parser.init_parser()
|
|
||||||
assert parser.parser_config_set
|
|
||||||
|
|
||||||
|
|
||||||
def test_epub_parser_ebooklib_import_error(epub_parser):
|
|
||||||
"""Test that ImportError is raised when ebooklib is not available."""
|
|
||||||
with patch.dict(sys.modules, {"ebooklib": None}):
|
|
||||||
with pytest.raises(ValueError, match="`EbookLib` is required to read Epub files"):
|
|
||||||
epub_parser.parse_file(Path("test.epub"))
|
|
||||||
|
|
||||||
|
|
||||||
def test_epub_parser_html2text_import_error(epub_parser):
|
|
||||||
"""Test that ImportError is raised when html2text is not available."""
|
|
||||||
fake_ebooklib = types.ModuleType("ebooklib")
|
|
||||||
fake_epub = types.ModuleType("ebooklib.epub")
|
|
||||||
fake_ebooklib.epub = fake_epub
|
|
||||||
|
|
||||||
with patch.dict(sys.modules, {"ebooklib": fake_ebooklib, "ebooklib.epub": fake_epub}):
|
|
||||||
with patch.dict(sys.modules, {"html2text": None}):
|
|
||||||
with pytest.raises(ValueError, match="`html2text` is required to parse Epub files"):
|
|
||||||
epub_parser.parse_file(Path("test.epub"))
|
|
||||||
|
|
||||||
|
|
||||||
def test_epub_parser_successful_parsing(epub_parser):
|
|
||||||
"""Test successful parsing of an epub file."""
|
|
||||||
|
|
||||||
fake_ebooklib = types.ModuleType("ebooklib")
|
|
||||||
fake_epub = types.ModuleType("ebooklib.epub")
|
|
||||||
fake_html2text = types.ModuleType("html2text")
|
|
||||||
|
|
||||||
# Mock ebooklib constants
|
|
||||||
fake_ebooklib.ITEM_DOCUMENT = "document"
|
|
||||||
fake_ebooklib.epub = fake_epub
|
|
||||||
|
|
||||||
mock_item1 = MagicMock()
|
|
||||||
mock_item1.get_type.return_value = "document"
|
|
||||||
mock_item1.get_content.return_value = b"<h1>Chapter 1</h1><p>Content 1</p>"
|
|
||||||
|
|
||||||
mock_item2 = MagicMock()
|
|
||||||
mock_item2.get_type.return_value = "document"
|
|
||||||
mock_item2.get_content.return_value = b"<h1>Chapter 2</h1><p>Content 2</p>"
|
|
||||||
|
|
||||||
mock_item3 = MagicMock()
|
|
||||||
mock_item3.get_type.return_value = "other" # Should be ignored
|
|
||||||
mock_item3.get_content.return_value = b"<p>Other content</p>"
|
|
||||||
|
|
||||||
mock_book = MagicMock()
|
|
||||||
mock_book.get_items.return_value = [mock_item1, mock_item2, mock_item3]
|
|
||||||
|
|
||||||
fake_epub.read_epub = MagicMock(return_value=mock_book)
|
|
||||||
|
|
||||||
def mock_html2text_func(html_content):
|
|
||||||
if "Chapter 1" in html_content:
|
|
||||||
return "# Chapter 1\n\nContent 1\n"
|
|
||||||
elif "Chapter 2" in html_content:
|
|
||||||
return "# Chapter 2\n\nContent 2\n"
|
|
||||||
return "Other content\n"
|
|
||||||
|
|
||||||
fake_html2text.html2text = mock_html2text_func
|
|
||||||
|
|
||||||
with patch.dict(sys.modules, {
|
|
||||||
"ebooklib": fake_ebooklib,
|
|
||||||
"ebooklib.epub": fake_epub,
|
|
||||||
"html2text": fake_html2text
|
|
||||||
}):
|
|
||||||
result = epub_parser.parse_file(Path("test.epub"))
|
|
||||||
|
|
||||||
expected_result = "# Chapter 1\n\nContent 1\n\n# Chapter 2\n\nContent 2\n"
|
|
||||||
assert result == expected_result
|
|
||||||
|
|
||||||
# Verify epub.read_epub was called with correct parameters
|
|
||||||
fake_epub.read_epub.assert_called_once_with(Path("test.epub"), options={"ignore_ncx": True})
|
|
||||||
|
|
||||||
|
|
||||||
def test_epub_parser_empty_book(epub_parser):
|
|
||||||
"""Test parsing an epub file with no document items."""
|
|
||||||
# Create mock modules
|
|
||||||
fake_ebooklib = types.ModuleType("ebooklib")
|
|
||||||
fake_epub = types.ModuleType("ebooklib.epub")
|
|
||||||
fake_html2text = types.ModuleType("html2text")
|
|
||||||
|
|
||||||
fake_ebooklib.ITEM_DOCUMENT = "document"
|
|
||||||
fake_ebooklib.epub = fake_epub
|
|
||||||
|
|
||||||
# Create mock book with no document items
|
|
||||||
mock_book = MagicMock()
|
|
||||||
mock_book.get_items.return_value = []
|
|
||||||
|
|
||||||
fake_epub.read_epub = MagicMock(return_value=mock_book)
|
|
||||||
fake_html2text.html2text = MagicMock()
|
|
||||||
|
|
||||||
with patch.dict(sys.modules, {
|
|
||||||
"ebooklib": fake_ebooklib,
|
|
||||||
"ebooklib.epub": fake_epub,
|
|
||||||
"html2text": fake_html2text
|
|
||||||
}):
|
|
||||||
result = epub_parser.parse_file(Path("empty.epub"))
|
|
||||||
assert result == ""
|
|
||||||
|
|
||||||
fake_html2text.html2text.assert_not_called()
|
|
||||||
|
|
||||||
|
|
||||||
def test_epub_parser_non_document_items_ignored(epub_parser):
|
|
||||||
"""Test that non-document items are ignored during parsing."""
|
|
||||||
fake_ebooklib = types.ModuleType("ebooklib")
|
|
||||||
fake_epub = types.ModuleType("ebooklib.epub")
|
|
||||||
fake_html2text = types.ModuleType("html2text")
|
|
||||||
|
|
||||||
fake_ebooklib.ITEM_DOCUMENT = "document"
|
|
||||||
fake_ebooklib.epub = fake_epub
|
|
||||||
|
|
||||||
mock_doc_item = MagicMock()
|
|
||||||
mock_doc_item.get_type.return_value = "document"
|
|
||||||
mock_doc_item.get_content.return_value = b"<p>Document content</p>"
|
|
||||||
|
|
||||||
mock_other_item = MagicMock()
|
|
||||||
mock_other_item.get_type.return_value = "image" # Not a document
|
|
||||||
|
|
||||||
mock_book = MagicMock()
|
|
||||||
mock_book.get_items.return_value = [mock_other_item, mock_doc_item]
|
|
||||||
|
|
||||||
fake_epub.read_epub = MagicMock(return_value=mock_book)
|
|
||||||
fake_html2text.html2text = MagicMock(return_value="Document content\n")
|
|
||||||
|
|
||||||
with patch.dict(sys.modules, {
|
|
||||||
"ebooklib": fake_ebooklib,
|
|
||||||
"ebooklib.epub": fake_epub,
|
|
||||||
"html2text": fake_html2text
|
|
||||||
}):
|
|
||||||
result = epub_parser.parse_file(Path("test.epub"))
|
|
||||||
|
|
||||||
assert result == "Document content\n"
|
|
||||||
|
|
||||||
fake_html2text.html2text.assert_called_once_with("<p>Document content</p>")
|
|
||||||
@@ -1,43 +0,0 @@
|
|||||||
import pytest
|
|
||||||
from pathlib import Path
|
|
||||||
from unittest.mock import patch, MagicMock
|
|
||||||
|
|
||||||
import sys
|
|
||||||
import types
|
|
||||||
|
|
||||||
from application.parser.file.html_parser import HTMLParser
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def html_parser():
|
|
||||||
return HTMLParser()
|
|
||||||
|
|
||||||
|
|
||||||
def test_html_init_parser():
|
|
||||||
parser = HTMLParser()
|
|
||||||
assert isinstance(parser._init_parser(), dict)
|
|
||||||
assert not parser.parser_config_set
|
|
||||||
parser.init_parser()
|
|
||||||
assert parser.parser_config_set
|
|
||||||
|
|
||||||
|
|
||||||
def test_html_parser_parse_file():
|
|
||||||
parser = HTMLParser()
|
|
||||||
mock_doc = MagicMock()
|
|
||||||
mock_doc.page_content = "Extracted HTML content"
|
|
||||||
mock_doc.metadata = {"source": "test.html"}
|
|
||||||
|
|
||||||
fake_lc = types.ModuleType("langchain_community")
|
|
||||||
fake_dl = types.ModuleType("langchain_community.document_loaders")
|
|
||||||
|
|
||||||
bshtml_mock = MagicMock(return_value=MagicMock(load=MagicMock(return_value=[mock_doc])))
|
|
||||||
fake_dl.BSHTMLLoader = bshtml_mock
|
|
||||||
fake_lc.document_loaders = fake_dl
|
|
||||||
|
|
||||||
with patch.dict(sys.modules, {
|
|
||||||
"langchain_community": fake_lc,
|
|
||||||
"langchain_community.document_loaders": fake_dl,
|
|
||||||
}):
|
|
||||||
result = parser.parse_file(Path("test.html"))
|
|
||||||
assert result == [mock_doc]
|
|
||||||
bshtml_mock.assert_called_once_with(Path("test.html"))
|
|
||||||
@@ -1,41 +0,0 @@
|
|||||||
from pathlib import Path
|
|
||||||
from unittest.mock import patch, MagicMock, mock_open
|
|
||||||
|
|
||||||
from application.parser.file.image_parser import ImageParser
|
|
||||||
|
|
||||||
|
|
||||||
def test_image_init_parser():
|
|
||||||
parser = ImageParser()
|
|
||||||
assert isinstance(parser._init_parser(), dict)
|
|
||||||
assert not parser.parser_config_set
|
|
||||||
parser.init_parser()
|
|
||||||
assert parser.parser_config_set
|
|
||||||
|
|
||||||
|
|
||||||
@patch("application.parser.file.image_parser.settings")
|
|
||||||
def test_image_parser_remote_true(mock_settings):
|
|
||||||
mock_settings.PARSE_IMAGE_REMOTE = True
|
|
||||||
parser = ImageParser()
|
|
||||||
|
|
||||||
mock_response = MagicMock()
|
|
||||||
mock_response.json.return_value = {"markdown": "# From Image"}
|
|
||||||
|
|
||||||
with patch("application.parser.file.image_parser.requests.post", return_value=mock_response) as mock_post:
|
|
||||||
with patch("builtins.open", mock_open()):
|
|
||||||
result = parser.parse_file(Path("img.png"))
|
|
||||||
|
|
||||||
assert result == "# From Image"
|
|
||||||
mock_post.assert_called_once()
|
|
||||||
|
|
||||||
|
|
||||||
@patch("application.parser.file.image_parser.settings")
|
|
||||||
def test_image_parser_remote_false(mock_settings):
|
|
||||||
mock_settings.PARSE_IMAGE_REMOTE = False
|
|
||||||
parser = ImageParser()
|
|
||||||
|
|
||||||
with patch("application.parser.file.image_parser.requests.post") as mock_post:
|
|
||||||
result = parser.parse_file(Path("img.png"))
|
|
||||||
|
|
||||||
assert result == ""
|
|
||||||
mock_post.assert_not_called()
|
|
||||||
|
|
||||||
@@ -1,49 +0,0 @@
|
|||||||
from pathlib import Path
|
|
||||||
from unittest.mock import patch, mock_open
|
|
||||||
|
|
||||||
from application.parser.file.json_parser import JSONParser
|
|
||||||
|
|
||||||
|
|
||||||
def test_json_init_parser():
|
|
||||||
parser = JSONParser()
|
|
||||||
assert isinstance(parser._init_parser(), dict)
|
|
||||||
assert not parser.parser_config_set
|
|
||||||
parser.init_parser()
|
|
||||||
assert parser.parser_config_set
|
|
||||||
|
|
||||||
|
|
||||||
def test_json_parser_parses_dict_concat():
|
|
||||||
parser = JSONParser()
|
|
||||||
with patch("builtins.open", mock_open(read_data="{}")):
|
|
||||||
with patch("json.load", return_value={"a": 1}):
|
|
||||||
result = parser.parse_file(Path("t.json"))
|
|
||||||
assert result == "{'a': 1}"
|
|
||||||
|
|
||||||
|
|
||||||
def test_json_parser_parses_list_no_concat():
|
|
||||||
parser = JSONParser()
|
|
||||||
parser._concat_rows = False
|
|
||||||
data = [{"a": 1}, {"b": 2}]
|
|
||||||
with patch("builtins.open", mock_open(read_data="[]")):
|
|
||||||
with patch("json.load", return_value=data):
|
|
||||||
result = parser.parse_file(Path("t.json"))
|
|
||||||
assert result == data
|
|
||||||
|
|
||||||
|
|
||||||
def test_json_parser_row_joiner_config():
|
|
||||||
parser = JSONParser(row_joiner=" || ")
|
|
||||||
with patch("builtins.open", mock_open(read_data="[]")):
|
|
||||||
with patch("json.load", return_value=[{"a": 1}, {"b": 2}]):
|
|
||||||
result = parser.parse_file(Path("t.json"))
|
|
||||||
assert result == "{'a': 1} || {'b': 2}"
|
|
||||||
|
|
||||||
|
|
||||||
def test_json_parser_forwards_json_config():
|
|
||||||
def pf(s):
|
|
||||||
return 1.23
|
|
||||||
parser = JSONParser(json_config={"parse_float": pf})
|
|
||||||
with patch("builtins.open", mock_open(read_data="[]")):
|
|
||||||
with patch("json.load", return_value=[]) as mock_load:
|
|
||||||
parser.parse_file(Path("t.json"))
|
|
||||||
assert mock_load.call_args.kwargs.get("parse_float") is pf
|
|
||||||
|
|
||||||