Compare commits

..

1 Commits

Author SHA1 Message Date
dependabot[bot]
ec4b7da528 build(deps): bump the npm_and_yarn group across 2 directories with 5 updates
Bumps the npm_and_yarn group with 3 updates in the /docs directory: [cross-spawn](https://github.com/moxystudio/node-cross-spawn), [nextra](https://github.com/shuding/nextra) and [nextra-theme-docs](https://github.com/shuding/nextra).
Bumps the npm_and_yarn group with 1 update in the /extensions/react-widget directory: [dompurify](https://github.com/cure53/DOMPurify).


Updates `cross-spawn` from 7.0.3 to 7.0.6
- [Changelog](https://github.com/moxystudio/node-cross-spawn/blob/master/CHANGELOG.md)
- [Commits](https://github.com/moxystudio/node-cross-spawn/compare/v7.0.3...v7.0.6)

Updates `nextra` from 2.13.2 to 4.2.17
- [Release notes](https://github.com/shuding/nextra/releases)
- [Commits](https://github.com/shuding/nextra/compare/nextra@2.13.2...nextra@4.2.17)

Updates `nextra-theme-docs` from 2.13.2 to 4.2.17
- [Release notes](https://github.com/shuding/nextra/releases)
- [Commits](https://github.com/shuding/nextra/compare/nextra-theme-docs@2.13.2...nextra-theme-docs@4.2.17)

Updates `micromatch` from 4.0.5 to 4.0.8
- [Release notes](https://github.com/micromatch/micromatch/releases)
- [Changelog](https://github.com/micromatch/micromatch/blob/master/CHANGELOG.md)
- [Commits](https://github.com/micromatch/micromatch/compare/4.0.5...4.0.8)

Updates `dompurify` from 3.2.4 to 3.2.6
- [Release notes](https://github.com/cure53/DOMPurify/releases)
- [Commits](https://github.com/cure53/DOMPurify/compare/3.2.4...3.2.6)

---
updated-dependencies:
- dependency-name: cross-spawn
  dependency-version: 7.0.6
  dependency-type: indirect
  dependency-group: npm_and_yarn
- dependency-name: nextra
  dependency-version: 4.2.17
  dependency-type: direct:production
  dependency-group: npm_and_yarn
- dependency-name: nextra-theme-docs
  dependency-version: 4.2.17
  dependency-type: direct:production
  dependency-group: npm_and_yarn
- dependency-name: micromatch
  dependency-version: 4.0.8
  dependency-type: indirect
  dependency-group: npm_and_yarn
- dependency-name: dompurify
  dependency-version: 3.2.6
  dependency-type: direct:production
  dependency-group: npm_and_yarn
...

Signed-off-by: dependabot[bot] <support@github.com>
2025-06-12 00:16:46 +00:00
210 changed files with 8969 additions and 21843 deletions

2
.gitattributes vendored
View File

@@ -1,2 +0,0 @@
# Auto detect text files and perform LF normalization
* text=auto

View File

@@ -3,11 +3,11 @@
</h1>
<p align="center">
<strong>Private AI for agents, assistants and enterprise search</strong>
<strong>Open-Source RAG Assistant</strong>
</p>
<p align="left">
<strong><a href="https://www.docsgpt.cloud/">DocsGPT</a></strong> is an open-source AI platform for building intelligent agents and assistants. Features Agent Builder, deep research tools, document analysis (PDF, Office, web content), Multi-model support (choose your provider or run locally), and rich API connectivity for agents with actionable tools and integrations. Deploy anywhere with complete privacy control.
<strong><a href="https://www.docsgpt.cloud/">DocsGPT</a></strong> is an open-source genAI tool that helps users get reliable answers from any knowledge source, while avoiding hallucinations. It enables quick and reliable information retrieval, with tooling and agentic system capability built in.
</p>
<div align="center">
@@ -19,10 +19,10 @@
<a href="https://discord.gg/n5BX8dh8rU">![link to discord](https://img.shields.io/discord/1070046503302877216)</a>
<a href="https://twitter.com/docsgptai">![X (formerly Twitter) URL](https://img.shields.io/twitter/follow/docsgptai)</a>
<a href="https://docs.docsgpt.cloud/quickstart">⚡️ Quickstart</a><a href="https://app.docsgpt.cloud/">☁️ Cloud Version</a><a href="https://discord.gg/n5BX8dh8rU">💬 Discord</a>
<br>
<a href="https://docs.docsgpt.cloud/">📖 Documentation</a><a href="https://github.com/arc53/DocsGPT/blob/main/CONTRIBUTING.md">👫 Contribute</a><a href="https://blog.docsgpt.cloud/">🗞 Blog</a>
<br>
<a href="https://docs.docsgpt.cloud/quickstart">⚡️ Quickstart</a><a href="https://app.docsgpt.cloud/">☁️ Cloud Version</a><a href="https://discord.gg/n5BX8dh8rU">💬 Discord</a>
<br>
<a href="https://docs.docsgpt.cloud/">📖 Documentation</a><a href="https://github.com/arc53/DocsGPT/blob/main/CONTRIBUTING.md">👫 Contribute</a><a href="https://blog.docsgpt.cloud/">🗞 Blog</a>
<br>
</div>
<div align="center">
@@ -52,14 +52,8 @@
- [x] Chatbots menu re-design to handle tools, agent types, and more (April 2025)
- [x] New input box in the conversation menu (April 2025)
- [x] Add triggerable actions / tools (webhook) (April 2025)
- [x] Agent optimisations (May 2025)
- [x] Filesystem sources update (July 2025)
- [x] Json Responses (August 2025)
- [x] MCP support (August 2025)
- [x] Google Drive integration (September 2025)
- [ ] Add OAuth 2.0 authentication for MCP (September 2025)
- [ ] Sharepoint integration (October 2025)
- [ ] Deep Agents (October 2025)
- [ ] Anthropic Tool compatibility (May 2025)
- [ ] Add OAuth 2.0 authentication for tools and sources
- [ ] Agent scheduling
You can find our full roadmap [here](https://github.com/orgs/arc53/projects/2). Please don't hesitate to contribute or create issues, it helps us improve DocsGPT!
@@ -74,10 +68,11 @@ We're eager to provide personalized assistance when deploying your DocsGPT to a
## Join the Lighthouse Program 🌟
Calling all developers and GenAI innovators! The **DocsGPT Lighthouse Program** connects technical leaders actively deploying or extending DocsGPT in real-world scenarios. Collaborate directly with our team to shape the roadmap, access priority support, and build enterprise-ready solutions with exclusive community insights.
Calling all developers and GenAI innovators! The **DocsGPT Lighthouse Program** connects technical leaders actively deploying or extending DocsGPT in real-world scenarios. Collaborate directly with our team to shape the roadmap, access priority support, and build enterprise-ready solutions with exclusive community insights.
[Learn More & Apply →](https://docs.google.com/forms/d/1KAADiJinUJ8EMQyfTXUIGyFbqINNClNR3jBNWq7DgTE)
## QuickStart
> [!Note]
@@ -108,7 +103,7 @@ A more detailed [Quickstart](https://docs.docsgpt.cloud/quickstart) is available
PowerShell -ExecutionPolicy Bypass -File .\setup.ps1
```
Either script will guide you through setting up DocsGPT. Four options available: using the public API, running locally, connecting to a local inference engine, or using a cloud API provider. Scripts will automatically configure your `.env` file and handle necessary downloads and installations based on your chosen option.
Either script will guide you through setting up DocsGPT. Four options available: using the public API, running locally, connecting to a local inference engine, or using a cloud API provider. Scripts will automatically configure your `.env` file and handle necessary downloads and installations based on your chosen option.
**Navigate to http://localhost:5173/**
@@ -117,7 +112,6 @@ To stop DocsGPT, open a terminal in the `DocsGPT` directory and run:
```bash
docker compose -f deployment/docker-compose.yaml down
```
(or use the specific `docker compose down` command shown after running the setup script).
> [!Note]
@@ -145,6 +139,7 @@ Please refer to the [CONTRIBUTING.md](CONTRIBUTING.md) file for information abou
We as members, contributors, and leaders, pledge to make participation in our community a harassment-free experience for everyone, regardless of age, body size, visible or invisible disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, religion, or sexual identity and orientation. Please refer to the [CODE_OF_CONDUCT.md](CODE_OF_CONDUCT.md) file for more information about contributing.
## Many Thanks To Our Contributors⚡
<a href="https://github.com/arc53/DocsGPT/graphs/contributors" alt="View Contributors">

View File

@@ -1,4 +1,3 @@
import logging
import uuid
from abc import ABC, abstractmethod
from typing import Dict, Generator, List, Optional
@@ -7,15 +6,15 @@ from bson.objectid import ObjectId
from application.agents.tools.tool_action_parser import ToolActionParser
from application.agents.tools.tool_manager import ToolManager
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.llm.handlers.handler_creator import LLMHandlerCreator
from application.llm.llm_creator import LLMCreator
from application.logging import build_stack_data, log_activity, LogContext
from application.retriever.base import BaseRetriever
logger = logging.getLogger(__name__)
class BaseAgent(ABC):
def __init__(
@@ -29,7 +28,6 @@ class BaseAgent(ABC):
chat_history: Optional[List[Dict]] = None,
decoded_token: Optional[Dict] = None,
attachments: Optional[List[Dict]] = None,
json_schema: Optional[Dict] = None,
):
self.endpoint = endpoint
self.llm_name = llm_name
@@ -53,7 +51,6 @@ class BaseAgent(ABC):
llm_name if llm_name else "default"
)
self.attachments = attachments or []
self.json_schema = json_schema
@log_activity()
def gen(
@@ -94,8 +91,8 @@ class BaseAgent(ABC):
user_tools_collection = db["user_tools"]
user_tools = user_tools_collection.find({"user": user, "status": True})
user_tools = list(user_tools)
return {str(i): tool for i, tool in enumerate(user_tools)}
tools_by_id = {str(tool["_id"]): tool for tool in user_tools}
return tools_by_id
def _build_tool_parameters(self, action):
params = {"type": "object", "properties": {}, "required": []}
@@ -140,40 +137,6 @@ class BaseAgent(ABC):
tool_id, action_name, call_args = parser.parse_args(call)
call_id = getattr(call, "id", None) or str(uuid.uuid4())
# Check if parsing failed
if tool_id is None or action_name is None:
error_message = f"Error: Failed to parse LLM tool call. Tool name: {getattr(call, 'name', 'unknown')}"
logger.error(error_message)
tool_call_data = {
"tool_name": "unknown",
"call_id": call_id,
"action_name": getattr(call, "name", "unknown"),
"arguments": call_args or {},
"result": f"Failed to parse tool call. Invalid tool name format: {getattr(call, 'name', 'unknown')}",
}
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
self.tool_calls.append(tool_call_data)
return "Failed to parse tool call.", call_id
# Check if tool_id exists in available tools
if tool_id not in tools_dict:
error_message = f"Error: Tool ID '{tool_id}' extracted from LLM call not found in available tools_dict. Available IDs: {list(tools_dict.keys())}"
logger.error(error_message)
# Return error result
tool_call_data = {
"tool_name": "unknown",
"call_id": call_id,
"action_name": f"{action_name}_{tool_id}",
"arguments": call_args,
"result": f"Tool with ID {tool_id} not found. Available tools: {list(tools_dict.keys())}",
}
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
self.tool_calls.append(tool_call_data)
return f"Tool with ID {tool_id} not found.", call_id
tool_call_data = {
"tool_name": tools_dict[tool_id]["name"],
"call_id": call_id,
@@ -225,7 +188,6 @@ class BaseAgent(ABC):
if tool_data["name"] == "api_tool"
else tool_data["config"]
),
user_id=self.user, # Pass user ID for MCP tools credential decryption
)
if tool_data["name"] == "api_tool":
print(
@@ -264,15 +226,7 @@ class BaseAgent(ABC):
query: str,
retrieved_data: List[Dict],
) -> List[Dict]:
docs_with_filenames = []
for doc in retrieved_data:
filename = doc.get("filename") or doc.get("title") or doc.get("source")
if filename:
chunk_header = str(filename)
docs_with_filenames.append(f"{chunk_header}\n{doc['text']}")
else:
docs_with_filenames.append(doc["text"])
docs_together = "\n\n".join(docs_with_filenames)
docs_together = "\n".join([doc["text"] for doc in retrieved_data])
p_chat_combine = system_prompt.replace("{summaries}", docs_together)
messages_combine = [{"role": "system", "content": p_chat_combine}]
@@ -329,21 +283,6 @@ class BaseAgent(ABC):
and self.tools
):
gen_kwargs["tools"] = self.tools
if (
self.json_schema
and hasattr(self.llm, "_supports_structured_output")
and self.llm._supports_structured_output()
):
structured_format = self.llm.prepare_structured_output_format(
self.json_schema
)
if structured_format:
if self.llm_name == "openai":
gen_kwargs["response_format"] = structured_format
elif self.llm_name == "google":
gen_kwargs["response_schema"] = structured_format
resp = self.llm.gen_stream(**gen_kwargs)
if log_context:
@@ -368,42 +307,21 @@ class BaseAgent(ABC):
return resp
def _handle_response(self, response, tools_dict, messages, log_context):
is_structured_output = (
self.json_schema is not None
and hasattr(self.llm, "_supports_structured_output")
and self.llm._supports_structured_output()
)
if isinstance(response, str):
answer_data = {"answer": response}
if is_structured_output:
answer_data["structured"] = True
answer_data["schema"] = self.json_schema
yield answer_data
yield {"answer": response}
return
if hasattr(response, "message") and getattr(response.message, "content", None):
answer_data = {"answer": response.message.content}
if is_structured_output:
answer_data["structured"] = True
answer_data["schema"] = self.json_schema
yield answer_data
yield {"answer": response.message.content}
return
processed_response_gen = self._llm_handler(
response, tools_dict, messages, log_context, self.attachments
)
for event in processed_response_gen:
if isinstance(event, str):
answer_data = {"answer": event}
if is_structured_output:
answer_data["structured"] = True
answer_data["schema"] = self.json_schema
yield answer_data
yield {"answer": event}
elif hasattr(event, "message") and getattr(event.message, "content", None):
answer_data = {"answer": event.message.content}
if is_structured_output:
answer_data["structured"] = True
answer_data["schema"] = self.json_schema
yield answer_data
yield {"answer": event.message.content}
elif isinstance(event, dict) and "type" in event:
yield event

View File

@@ -8,7 +8,7 @@ logger = logging.getLogger(__name__)
class ClassicAgent(BaseAgent):
"""A simplified agent with clear execution flow.
"""A simplified classic agent with clear execution flow.
Usage:
1. Processes a query through retrieval

View File

@@ -25,35 +25,27 @@ class BraveSearchTool(Tool):
else:
raise ValueError(f"Unknown action: {action_name}")
def _web_search(
self,
query,
country="ALL",
search_lang="en",
count=10,
offset=0,
safesearch="off",
freshness=None,
result_filter=None,
extra_snippets=False,
summary=False,
):
def _web_search(self, query, country="ALL", search_lang="en", count=10,
offset=0, safesearch="off", freshness=None,
result_filter=None, extra_snippets=False, summary=False):
"""
Performs a web search using the Brave Search API.
"""
print(f"Performing Brave web search for: {query}")
url = f"{self.base_url}/web/search"
# Build query parameters
params = {
"q": query,
"country": country,
"search_lang": search_lang,
"count": min(count, 20),
"offset": min(offset, 9),
"safesearch": safesearch,
"safesearch": safesearch
}
# Add optional parameters only if they have values
if freshness:
params["freshness"] = freshness
if result_filter:
@@ -62,69 +54,68 @@ class BraveSearchTool(Tool):
params["extra_snippets"] = 1
if summary:
params["summary"] = 1
# Set up headers
headers = {
"Accept": "application/json",
"Accept-Encoding": "gzip",
"X-Subscription-Token": self.token,
"X-Subscription-Token": self.token
}
# Make the request
response = requests.get(url, params=params, headers=headers)
if response.status_code == 200:
return {
"status_code": response.status_code,
"results": response.json(),
"message": "Search completed successfully.",
"message": "Search completed successfully."
}
else:
return {
"status_code": response.status_code,
"message": f"Search failed with status code: {response.status_code}.",
"message": f"Search failed with status code: {response.status_code}."
}
def _image_search(
self,
query,
country="ALL",
search_lang="en",
count=5,
safesearch="off",
spellcheck=False,
):
def _image_search(self, query, country="ALL", search_lang="en", count=5,
safesearch="off", spellcheck=False):
"""
Performs an image search using the Brave Search API.
"""
print(f"Performing Brave image search for: {query}")
url = f"{self.base_url}/images/search"
# Build query parameters
params = {
"q": query,
"country": country,
"search_lang": search_lang,
"count": min(count, 100), # API max is 100
"safesearch": safesearch,
"spellcheck": 1 if spellcheck else 0,
"spellcheck": 1 if spellcheck else 0
}
# Set up headers
headers = {
"Accept": "application/json",
"Accept-Encoding": "gzip",
"X-Subscription-Token": self.token,
"X-Subscription-Token": self.token
}
# Make the request
response = requests.get(url, params=params, headers=headers)
if response.status_code == 200:
return {
"status_code": response.status_code,
"results": response.json(),
"message": "Image search completed successfully.",
"message": "Image search completed successfully."
}
else:
return {
"status_code": response.status_code,
"message": f"Image search failed with status code: {response.status_code}.",
"message": f"Image search failed with status code: {response.status_code}."
}
def get_actions_metadata(self):
@@ -139,14 +130,42 @@ class BraveSearchTool(Tool):
"type": "string",
"description": "The search query (max 400 characters, 50 words)",
},
# "country": {
# "type": "string",
# "description": "The 2-character country code (default: US)",
# },
"search_lang": {
"type": "string",
"description": "The search language preference (default: en)",
},
# "count": {
# "type": "integer",
# "description": "Number of results to return (max 20, default: 10)",
# },
# "offset": {
# "type": "integer",
# "description": "Pagination offset (max 9, default: 0)",
# },
# "safesearch": {
# "type": "string",
# "description": "Filter level for adult content (off, moderate, strict)",
# },
"freshness": {
"type": "string",
"description": "Time filter for results (pd: last 24h, pw: last week, pm: last month, py: last year)",
},
# "result_filter": {
# "type": "string",
# "description": "Comma-delimited list of result types to include",
# },
# "extra_snippets": {
# "type": "boolean",
# "description": "Get additional excerpts from result pages",
# },
# "summary": {
# "type": "boolean",
# "description": "Enable summary generation in search results",
# }
},
"required": ["query"],
"additionalProperties": False,
@@ -162,21 +181,37 @@ class BraveSearchTool(Tool):
"type": "string",
"description": "The search query (max 400 characters, 50 words)",
},
# "country": {
# "type": "string",
# "description": "The 2-character country code (default: US)",
# },
# "search_lang": {
# "type": "string",
# "description": "The search language preference (default: en)",
# },
"count": {
"type": "integer",
"description": "Number of results to return (max 100, default: 5)",
},
# "safesearch": {
# "type": "string",
# "description": "Filter level for adult content (off, strict). Default: strict",
# },
# "spellcheck": {
# "type": "boolean",
# "description": "Whether to spellcheck provided query (default: true)",
# }
},
"required": ["query"],
"additionalProperties": False,
},
},
}
]
def get_config_requirements(self):
return {
"token": {
"type": "string",
"description": "Brave Search API key for authentication",
"type": "string",
"description": "Brave Search API key for authentication"
},
}
}

View File

@@ -1,114 +0,0 @@
from application.agents.tools.base import Tool
from duckduckgo_search import DDGS
class DuckDuckGoSearchTool(Tool):
"""
DuckDuckGo Search
A tool for performing web and image searches using DuckDuckGo.
"""
def __init__(self, config):
self.config = config
def execute_action(self, action_name, **kwargs):
actions = {
"ddg_web_search": self._web_search,
"ddg_image_search": self._image_search,
}
if action_name in actions:
return actions[action_name](**kwargs)
else:
raise ValueError(f"Unknown action: {action_name}")
def _web_search(
self,
query,
max_results=5,
):
print(f"Performing DuckDuckGo web search for: {query}")
try:
results = DDGS().text(
query,
max_results=max_results,
)
return {
"status_code": 200,
"results": results,
"message": "Web search completed successfully.",
}
except Exception as e:
return {
"status_code": 500,
"message": f"Web search failed: {str(e)}",
}
def _image_search(
self,
query,
max_results=5,
):
print(f"Performing DuckDuckGo image search for: {query}")
try:
results = DDGS().images(
keywords=query,
max_results=max_results,
)
return {
"status_code": 200,
"results": results,
"message": "Image search completed successfully.",
}
except Exception as e:
return {
"status_code": 500,
"message": f"Image search failed: {str(e)}",
}
def get_actions_metadata(self):
return [
{
"name": "ddg_web_search",
"description": "Perform a web search using DuckDuckGo.",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Search query",
},
"max_results": {
"type": "integer",
"description": "Number of results to return (default: 5)",
},
},
"required": ["query"],
},
},
{
"name": "ddg_image_search",
"description": "Perform an image search using DuckDuckGo.",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Search query",
},
"max_results": {
"type": "integer",
"description": "Number of results to return (default: 5, max: 50)",
},
},
"required": ["query"],
},
},
]
def get_config_requirements(self):
return {}

View File

@@ -1,861 +0,0 @@
import asyncio
import base64
import json
import logging
import time
from typing import Any, Dict, List, Optional
from urllib.parse import parse_qs, urlparse
from application.agents.tools.base import Tool
from application.api.user.tasks import mcp_oauth_status_task, mcp_oauth_task
from application.cache import get_redis_instance
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.security.encryption import decrypt_credentials
from fastmcp import Client
from fastmcp.client.auth import BearerAuth
from fastmcp.client.transports import (
SSETransport,
StdioTransport,
StreamableHttpTransport,
)
from mcp.client.auth import OAuthClientProvider, TokenStorage
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken
from pydantic import AnyHttpUrl, ValidationError
from redis import Redis
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
_mcp_clients_cache = {}
class MCPTool(Tool):
"""
MCP Tool
Connect to remote Model Context Protocol (MCP) servers to access dynamic tools and resources. Supports various authentication methods and provides secure access to external services through the MCP protocol.
"""
def __init__(self, config: Dict[str, Any], user_id: Optional[str] = None):
"""
Initialize the MCP Tool with configuration.
Args:
config: Dictionary containing MCP server configuration:
- server_url: URL of the remote MCP server
- transport_type: Transport type (auto, sse, http, stdio)
- auth_type: Type of authentication (bearer, oauth, api_key, basic, none)
- encrypted_credentials: Encrypted credentials (if available)
- timeout: Request timeout in seconds (default: 30)
- headers: Custom headers for requests
- command: Command for STDIO transport
- args: Arguments for STDIO transport
- oauth_scopes: OAuth scopes for oauth auth type
- oauth_client_name: OAuth client name for oauth auth type
user_id: User ID for decrypting credentials (required if encrypted_credentials exist)
"""
self.config = config
self.user_id = user_id
self.server_url = config.get("server_url", "")
self.transport_type = config.get("transport_type", "auto")
self.auth_type = config.get("auth_type", "none")
self.timeout = config.get("timeout", 30)
self.custom_headers = config.get("headers", {})
self.auth_credentials = {}
if config.get("encrypted_credentials") and user_id:
self.auth_credentials = decrypt_credentials(
config["encrypted_credentials"], user_id
)
else:
self.auth_credentials = config.get("auth_credentials", {})
self.oauth_scopes = config.get("oauth_scopes", [])
self.oauth_task_id = config.get("oauth_task_id", None)
self.oauth_client_name = config.get("oauth_client_name", "DocsGPT-MCP")
self.redirect_uri = f"{settings.API_URL}/api/mcp_server/callback"
self.available_tools = []
self._cache_key = self._generate_cache_key()
self._client = None
# Only validate and setup if server_url is provided and not OAuth
if self.server_url and self.auth_type != "oauth":
self._setup_client()
def _generate_cache_key(self) -> str:
"""Generate a unique cache key for this MCP server configuration."""
auth_key = ""
if self.auth_type == "oauth":
scopes_str = ",".join(self.oauth_scopes) if self.oauth_scopes else "none"
auth_key = f"oauth:{self.oauth_client_name}:{scopes_str}"
elif self.auth_type in ["bearer"]:
token = self.auth_credentials.get(
"bearer_token", ""
) or self.auth_credentials.get("access_token", "")
auth_key = f"bearer:{token[:10]}..." if token else "bearer:none"
elif self.auth_type == "api_key":
api_key = self.auth_credentials.get("api_key", "")
auth_key = f"apikey:{api_key[:10]}..." if api_key else "apikey:none"
elif self.auth_type == "basic":
username = self.auth_credentials.get("username", "")
auth_key = f"basic:{username}"
else:
auth_key = "none"
return f"{self.server_url}#{self.transport_type}#{auth_key}"
def _setup_client(self):
"""Setup FastMCP client with proper transport and authentication."""
global _mcp_clients_cache
if self._cache_key in _mcp_clients_cache:
cached_data = _mcp_clients_cache[self._cache_key]
if time.time() - cached_data["created_at"] < 1800:
self._client = cached_data["client"]
return
else:
del _mcp_clients_cache[self._cache_key]
transport = self._create_transport()
auth = None
if self.auth_type == "oauth":
redis_client = get_redis_instance()
auth = DocsGPTOAuth(
mcp_url=self.server_url,
scopes=self.oauth_scopes,
redis_client=redis_client,
redirect_uri=self.redirect_uri,
task_id=self.oauth_task_id,
db=db,
user_id=self.user_id,
)
elif self.auth_type == "bearer":
token = self.auth_credentials.get(
"bearer_token", ""
) or self.auth_credentials.get("access_token", "")
if token:
auth = BearerAuth(token)
self._client = Client(transport, auth=auth)
_mcp_clients_cache[self._cache_key] = {
"client": self._client,
"created_at": time.time(),
}
def _create_transport(self):
"""Create appropriate transport based on configuration."""
headers = {"Content-Type": "application/json", "User-Agent": "DocsGPT-MCP/1.0"}
headers.update(self.custom_headers)
if self.auth_type == "api_key":
api_key = self.auth_credentials.get("api_key", "")
header_name = self.auth_credentials.get("api_key_header", "X-API-Key")
if api_key:
headers[header_name] = api_key
elif self.auth_type == "basic":
username = self.auth_credentials.get("username", "")
password = self.auth_credentials.get("password", "")
if username and password:
credentials = base64.b64encode(
f"{username}:{password}".encode()
).decode()
headers["Authorization"] = f"Basic {credentials}"
if self.transport_type == "auto":
if "sse" in self.server_url.lower() or self.server_url.endswith("/sse"):
transport_type = "sse"
else:
transport_type = "http"
else:
transport_type = self.transport_type
if transport_type == "sse":
headers.update({"Accept": "text/event-stream", "Cache-Control": "no-cache"})
return SSETransport(url=self.server_url, headers=headers)
elif transport_type == "http":
return StreamableHttpTransport(url=self.server_url, headers=headers)
elif transport_type == "stdio":
command = self.config.get("command", "python")
args = self.config.get("args", [])
env = self.auth_credentials if self.auth_credentials else None
return StdioTransport(command=command, args=args, env=env)
else:
return StreamableHttpTransport(url=self.server_url, headers=headers)
def _format_tools(self, tools_response) -> List[Dict]:
"""Format tools response to match expected format."""
if hasattr(tools_response, "tools"):
tools = tools_response.tools
elif isinstance(tools_response, list):
tools = tools_response
else:
tools = []
tools_dict = []
for tool in tools:
if hasattr(tool, "name"):
tool_dict = {
"name": tool.name,
"description": tool.description,
}
if hasattr(tool, "inputSchema"):
tool_dict["inputSchema"] = tool.inputSchema
tools_dict.append(tool_dict)
elif isinstance(tool, dict):
tools_dict.append(tool)
else:
if hasattr(tool, "model_dump"):
tools_dict.append(tool.model_dump())
else:
tools_dict.append({"name": str(tool), "description": ""})
return tools_dict
async def _execute_with_client(self, operation: str, *args, **kwargs):
"""Execute operation with FastMCP client."""
if not self._client:
raise Exception("FastMCP client not initialized")
async with self._client:
if operation == "ping":
return await self._client.ping()
elif operation == "list_tools":
tools_response = await self._client.list_tools()
self.available_tools = self._format_tools(tools_response)
return self.available_tools
elif operation == "call_tool":
tool_name = args[0]
tool_args = kwargs
return await self._client.call_tool(tool_name, tool_args)
elif operation == "list_resources":
return await self._client.list_resources()
elif operation == "list_prompts":
return await self._client.list_prompts()
else:
raise Exception(f"Unknown operation: {operation}")
def _run_async_operation(self, operation: str, *args, **kwargs):
"""Run async operation in sync context."""
try:
try:
loop = asyncio.get_running_loop()
import concurrent.futures
def run_in_thread():
new_loop = asyncio.new_event_loop()
asyncio.set_event_loop(new_loop)
try:
return new_loop.run_until_complete(
self._execute_with_client(operation, *args, **kwargs)
)
finally:
new_loop.close()
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(run_in_thread)
return future.result(timeout=self.timeout)
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(
self._execute_with_client(operation, *args, **kwargs)
)
finally:
loop.close()
except Exception as e:
print(f"Error occurred while running async operation: {e}")
raise
def discover_tools(self) -> List[Dict]:
"""
Discover available tools from the MCP server using FastMCP.
Returns:
List of tool definitions from the server
"""
if not self.server_url:
return []
if not self._client:
self._setup_client()
try:
tools = self._run_async_operation("list_tools")
self.available_tools = tools
return self.available_tools
except Exception as e:
raise Exception(f"Failed to discover tools from MCP server: {str(e)}")
def execute_action(self, action_name: str, **kwargs) -> Any:
"""
Execute an action on the remote MCP server using FastMCP.
Args:
action_name: Name of the action to execute
**kwargs: Parameters for the action
Returns:
Result from the MCP server
"""
if not self.server_url:
raise Exception("No MCP server configured")
if not self._client:
self._setup_client()
cleaned_kwargs = {}
for key, value in kwargs.items():
if value == "" or value is None:
continue
cleaned_kwargs[key] = value
try:
result = self._run_async_operation(
"call_tool", action_name, **cleaned_kwargs
)
return self._format_result(result)
except Exception as e:
raise Exception(f"Failed to execute action '{action_name}': {str(e)}")
def _format_result(self, result) -> Dict:
"""Format FastMCP result to match expected format."""
if hasattr(result, "content"):
content_list = []
for content_item in result.content:
if hasattr(content_item, "text"):
content_list.append({"type": "text", "text": content_item.text})
elif hasattr(content_item, "data"):
content_list.append({"type": "data", "data": content_item.data})
else:
content_list.append(
{"type": "unknown", "content": str(content_item)}
)
return {
"content": content_list,
"isError": getattr(result, "isError", False),
}
else:
return result
def test_connection(self) -> Dict:
"""
Test the connection to the MCP server and validate functionality.
Returns:
Dictionary with connection test results including tool count
"""
if not self.server_url:
return {
"success": False,
"message": "No MCP server URL configured",
"tools_count": 0,
"transport_type": self.transport_type,
"auth_type": self.auth_type,
"error_type": "ConfigurationError",
}
if not self._client:
self._setup_client()
try:
if self.auth_type == "oauth":
return self._test_oauth_connection()
else:
return self._test_regular_connection()
except Exception as e:
return {
"success": False,
"message": f"Connection failed: {str(e)}",
"tools_count": 0,
"transport_type": self.transport_type,
"auth_type": self.auth_type,
"error_type": type(e).__name__,
}
def _test_regular_connection(self) -> Dict:
"""Test connection for non-OAuth auth types."""
try:
self._run_async_operation("ping")
ping_success = True
except Exception:
ping_success = False
tools = self.discover_tools()
message = f"Successfully connected to MCP server. Found {len(tools)} tools."
if not ping_success:
message += " (Ping not supported, but tool discovery worked)"
return {
"success": True,
"message": message,
"tools_count": len(tools),
"transport_type": self.transport_type,
"auth_type": self.auth_type,
"ping_supported": ping_success,
"tools": [tool.get("name", "unknown") for tool in tools],
}
def _test_oauth_connection(self) -> Dict:
"""Test connection for OAuth auth type with proper async handling."""
try:
task = mcp_oauth_task.delay(config=self.config, user=self.user_id)
if not task:
raise Exception("Failed to start OAuth authentication")
return {
"success": True,
"requires_oauth": True,
"task_id": task.id,
"status": "pending",
"message": "OAuth flow started",
}
except Exception as e:
return {
"success": False,
"message": f"OAuth connection failed: {str(e)}",
"tools_count": 0,
"transport_type": self.transport_type,
"auth_type": self.auth_type,
"error_type": type(e).__name__,
}
def get_actions_metadata(self) -> List[Dict]:
"""
Get metadata for all available actions.
Returns:
List of action metadata dictionaries
"""
actions = []
for tool in self.available_tools:
input_schema = (
tool.get("inputSchema")
or tool.get("input_schema")
or tool.get("schema")
or tool.get("parameters")
)
parameters_schema = {
"type": "object",
"properties": {},
"required": [],
}
if input_schema:
if isinstance(input_schema, dict):
if "properties" in input_schema:
parameters_schema = {
"type": input_schema.get("type", "object"),
"properties": input_schema.get("properties", {}),
"required": input_schema.get("required", []),
}
for key in ["additionalProperties", "description"]:
if key in input_schema:
parameters_schema[key] = input_schema[key]
else:
parameters_schema["properties"] = input_schema
action = {
"name": tool.get("name", ""),
"description": tool.get("description", ""),
"parameters": parameters_schema,
}
actions.append(action)
return actions
def get_config_requirements(self) -> Dict:
"""Get configuration requirements for the MCP tool."""
return {
"server_url": {
"type": "string",
"description": "URL of the remote MCP server (e.g., https://api.example.com/mcp or https://docs.mcp.cloudflare.com/sse)",
"required": True,
},
"transport_type": {
"type": "string",
"description": "Transport type for connection",
"enum": ["auto", "sse", "http", "stdio"],
"default": "auto",
"required": False,
"help": {
"auto": "Automatically detect best transport",
"sse": "Server-Sent Events (for real-time streaming)",
"http": "HTTP streaming (recommended for production)",
"stdio": "Standard I/O (for local servers)",
},
},
"auth_type": {
"type": "string",
"description": "Authentication type",
"enum": ["none", "bearer", "oauth", "api_key", "basic"],
"default": "none",
"required": True,
"help": {
"none": "No authentication",
"bearer": "Bearer token authentication",
"oauth": "OAuth 2.1 authentication (with frontend integration)",
"api_key": "API key authentication",
"basic": "Basic authentication",
},
},
"auth_credentials": {
"type": "object",
"description": "Authentication credentials (varies by auth_type)",
"required": False,
"properties": {
"bearer_token": {
"type": "string",
"description": "Bearer token for bearer auth",
},
"access_token": {
"type": "string",
"description": "Access token for OAuth (if pre-obtained)",
},
"api_key": {
"type": "string",
"description": "API key for api_key auth",
},
"api_key_header": {
"type": "string",
"description": "Header name for API key (default: X-API-Key)",
},
"username": {
"type": "string",
"description": "Username for basic auth",
},
"password": {
"type": "string",
"description": "Password for basic auth",
},
},
},
"oauth_scopes": {
"type": "array",
"description": "OAuth scopes to request (for oauth auth_type)",
"items": {"type": "string"},
"required": False,
"default": [],
},
"oauth_client_name": {
"type": "string",
"description": "Client name for OAuth registration (for oauth auth_type)",
"default": "DocsGPT-MCP",
"required": False,
},
"headers": {
"type": "object",
"description": "Custom headers to send with requests",
"required": False,
},
"timeout": {
"type": "integer",
"description": "Request timeout in seconds",
"default": 30,
"minimum": 1,
"maximum": 300,
"required": False,
},
"command": {
"type": "string",
"description": "Command to run for STDIO transport (e.g., 'python')",
"required": False,
},
"args": {
"type": "array",
"description": "Arguments for STDIO command",
"items": {"type": "string"},
"required": False,
},
}
class DocsGPTOAuth(OAuthClientProvider):
"""
Custom OAuth handler for DocsGPT that uses frontend redirect instead of browser.
"""
def __init__(
self,
mcp_url: str,
redirect_uri: str,
redis_client: Redis | None = None,
redis_prefix: str = "mcp_oauth:",
task_id: str = None,
scopes: str | list[str] | None = None,
client_name: str = "DocsGPT-MCP",
user_id=None,
db=None,
additional_client_metadata: dict[str, Any] | None = None,
):
"""
Initialize custom OAuth client provider for DocsGPT.
Args:
mcp_url: Full URL to the MCP endpoint
redirect_uri: Custom redirect URI for DocsGPT frontend
redis_client: Redis client for storing auth state
redis_prefix: Prefix for Redis keys
task_id: Task ID for tracking auth status
scopes: OAuth scopes to request
client_name: Name for this client during registration
user_id: User ID for token storage
db: Database instance for token storage
additional_client_metadata: Extra fields for OAuthClientMetadata
"""
self.redirect_uri = redirect_uri
self.redis_client = redis_client
self.redis_prefix = redis_prefix
self.task_id = task_id
self.user_id = user_id
self.db = db
parsed_url = urlparse(mcp_url)
self.server_base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
if isinstance(scopes, list):
scopes = " ".join(scopes)
client_metadata = OAuthClientMetadata(
client_name=client_name,
redirect_uris=[AnyHttpUrl(redirect_uri)],
grant_types=["authorization_code", "refresh_token"],
response_types=["code"],
scope=scopes,
**(additional_client_metadata or {}),
)
storage = DBTokenStorage(
server_url=self.server_base_url, user_id=self.user_id, db_client=self.db
)
super().__init__(
server_url=self.server_base_url,
client_metadata=client_metadata,
storage=storage,
redirect_handler=self.redirect_handler,
callback_handler=self.callback_handler,
)
self.auth_url = None
self.extracted_state = None
def _process_auth_url(self, authorization_url: str) -> tuple[str, str]:
"""Process authorization URL to extract state"""
try:
parsed_url = urlparse(authorization_url)
query_params = parse_qs(parsed_url.query)
state_params = query_params.get("state", [])
if state_params:
state = state_params[0]
else:
raise ValueError("No state in auth URL")
return authorization_url, state
except Exception as e:
raise Exception(f"Failed to process auth URL: {e}")
async def redirect_handler(self, authorization_url: str) -> None:
"""Store auth URL and state in Redis for frontend to use."""
auth_url, state = self._process_auth_url(authorization_url)
logging.info(
"[DocsGPTOAuth] Processed auth_url: %s, state: %s", auth_url, state
)
self.auth_url = auth_url
self.extracted_state = state
if self.redis_client and self.extracted_state:
key = f"{self.redis_prefix}auth_url:{self.extracted_state}"
self.redis_client.setex(key, 600, auth_url)
logging.info("[DocsGPTOAuth] Stored auth_url in Redis: %s", key)
if self.task_id:
status_key = f"mcp_oauth_status:{self.task_id}"
status_data = {
"status": "requires_redirect",
"message": "OAuth authorization required",
"authorization_url": self.auth_url,
"state": self.extracted_state,
"requires_oauth": True,
"task_id": self.task_id,
}
self.redis_client.setex(status_key, 600, json.dumps(status_data))
async def callback_handler(self) -> tuple[str, str | None]:
"""Wait for auth code from Redis using the state value."""
if not self.redis_client or not self.extracted_state:
raise Exception("Redis client or state not configured for OAuth")
poll_interval = 1
max_wait_time = 300
code_key = f"{self.redis_prefix}code:{self.extracted_state}"
if self.task_id:
status_key = f"mcp_oauth_status:{self.task_id}"
status_data = {
"status": "awaiting_callback",
"message": "Waiting for OAuth callback...",
"authorization_url": self.auth_url,
"state": self.extracted_state,
"requires_oauth": True,
"task_id": self.task_id,
}
self.redis_client.setex(status_key, 600, json.dumps(status_data))
start_time = time.time()
while time.time() - start_time < max_wait_time:
code_data = self.redis_client.get(code_key)
if code_data:
code = code_data.decode()
returned_state = self.extracted_state
self.redis_client.delete(code_key)
self.redis_client.delete(
f"{self.redis_prefix}auth_url:{self.extracted_state}"
)
self.redis_client.delete(
f"{self.redis_prefix}state:{self.extracted_state}"
)
if self.task_id:
status_data = {
"status": "callback_received",
"message": "OAuth callback received, completing authentication...",
"task_id": self.task_id,
}
self.redis_client.setex(status_key, 600, json.dumps(status_data))
return code, returned_state
error_key = f"{self.redis_prefix}error:{self.extracted_state}"
error_data = self.redis_client.get(error_key)
if error_data:
error_msg = error_data.decode()
self.redis_client.delete(error_key)
self.redis_client.delete(
f"{self.redis_prefix}auth_url:{self.extracted_state}"
)
self.redis_client.delete(
f"{self.redis_prefix}state:{self.extracted_state}"
)
raise Exception(f"OAuth error: {error_msg}")
await asyncio.sleep(poll_interval)
self.redis_client.delete(f"{self.redis_prefix}auth_url:{self.extracted_state}")
self.redis_client.delete(f"{self.redis_prefix}state:{self.extracted_state}")
raise Exception("OAuth callback timeout: no code received within 5 minutes")
class DBTokenStorage(TokenStorage):
def __init__(self, server_url: str, user_id: str, db_client):
self.server_url = server_url
self.user_id = user_id
self.db_client = db_client
self.collection = db_client["connector_sessions"]
@staticmethod
def get_base_url(url: str) -> str:
parsed = urlparse(url)
return f"{parsed.scheme}://{parsed.netloc}"
def get_db_key(self) -> dict:
return {
"server_url": self.get_base_url(self.server_url),
"user_id": self.user_id,
}
async def get_tokens(self) -> OAuthToken | None:
doc = await asyncio.to_thread(self.collection.find_one, self.get_db_key())
if not doc or "tokens" not in doc:
return None
try:
tokens = OAuthToken.model_validate(doc["tokens"])
return tokens
except ValidationError as e:
logging.error(f"Could not load tokens: {e}")
return None
async def set_tokens(self, tokens: OAuthToken) -> None:
await asyncio.to_thread(
self.collection.update_one,
self.get_db_key(),
{"$set": {"tokens": tokens.model_dump()}},
True,
)
logging.info(f"Saved tokens for {self.get_base_url(self.server_url)}")
async def get_client_info(self) -> OAuthClientInformationFull | None:
doc = await asyncio.to_thread(self.collection.find_one, self.get_db_key())
if not doc or "client_info" not in doc:
return None
try:
client_info = OAuthClientInformationFull.model_validate(doc["client_info"])
tokens = await self.get_tokens()
if tokens is None:
logging.debug(
"No tokens found, clearing client info to force fresh registration."
)
await asyncio.to_thread(
self.collection.update_one,
self.get_db_key(),
{"$unset": {"client_info": ""}},
)
return None
return client_info
except ValidationError as e:
logging.error(f"Could not load client info: {e}")
return None
def _serialize_client_info(self, info: dict) -> dict:
if "redirect_uris" in info and isinstance(info["redirect_uris"], list):
info["redirect_uris"] = [str(u) for u in info["redirect_uris"]]
return info
async def set_client_info(self, client_info: OAuthClientInformationFull) -> None:
serialized_info = self._serialize_client_info(client_info.model_dump())
await asyncio.to_thread(
self.collection.update_one,
self.get_db_key(),
{"$set": {"client_info": serialized_info}},
True,
)
logging.info(f"Saved client info for {self.get_base_url(self.server_url)}")
async def clear(self) -> None:
await asyncio.to_thread(self.collection.delete_one, self.get_db_key())
logging.info(f"Cleared OAuth cache for {self.get_base_url(self.server_url)}")
@classmethod
async def clear_all(cls, db_client) -> None:
collection = db_client["connector_sessions"]
await asyncio.to_thread(collection.delete_many, {})
logging.info("Cleared all OAuth client cache data.")
class MCPOAuthManager:
"""Manager for handling MCP OAuth callbacks."""
def __init__(self, redis_client: Redis | None, redis_prefix: str = "mcp_oauth:"):
self.redis_client = redis_client
self.redis_prefix = redis_prefix
def handle_oauth_callback(
self, state: str, code: str, error: Optional[str] = None
) -> bool:
"""
Handle OAuth callback from provider.
Args:
state: The state parameter from OAuth callback
code: The authorization code from OAuth callback
error: Error message if OAuth failed
Returns:
True if successful, False otherwise
"""
try:
if not self.redis_client or not state:
raise Exception("Redis client or state not provided")
if error:
error_key = f"{self.redis_prefix}error:{state}"
self.redis_client.setex(error_key, 300, error)
raise Exception(f"OAuth error received: {error}")
code_key = f"{self.redis_prefix}code:{state}"
self.redis_client.setex(code_key, 300, code)
state_key = f"{self.redis_prefix}state:{state}"
self.redis_client.setex(state_key, 300, "completed")
return True
except Exception as e:
logging.error(f"Error handling OAuth callback: {e}")
return False
def get_oauth_status(self, task_id: str) -> Dict[str, Any]:
"""Get current status of OAuth flow using provided task_id."""
if not task_id:
return {"status": "not_started", "message": "OAuth flow not started"}
return mcp_oauth_status_task(task_id)

View File

@@ -19,20 +19,8 @@ class ToolActionParser:
def _parse_openai_llm(self, call):
try:
call_args = json.loads(call.arguments)
tool_parts = call.name.split("_")
# If the tool name doesn't contain an underscore, it's likely a hallucinated tool
if len(tool_parts) < 2:
logger.warning(f"Invalid tool name format: {call.name}. Expected format: action_name_tool_id")
return None, None, None
tool_id = tool_parts[-1]
action_name = "_".join(tool_parts[:-1])
# Validate that tool_id looks like a numerical ID
if not tool_id.isdigit():
logger.warning(f"Tool ID '{tool_id}' is not numerical. This might be a hallucinated tool call.")
tool_id = call.name.split("_")[-1]
action_name = call.name.rsplit("_", 1)[0]
except (AttributeError, TypeError) as e:
logger.error(f"Error parsing OpenAI LLM call: {e}")
return None, None, None
@@ -41,20 +29,8 @@ class ToolActionParser:
def _parse_google_llm(self, call):
try:
call_args = call.arguments
tool_parts = call.name.split("_")
# If the tool name doesn't contain an underscore, it's likely a hallucinated tool
if len(tool_parts) < 2:
logger.warning(f"Invalid tool name format: {call.name}. Expected format: action_name_tool_id")
return None, None, None
tool_id = tool_parts[-1]
action_name = "_".join(tool_parts[:-1])
# Validate that tool_id looks like a numerical ID
if not tool_id.isdigit():
logger.warning(f"Tool ID '{tool_id}' is not numerical. This might be a hallucinated tool call.")
tool_id = call.name.split("_")[-1]
action_name = call.name.rsplit("_", 1)[0]
except (AttributeError, TypeError) as e:
logger.error(f"Error parsing Google LLM call: {e}")
return None, None, None

View File

@@ -23,23 +23,16 @@ class ToolManager:
tool_config = self.config.get(name, {})
self.tools[name] = obj(tool_config)
def load_tool(self, tool_name, tool_config, user_id=None):
def load_tool(self, tool_name, tool_config):
self.config[tool_name] = tool_config
module = importlib.import_module(f"application.agents.tools.{tool_name}")
for member_name, obj in inspect.getmembers(module, inspect.isclass):
if issubclass(obj, Tool) and obj is not Tool:
if tool_name == "mcp_tool" and user_id:
return obj(tool_config, user_id)
else:
return obj(tool_config)
return obj(tool_config)
def execute_action(self, tool_name, action_name, user_id=None, **kwargs):
def execute_action(self, tool_name, action_name, **kwargs):
if tool_name not in self.tools:
raise ValueError(f"Tool '{tool_name}' not loaded")
if tool_name == "mcp_tool" and user_id:
tool_config = self.config.get(tool_name, {})
tool = self.load_tool(tool_name, tool_config, user_id)
return tool.execute_action(action_name, **kwargs)
return self.tools[tool_name].execute_action(action_name, **kwargs)
def get_all_actions_metadata(self):

View File

@@ -1,7 +0,0 @@
from flask_restx import Api
api = Api(
version="1.0",
title="DocsGPT API",
description="API for DocsGPT",
)

View File

@@ -1,19 +0,0 @@
from flask import Blueprint
from application.api import api
from application.api.answer.routes.answer import AnswerResource
from application.api.answer.routes.base import answer_ns
from application.api.answer.routes.stream import StreamResource
answer = Blueprint("answer", __name__)
api.add_namespace(answer_ns)
def init_answer_routes():
api.add_resource(StreamResource, "/stream")
api.add_resource(AnswerResource, "/api/answer")
init_answer_routes()

View File

@@ -0,0 +1,914 @@
import asyncio
import datetime
import json
import logging
import os
import traceback
from bson.dbref import DBRef
from bson.objectid import ObjectId
from flask import Blueprint, make_response, request, Response
from flask_restx import fields, Namespace, Resource
from application.agents.agent_creator import AgentCreator
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.error import bad_request
from application.extensions import api
from application.llm.llm_creator import LLMCreator
from application.retriever.retriever_creator import RetrieverCreator
from application.utils import check_required_fields, limit_chat_history
logger = logging.getLogger(__name__)
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
conversations_collection = db["conversations"]
sources_collection = db["sources"]
prompts_collection = db["prompts"]
agents_collection = db["agents"]
user_logs_collection = db["user_logs"]
attachments_collection = db["attachments"]
answer = Blueprint("answer", __name__)
answer_ns = Namespace("answer", description="Answer related operations", path="/")
api.add_namespace(answer_ns)
gpt_model = ""
# to have some kind of default behaviour
if settings.LLM_PROVIDER == "openai":
gpt_model = "gpt-4o-mini"
elif settings.LLM_PROVIDER == "anthropic":
gpt_model = "claude-2"
elif settings.LLM_PROVIDER == "groq":
gpt_model = "llama3-8b-8192"
elif settings.LLM_PROVIDER == "novita":
gpt_model = "deepseek/deepseek-r1"
if settings.LLM_NAME: # in case there is particular model name configured
gpt_model = settings.LLM_NAME
# load the prompts
current_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
with open(os.path.join(current_dir, "prompts", "chat_combine_default.txt"), "r") as f:
chat_combine_template = f.read()
with open(os.path.join(current_dir, "prompts", "chat_reduce_prompt.txt"), "r") as f:
chat_reduce_template = f.read()
with open(os.path.join(current_dir, "prompts", "chat_combine_creative.txt"), "r") as f:
chat_combine_creative = f.read()
with open(os.path.join(current_dir, "prompts", "chat_combine_strict.txt"), "r") as f:
chat_combine_strict = f.read()
api_key_set = settings.API_KEY is not None
embeddings_key_set = settings.EMBEDDINGS_KEY is not None
async def async_generate(chain, question, chat_history):
result = await chain.arun({"question": question, "chat_history": chat_history})
return result
def run_async_chain(chain, question, chat_history):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
result = {}
try:
answer = loop.run_until_complete(async_generate(chain, question, chat_history))
finally:
loop.close()
result["answer"] = answer
return result
def get_agent_key(agent_id, user_id):
if not agent_id:
return None, False, None
try:
agent = agents_collection.find_one({"_id": ObjectId(agent_id)})
if agent is None:
raise Exception("Agent not found", 404)
is_owner = agent.get("user") == user_id
if is_owner:
agents_collection.update_one(
{"_id": ObjectId(agent_id)},
{"$set": {"lastUsedAt": datetime.datetime.now(datetime.timezone.utc)}},
)
return str(agent["key"]), False, None
is_shared_with_user = agent.get(
"shared_publicly", False
) or user_id in agent.get("shared_with", [])
if is_shared_with_user:
return str(agent["key"]), True, agent.get("shared_token")
raise Exception("Unauthorized access to the agent", 403)
except Exception as e:
logger.error(f"Error in get_agent_key: {str(e)}", exc_info=True)
raise
def get_data_from_api_key(api_key):
data = agents_collection.find_one({"key": api_key})
if not data:
raise Exception("Invalid API Key, please generate a new key", 401)
source = data.get("source")
if isinstance(source, DBRef):
source_doc = db.dereference(source)
data["source"] = str(source_doc["_id"])
data["retriever"] = source_doc.get("retriever", data.get("retriever"))
else:
data["source"] = {}
return data
def get_retriever(source_id: str):
doc = sources_collection.find_one({"_id": ObjectId(source_id)})
if doc is None:
raise Exception("Source document does not exist", 404)
retriever_name = None if "retriever" not in doc else doc["retriever"]
return retriever_name
def is_azure_configured():
return (
settings.OPENAI_API_BASE
and settings.OPENAI_API_VERSION
and settings.AZURE_DEPLOYMENT_NAME
)
def save_conversation(
conversation_id,
question,
response,
thought,
source_log_docs,
tool_calls,
llm,
decoded_token,
index=None,
api_key=None,
agent_id=None,
is_shared_usage=False,
shared_token=None,
attachment_ids=None,
):
current_time = datetime.datetime.now(datetime.timezone.utc)
if conversation_id is not None and index is not None:
conversations_collection.update_one(
{"_id": ObjectId(conversation_id), f"queries.{index}": {"$exists": True}},
{
"$set": {
f"queries.{index}.prompt": question,
f"queries.{index}.response": response,
f"queries.{index}.thought": thought,
f"queries.{index}.sources": source_log_docs,
f"queries.{index}.tool_calls": tool_calls,
f"queries.{index}.timestamp": current_time,
f"queries.{index}.attachments": attachment_ids,
}
},
)
##remove following queries from the array
conversations_collection.update_one(
{"_id": ObjectId(conversation_id), f"queries.{index}": {"$exists": True}},
{"$push": {"queries": {"$each": [], "$slice": index + 1}}},
)
elif conversation_id is not None and conversation_id != "None":
conversations_collection.update_one(
{"_id": ObjectId(conversation_id)},
{
"$push": {
"queries": {
"prompt": question,
"response": response,
"thought": thought,
"sources": source_log_docs,
"tool_calls": tool_calls,
"timestamp": current_time,
"attachments": attachment_ids,
}
}
},
)
else:
# create new conversation
# generate summary
messages_summary = [
{
"role": "assistant",
"content": "Summarise following conversation in no more than 3 "
"words, respond ONLY with the summary, use the same "
"language as the system",
},
{
"role": "user",
"content": "Summarise following conversation in no more than 3 words, "
"respond ONLY with the summary, use the same language as the "
"system \n\nUser: " + question + "\n\n" + "AI: " + response,
},
]
completion = llm.gen(model=gpt_model, messages=messages_summary, max_tokens=30)
conversation_data = {
"user": decoded_token.get("sub"),
"date": datetime.datetime.utcnow(),
"name": completion,
"queries": [
{
"prompt": question,
"response": response,
"thought": thought,
"sources": source_log_docs,
"tool_calls": tool_calls,
"timestamp": current_time,
"attachments": attachment_ids,
}
],
}
if api_key:
if agent_id:
conversation_data["agent_id"] = agent_id
if is_shared_usage:
conversation_data["is_shared_usage"] = is_shared_usage
conversation_data["shared_token"] = shared_token
api_key_doc = agents_collection.find_one({"key": api_key})
if api_key_doc:
conversation_data["api_key"] = api_key_doc["key"]
conversation_id = conversations_collection.insert_one(
conversation_data
).inserted_id
return conversation_id
def get_prompt(prompt_id):
if prompt_id == "default":
prompt = chat_combine_template
elif prompt_id == "creative":
prompt = chat_combine_creative
elif prompt_id == "strict":
prompt = chat_combine_strict
else:
prompt = prompts_collection.find_one({"_id": ObjectId(prompt_id)})["content"]
return prompt
def complete_stream(
question,
agent,
retriever,
conversation_id,
user_api_key,
decoded_token,
isNoneDoc=False,
index=None,
should_save_conversation=True,
attachment_ids=None,
agent_id=None,
is_shared_usage=False,
shared_token=None,
):
try:
response_full, thought, source_log_docs, tool_calls = "", "", [], []
answer = agent.gen(query=question, retriever=retriever)
for line in answer:
if "answer" in line:
response_full += str(line["answer"])
data = json.dumps({"type": "answer", "answer": line["answer"]})
yield f"data: {data}\n\n"
elif "sources" in line:
truncated_sources = []
source_log_docs = line["sources"]
for source in line["sources"]:
truncated_source = source.copy()
if "text" in truncated_source:
truncated_source["text"] = (
truncated_source["text"][:100].strip() + "..."
)
truncated_sources.append(truncated_source)
if len(truncated_sources) > 0:
data = json.dumps({"type": "source", "source": truncated_sources})
yield f"data: {data}\n\n"
elif "tool_calls" in line:
tool_calls = line["tool_calls"]
elif "thought" in line:
thought += line["thought"]
data = json.dumps({"type": "thought", "thought": line["thought"]})
yield f"data: {data}\n\n"
elif "type" in line:
data = json.dumps(line)
yield f"data: {data}\n\n"
if isNoneDoc:
for doc in source_log_docs:
doc["source"] = "None"
llm = LLMCreator.create_llm(
settings.LLM_PROVIDER,
api_key=settings.API_KEY,
user_api_key=user_api_key,
decoded_token=decoded_token,
)
if should_save_conversation:
conversation_id = save_conversation(
conversation_id,
question,
response_full,
thought,
source_log_docs,
tool_calls,
llm,
decoded_token,
index,
api_key=user_api_key,
attachment_ids=attachment_ids,
agent_id=agent_id,
is_shared_usage=is_shared_usage,
shared_token=shared_token,
)
else:
conversation_id = None
# send data.type = "end" to indicate that the stream has ended as json
data = json.dumps({"type": "id", "id": str(conversation_id)})
yield f"data: {data}\n\n"
retriever_params = retriever.get_params()
user_logs_collection.insert_one(
{
"action": "stream_answer",
"level": "info",
"user": decoded_token.get("sub"),
"api_key": user_api_key,
"question": question,
"response": response_full,
"sources": source_log_docs,
"retriever_params": retriever_params,
"attachments": attachment_ids,
"timestamp": datetime.datetime.now(datetime.timezone.utc),
}
)
data = json.dumps({"type": "end"})
yield f"data: {data}\n\n"
except Exception as e:
logger.error(f"Error in stream: {str(e)}", exc_info=True)
data = json.dumps(
{
"type": "error",
"error": "Please try again later. We apologize for any inconvenience.",
}
)
yield f"data: {data}\n\n"
return
@answer_ns.route("/stream")
class Stream(Resource):
stream_model = api.model(
"StreamModel",
{
"question": fields.String(
required=True, description="Question to be asked"
),
"history": fields.List(
fields.String, required=False, description="Chat history"
),
"conversation_id": fields.String(
required=False, description="Conversation ID"
),
"prompt_id": fields.String(
required=False, default="default", description="Prompt ID"
),
"chunks": fields.Integer(
required=False, default=2, description="Number of chunks"
),
"token_limit": fields.Integer(required=False, description="Token limit"),
"retriever": fields.String(required=False, description="Retriever type"),
"api_key": fields.String(required=False, description="API key"),
"active_docs": fields.String(
required=False, description="Active documents"
),
"isNoneDoc": fields.Boolean(
required=False, description="Flag indicating if no document is used"
),
"index": fields.Integer(
required=False, description="Index of the query to update"
),
"save_conversation": fields.Boolean(
required=False,
default=True,
description="Whether to save the conversation",
),
"attachments": fields.List(
fields.String, required=False, description="List of attachment IDs"
),
},
)
@api.expect(stream_model)
@api.doc(description="Stream a response based on the question and retriever")
def post(self):
data = request.get_json()
required_fields = ["question"]
if "index" in data:
required_fields = ["question", "conversation_id"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
save_conv = data.get("save_conversation", True)
try:
question = data["question"]
history = limit_chat_history(
json.loads(data.get("history", "[]")), gpt_model=gpt_model
)
conversation_id = data.get("conversation_id")
prompt_id = data.get("prompt_id", "default")
attachment_ids = data.get("attachments", [])
index = data.get("index", None)
chunks = int(data.get("chunks", 2))
token_limit = data.get("token_limit", settings.DEFAULT_MAX_HISTORY)
retriever_name = data.get("retriever", "classic")
agent_id = data.get("agent_id", None)
agent_type = settings.AGENT_NAME
decoded_token = getattr(request, "decoded_token", None)
user_sub = decoded_token.get("sub") if decoded_token else None
agent_key, is_shared_usage, shared_token = get_agent_key(agent_id, user_sub)
if agent_key:
data.update({"api_key": agent_key})
else:
agent_id = None
if "api_key" in data:
data_key = get_data_from_api_key(data["api_key"])
chunks = int(data_key.get("chunks", 2))
prompt_id = data_key.get("prompt_id", "default")
source = {"active_docs": data_key.get("source")}
retriever_name = data_key.get("retriever", retriever_name)
user_api_key = data["api_key"]
agent_type = data_key.get("agent_type", agent_type)
if is_shared_usage:
decoded_token = request.decoded_token
else:
decoded_token = {"sub": data_key.get("user")}
is_shared_usage = False
elif "active_docs" in data:
source = {"active_docs": data["active_docs"]}
retriever_name = get_retriever(data["active_docs"]) or retriever_name
user_api_key = None
decoded_token = request.decoded_token
else:
source = {}
user_api_key = None
decoded_token = request.decoded_token
if not decoded_token:
return make_response({"error": "Unauthorized"}, 401)
attachments = get_attachments_content(
attachment_ids, decoded_token.get("sub")
)
logger.info(
f"/stream - request_data: {data}, source: {source}, attachments: {len(attachments)}",
extra={"data": json.dumps({"request_data": data, "source": source})},
)
prompt = get_prompt(prompt_id)
if "isNoneDoc" in data and data["isNoneDoc"] is True:
chunks = 0
agent = AgentCreator.create_agent(
agent_type,
endpoint="stream",
llm_name=settings.LLM_PROVIDER,
gpt_model=gpt_model,
api_key=settings.API_KEY,
user_api_key=user_api_key,
prompt=prompt,
chat_history=history,
decoded_token=decoded_token,
attachments=attachments,
)
retriever = RetrieverCreator.create_retriever(
retriever_name,
source=source,
chat_history=history,
prompt=prompt,
chunks=chunks,
token_limit=token_limit,
gpt_model=gpt_model,
user_api_key=user_api_key,
decoded_token=decoded_token,
)
return Response(
complete_stream(
question=question,
agent=agent,
retriever=retriever,
conversation_id=conversation_id,
user_api_key=user_api_key,
decoded_token=decoded_token,
isNoneDoc=data.get("isNoneDoc"),
index=index,
should_save_conversation=save_conv,
attachment_ids=attachment_ids,
agent_id=agent_id,
is_shared_usage=is_shared_usage,
shared_token=shared_token,
),
mimetype="text/event-stream",
)
except ValueError:
message = "Malformed request body"
logger.error(f"/stream - error: {message}")
return Response(
error_stream_generate(message),
status=400,
mimetype="text/event-stream",
)
except Exception as e:
logger.error(
f"/stream - error: {str(e)} - traceback: {traceback.format_exc()}",
extra={"error": str(e), "traceback": traceback.format_exc()},
)
status_code = 400
return Response(
error_stream_generate("Unknown error occurred"),
status=status_code,
mimetype="text/event-stream",
)
def error_stream_generate(err_response):
data = json.dumps({"type": "error", "error": err_response})
yield f"data: {data}\n\n"
@answer_ns.route("/api/answer")
class Answer(Resource):
answer_model = api.model(
"AnswerModel",
{
"question": fields.String(
required=True, description="The question to answer"
),
"history": fields.List(
fields.String, required=False, description="Conversation history"
),
"conversation_id": fields.String(
required=False, description="Conversation ID"
),
"prompt_id": fields.String(
required=False, default="default", description="Prompt ID"
),
"chunks": fields.Integer(
required=False, default=2, description="Number of chunks"
),
"token_limit": fields.Integer(required=False, description="Token limit"),
"retriever": fields.String(required=False, description="Retriever type"),
"api_key": fields.String(required=False, description="API key"),
"active_docs": fields.String(
required=False, description="Active documents"
),
"isNoneDoc": fields.Boolean(
required=False, description="Flag indicating if no document is used"
),
},
)
@api.expect(answer_model)
@api.doc(description="Provide an answer based on the question and retriever")
def post(self):
data = request.get_json()
required_fields = ["question"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
try:
question = data["question"]
history = limit_chat_history(
json.loads(data.get("history", [])), gpt_model=gpt_model
)
conversation_id = data.get("conversation_id")
prompt_id = data.get("prompt_id", "default")
chunks = int(data.get("chunks", 2))
token_limit = data.get("token_limit", settings.DEFAULT_MAX_HISTORY)
retriever_name = data.get("retriever", "classic")
agent_type = settings.AGENT_NAME
if "api_key" in data:
data_key = get_data_from_api_key(data["api_key"])
chunks = int(data_key.get("chunks", 2))
prompt_id = data_key.get("prompt_id", "default")
source = {"active_docs": data_key.get("source")}
retriever_name = data_key.get("retriever", retriever_name)
user_api_key = data["api_key"]
agent_type = data_key.get("agent_type", agent_type)
decoded_token = {"sub": data_key.get("user")}
elif "active_docs" in data:
source = {"active_docs": data["active_docs"]}
retriever_name = get_retriever(data["active_docs"]) or retriever_name
user_api_key = None
decoded_token = request.decoded_token
else:
source = {}
user_api_key = None
decoded_token = request.decoded_token
if not decoded_token:
return make_response({"error": "Unauthorized"}, 401)
prompt = get_prompt(prompt_id)
logger.info(
f"/api/answer - request_data: {data}, source: {source}",
extra={"data": json.dumps({"request_data": data, "source": source})},
)
agent = AgentCreator.create_agent(
agent_type,
endpoint="api/answer",
llm_name=settings.LLM_PROVIDER,
gpt_model=gpt_model,
api_key=settings.API_KEY,
user_api_key=user_api_key,
prompt=prompt,
chat_history=history,
decoded_token=decoded_token,
)
retriever = RetrieverCreator.create_retriever(
retriever_name,
source=source,
chat_history=history,
prompt=prompt,
chunks=chunks,
token_limit=token_limit,
gpt_model=gpt_model,
user_api_key=user_api_key,
decoded_token=decoded_token,
)
response_full = ""
source_log_docs = []
tool_calls = []
stream_ended = False
thought = ""
for line in complete_stream(
question=question,
agent=agent,
retriever=retriever,
conversation_id=conversation_id,
user_api_key=user_api_key,
decoded_token=decoded_token,
isNoneDoc=data.get("isNoneDoc"),
index=None,
should_save_conversation=False,
):
try:
event_data = line.replace("data: ", "").strip()
event = json.loads(event_data)
if event["type"] == "answer":
response_full += event["answer"]
elif event["type"] == "source":
source_log_docs = event["source"]
elif event["type"] == "tool_calls":
tool_calls = event["tool_calls"]
elif event["type"] == "thought":
thought = event["thought"]
elif event["type"] == "error":
logger.error(f"Error from stream: {event['error']}")
return bad_request(500, event["error"])
elif event["type"] == "end":
stream_ended = True
except (json.JSONDecodeError, KeyError) as e:
logger.warning(f"Error parsing stream event: {e}, line: {line}")
continue
if not stream_ended:
logger.error("Stream ended unexpectedly without an 'end' event.")
return bad_request(500, "Stream ended unexpectedly.")
if data.get("isNoneDoc"):
for doc in source_log_docs:
doc["source"] = "None"
llm = LLMCreator.create_llm(
settings.LLM_PROVIDER,
api_key=settings.API_KEY,
user_api_key=user_api_key,
decoded_token=decoded_token,
)
result = {"answer": response_full, "sources": source_log_docs}
result["conversation_id"] = str(
save_conversation(
conversation_id,
question,
response_full,
thought,
source_log_docs,
tool_calls,
llm,
decoded_token,
api_key=user_api_key,
)
)
retriever_params = retriever.get_params()
user_logs_collection.insert_one(
{
"action": "api_answer",
"level": "info",
"user": decoded_token.get("sub"),
"api_key": user_api_key,
"question": question,
"response": response_full,
"sources": source_log_docs,
"retriever_params": retriever_params,
"timestamp": datetime.datetime.now(datetime.timezone.utc),
}
)
except Exception as e:
logger.error(
f"/api/answer - error: {str(e)} - traceback: {traceback.format_exc()}",
extra={"error": str(e), "traceback": traceback.format_exc()},
)
return bad_request(500, str(e))
return make_response(result, 200)
@answer_ns.route("/api/search")
class Search(Resource):
search_model = api.model(
"SearchModel",
{
"question": fields.String(
required=True, description="The question to search"
),
"chunks": fields.Integer(
required=False, default=2, description="Number of chunks"
),
"api_key": fields.String(
required=False, description="API key for authentication"
),
"active_docs": fields.String(
required=False, description="Active documents for retrieval"
),
"retriever": fields.String(required=False, description="Retriever type"),
"token_limit": fields.Integer(
required=False, description="Limit for tokens"
),
"isNoneDoc": fields.Boolean(
required=False, description="Flag indicating if no document is used"
),
},
)
@api.expect(search_model)
@api.doc(
description="Search for relevant documents based on the question and retriever"
)
def post(self):
data = request.get_json()
required_fields = ["question"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
try:
question = data["question"]
chunks = int(data.get("chunks", 2))
token_limit = data.get("token_limit", settings.DEFAULT_MAX_HISTORY)
retriever_name = data.get("retriever", "classic")
if "api_key" in data:
data_key = get_data_from_api_key(data["api_key"])
chunks = int(data_key.get("chunks", 2))
source = {"active_docs": data_key.get("source")}
user_api_key = data["api_key"]
decoded_token = {"sub": data_key.get("user")}
elif "active_docs" in data:
source = {"active_docs": data["active_docs"]}
user_api_key = None
decoded_token = request.decoded_token
else:
source = {}
user_api_key = None
decoded_token = request.decoded_token
if not decoded_token:
return make_response({"error": "Unauthorized"}, 401)
logger.info(
f"/api/answer - request_data: {data}, source: {source}",
extra={"data": json.dumps({"request_data": data, "source": source})},
)
retriever = RetrieverCreator.create_retriever(
retriever_name,
source=source,
chat_history=[],
prompt="default",
chunks=chunks,
token_limit=token_limit,
gpt_model=gpt_model,
user_api_key=user_api_key,
decoded_token=decoded_token,
)
docs = retriever.search(question)
retriever_params = retriever.get_params()
user_logs_collection.insert_one(
{
"action": "api_search",
"level": "info",
"user": decoded_token.get("sub"),
"api_key": user_api_key,
"question": question,
"sources": docs,
"retriever_params": retriever_params,
"timestamp": datetime.datetime.now(datetime.timezone.utc),
}
)
if data.get("isNoneDoc"):
for doc in docs:
doc["source"] = "None"
except Exception as e:
logger.error(
f"/api/search - error: {str(e)} - traceback: {traceback.format_exc()}",
extra={"error": str(e), "traceback": traceback.format_exc()},
)
return bad_request(500, str(e))
return make_response(docs, 200)
def get_attachments_content(attachment_ids, user):
"""
Retrieve content from attachment documents based on their IDs.
Args:
attachment_ids (list): List of attachment document IDs
user (str): User identifier to verify ownership
Returns:
list: List of dictionaries containing attachment content and metadata
"""
if not attachment_ids:
return []
attachments = []
for attachment_id in attachment_ids:
try:
attachment_doc = attachments_collection.find_one(
{"_id": ObjectId(attachment_id), "user": user}
)
if attachment_doc:
attachments.append(attachment_doc)
except Exception as e:
logger.error(
f"Error retrieving attachment {attachment_id}: {e}", exc_info=True
)
return attachments

View File

@@ -1,122 +0,0 @@
import logging
import traceback
from flask import make_response, request
from flask_restx import fields, Resource
from application.api import api
from application.api.answer.routes.base import answer_ns, BaseAnswerResource
from application.api.answer.services.stream_processor import StreamProcessor
logger = logging.getLogger(__name__)
@answer_ns.route("/api/answer")
class AnswerResource(Resource, BaseAnswerResource):
def __init__(self, *args, **kwargs):
Resource.__init__(self, *args, **kwargs)
BaseAnswerResource.__init__(self)
answer_model = answer_ns.model(
"AnswerModel",
{
"question": fields.String(
required=True, description="Question to be asked"
),
"history": fields.List(
fields.String,
required=False,
description="Conversation history (only for new conversations)",
),
"conversation_id": fields.String(
required=False,
description="Existing conversation ID (loads history)",
),
"prompt_id": fields.String(
required=False, default="default", description="Prompt ID"
),
"chunks": fields.Integer(
required=False, default=2, description="Number of chunks"
),
"token_limit": fields.Integer(required=False, description="Token limit"),
"retriever": fields.String(required=False, description="Retriever type"),
"api_key": fields.String(required=False, description="API key"),
"active_docs": fields.String(
required=False, description="Active documents"
),
"isNoneDoc": fields.Boolean(
required=False, description="Flag indicating if no document is used"
),
"save_conversation": fields.Boolean(
required=False,
default=True,
description="Whether to save the conversation",
),
},
)
@api.expect(answer_model)
@api.doc(description="Provide a response based on the question and retriever")
def post(self):
data = request.get_json()
if error := self.validate_request(data):
return error
decoded_token = getattr(request, "decoded_token", None)
processor = StreamProcessor(data, decoded_token)
try:
processor.initialize()
if not processor.decoded_token:
return make_response({"error": "Unauthorized"}, 401)
agent = processor.create_agent()
retriever = processor.create_retriever()
stream = self.complete_stream(
question=data["question"],
agent=agent,
retriever=retriever,
conversation_id=processor.conversation_id,
user_api_key=processor.agent_config.get("user_api_key"),
decoded_token=processor.decoded_token,
isNoneDoc=data.get("isNoneDoc"),
index=None,
should_save_conversation=data.get("save_conversation", True),
)
stream_result = self.process_response_stream(stream)
if len(stream_result) == 7:
(
conversation_id,
response,
sources,
tool_calls,
thought,
error,
structured_info,
) = stream_result
else:
conversation_id, response, sources, tool_calls, thought, error = (
stream_result
)
structured_info = None
if error:
return make_response({"error": error}, 400)
result = {
"conversation_id": conversation_id,
"answer": response,
"sources": sources,
"tool_calls": tool_calls,
"thought": thought,
}
if structured_info:
result.update(structured_info)
except Exception as e:
logger.error(
f"/api/answer - error: {str(e)} - traceback: {traceback.format_exc()}",
extra={"error": str(e), "traceback": traceback.format_exc()},
)
return make_response({"error": str(e)}, 500)
return make_response(result, 200)

View File

@@ -1,265 +0,0 @@
import datetime
import json
import logging
from typing import Any, Dict, Generator, List, Optional
from flask import Response
from flask_restx import Namespace
from application.api.answer.services.conversation_service import ConversationService
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.llm.llm_creator import LLMCreator
from application.utils import check_required_fields, get_gpt_model
logger = logging.getLogger(__name__)
answer_ns = Namespace("answer", description="Answer related operations", path="/")
class BaseAnswerResource:
"""Shared base class for answer endpoints"""
def __init__(self):
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
self.user_logs_collection = db["user_logs"]
self.gpt_model = get_gpt_model()
self.conversation_service = ConversationService()
def validate_request(
self, data: Dict[str, Any], require_conversation_id: bool = False
) -> Optional[Response]:
"""Common request validation"""
required_fields = ["question"]
if require_conversation_id:
required_fields.append("conversation_id")
if missing_fields := check_required_fields(data, required_fields):
return missing_fields
return None
def complete_stream(
self,
question: str,
agent: Any,
retriever: Any,
conversation_id: Optional[str],
user_api_key: Optional[str],
decoded_token: Dict[str, Any],
isNoneDoc: bool = False,
index: Optional[int] = None,
should_save_conversation: bool = True,
attachment_ids: Optional[List[str]] = None,
agent_id: Optional[str] = None,
is_shared_usage: bool = False,
shared_token: Optional[str] = None,
) -> Generator[str, None, None]:
"""
Generator function that streams the complete conversation response.
Args:
question: The user's question
agent: The agent instance
retriever: The retriever instance
conversation_id: Existing conversation ID
user_api_key: User's API key if any
decoded_token: Decoded JWT token
isNoneDoc: Flag for document-less responses
index: Index of message to update
should_save_conversation: Whether to persist the conversation
attachment_ids: List of attachment IDs
agent_id: ID of agent used
is_shared_usage: Flag for shared agent usage
shared_token: Token for shared agent
Yields:
Server-sent event strings
"""
try:
response_full, thought, source_log_docs, tool_calls = "", "", [], []
is_structured = False
schema_info = None
structured_chunks = []
for line in agent.gen(query=question, retriever=retriever):
if "answer" in line:
response_full += str(line["answer"])
if line.get("structured"):
is_structured = True
schema_info = line.get("schema")
structured_chunks.append(line["answer"])
else:
data = json.dumps({"type": "answer", "answer": line["answer"]})
yield f"data: {data}\n\n"
elif "sources" in line:
truncated_sources = []
source_log_docs = line["sources"]
for source in line["sources"]:
truncated_source = source.copy()
if "text" in truncated_source:
truncated_source["text"] = (
truncated_source["text"][:100].strip() + "..."
)
truncated_sources.append(truncated_source)
if truncated_sources:
data = json.dumps(
{"type": "source", "source": truncated_sources}
)
yield f"data: {data}\n\n"
elif "tool_calls" in line:
tool_calls = line["tool_calls"]
data = json.dumps({"type": "tool_calls", "tool_calls": tool_calls})
yield f"data: {data}\n\n"
elif "thought" in line:
thought += line["thought"]
data = json.dumps({"type": "thought", "thought": line["thought"]})
yield f"data: {data}\n\n"
elif "type" in line:
data = json.dumps(line)
yield f"data: {data}\n\n"
if is_structured and structured_chunks:
structured_data = {
"type": "structured_answer",
"answer": response_full,
"structured": True,
"schema": schema_info,
}
data = json.dumps(structured_data)
yield f"data: {data}\n\n"
if isNoneDoc:
for doc in source_log_docs:
doc["source"] = "None"
llm = LLMCreator.create_llm(
settings.LLM_PROVIDER,
api_key=settings.API_KEY,
user_api_key=user_api_key,
decoded_token=decoded_token,
)
if should_save_conversation:
conversation_id = self.conversation_service.save_conversation(
conversation_id,
question,
response_full,
thought,
source_log_docs,
tool_calls,
llm,
self.gpt_model,
decoded_token,
index=index,
api_key=user_api_key,
agent_id=agent_id,
is_shared_usage=is_shared_usage,
shared_token=shared_token,
attachment_ids=attachment_ids,
)
else:
conversation_id = None
id_data = {"type": "id", "id": str(conversation_id)}
data = json.dumps(id_data)
yield f"data: {data}\n\n"
retriever_params = retriever.get_params()
log_data = {
"action": "stream_answer",
"level": "info",
"user": decoded_token.get("sub"),
"api_key": user_api_key,
"question": question,
"response": response_full,
"sources": source_log_docs,
"retriever_params": retriever_params,
"attachments": attachment_ids,
"timestamp": datetime.datetime.now(datetime.timezone.utc),
}
if is_structured:
log_data["structured_output"] = True
if schema_info:
log_data["schema"] = schema_info
# clean up text fields to be no longer than 10000 characters
for key, value in log_data.items():
if isinstance(value, str) and len(value) > 10000:
log_data[key] = value[:10000]
self.user_logs_collection.insert_one(log_data)
# End of stream
data = json.dumps({"type": "end"})
yield f"data: {data}\n\n"
except Exception as e:
logger.error(f"Error in stream: {str(e)}", exc_info=True)
data = json.dumps(
{
"type": "error",
"error": "Please try again later. We apologize for any inconvenience.",
}
)
yield f"data: {data}\n\n"
return
def process_response_stream(self, stream):
"""Process the stream response for non-streaming endpoint"""
conversation_id = ""
response_full = ""
source_log_docs = []
tool_calls = []
thought = ""
stream_ended = False
is_structured = False
schema_info = None
for line in stream:
try:
event_data = line.replace("data: ", "").strip()
event = json.loads(event_data)
if event["type"] == "id":
conversation_id = event["id"]
elif event["type"] == "answer":
response_full += event["answer"]
elif event["type"] == "structured_answer":
response_full = event["answer"]
is_structured = True
schema_info = event.get("schema")
elif event["type"] == "source":
source_log_docs = event["source"]
elif event["type"] == "tool_calls":
tool_calls = event["tool_calls"]
elif event["type"] == "thought":
thought = event["thought"]
elif event["type"] == "error":
logger.error(f"Error from stream: {event['error']}")
return None, None, None, None, event["error"]
elif event["type"] == "end":
stream_ended = True
except (json.JSONDecodeError, KeyError) as e:
logger.warning(f"Error parsing stream event: {e}, line: {line}")
continue
if not stream_ended:
logger.error("Stream ended unexpectedly without an 'end' event.")
return None, None, None, None, "Stream ended unexpectedly"
result = (
conversation_id,
response_full,
source_log_docs,
tool_calls,
thought,
None,
)
if is_structured:
result = result + ({"structured": True, "schema": schema_info},)
return result
def error_stream_generate(self, err_response):
data = json.dumps({"type": "error", "error": err_response})
yield f"data: {data}\n\n"

View File

@@ -1,117 +0,0 @@
import logging
import traceback
from flask import request, Response
from flask_restx import fields, Resource
from application.api import api
from application.api.answer.routes.base import answer_ns, BaseAnswerResource
from application.api.answer.services.stream_processor import StreamProcessor
logger = logging.getLogger(__name__)
@answer_ns.route("/stream")
class StreamResource(Resource, BaseAnswerResource):
def __init__(self, *args, **kwargs):
Resource.__init__(self, *args, **kwargs)
BaseAnswerResource.__init__(self)
stream_model = answer_ns.model(
"StreamModel",
{
"question": fields.String(
required=True, description="Question to be asked"
),
"history": fields.List(
fields.String,
required=False,
description="Conversation history (only for new conversations)",
),
"conversation_id": fields.String(
required=False,
description="Existing conversation ID (loads history)",
),
"prompt_id": fields.String(
required=False, default="default", description="Prompt ID"
),
"chunks": fields.Integer(
required=False, default=2, description="Number of chunks"
),
"token_limit": fields.Integer(required=False, description="Token limit"),
"retriever": fields.String(required=False, description="Retriever type"),
"api_key": fields.String(required=False, description="API key"),
"active_docs": fields.String(
required=False, description="Active documents"
),
"isNoneDoc": fields.Boolean(
required=False, description="Flag indicating if no document is used"
),
"index": fields.Integer(
required=False, description="Index of the query to update"
),
"save_conversation": fields.Boolean(
required=False,
default=True,
description="Whether to save the conversation",
),
"attachments": fields.List(
fields.String, required=False, description="List of attachment IDs"
),
},
)
@api.expect(stream_model)
@api.doc(description="Stream a response based on the question and retriever")
def post(self):
data = request.get_json()
if error := self.validate_request(data, "index" in data):
return error
decoded_token = getattr(request, "decoded_token", None)
processor = StreamProcessor(data, decoded_token)
try:
processor.initialize()
agent = processor.create_agent()
retriever = processor.create_retriever()
return Response(
self.complete_stream(
question=data["question"],
agent=agent,
retriever=retriever,
conversation_id=processor.conversation_id,
user_api_key=processor.agent_config.get("user_api_key"),
decoded_token=processor.decoded_token,
isNoneDoc=data.get("isNoneDoc"),
index=data.get("index"),
should_save_conversation=data.get("save_conversation", True),
attachment_ids=data.get("attachments", []),
agent_id=data.get("agent_id"),
is_shared_usage=processor.is_shared_usage,
shared_token=processor.shared_token,
),
mimetype="text/event-stream",
)
except ValueError as e:
message = "Malformed request body"
logger.error(
f"/stream - error: {message} - specific error: {str(e)} - traceback: {traceback.format_exc()}",
extra={"error": str(e), "traceback": traceback.format_exc()},
)
return Response(
self.error_stream_generate(message),
status=400,
mimetype="text/event-stream",
)
except Exception as e:
logger.error(
f"/stream - error: {str(e)} - traceback: {traceback.format_exc()}",
extra={"error": str(e), "traceback": traceback.format_exc()},
)
return Response(
self.error_stream_generate("Unknown error occurred"),
status=400,
mimetype="text/event-stream",
)

View File

@@ -1,180 +0,0 @@
import logging
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from bson import ObjectId
logger = logging.getLogger(__name__)
class ConversationService:
def __init__(self):
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
self.conversations_collection = db["conversations"]
self.agents_collection = db["agents"]
def get_conversation(
self, conversation_id: str, user_id: str
) -> Optional[Dict[str, Any]]:
"""Retrieve a conversation with proper access control"""
if not conversation_id or not user_id:
return None
try:
conversation = self.conversations_collection.find_one(
{
"_id": ObjectId(conversation_id),
"$or": [{"user": user_id}, {"shared_with": user_id}],
}
)
if not conversation:
logger.warning(
f"Conversation not found or unauthorized - ID: {conversation_id}, User: {user_id}"
)
return None
conversation["_id"] = str(conversation["_id"])
return conversation
except Exception as e:
logger.error(f"Error fetching conversation: {str(e)}", exc_info=True)
return None
def save_conversation(
self,
conversation_id: Optional[str],
question: str,
response: str,
thought: str,
sources: List[Dict[str, Any]],
tool_calls: List[Dict[str, Any]],
llm: Any,
gpt_model: str,
decoded_token: Dict[str, Any],
index: Optional[int] = None,
api_key: Optional[str] = None,
agent_id: Optional[str] = None,
is_shared_usage: bool = False,
shared_token: Optional[str] = None,
attachment_ids: Optional[List[str]] = None,
) -> str:
"""Save or update a conversation in the database"""
user_id = decoded_token.get("sub")
if not user_id:
raise ValueError("User ID not found in token")
current_time = datetime.now(timezone.utc)
# clean up in sources array such that we save max 1k characters for text part
for source in sources:
if "text" in source and isinstance(source["text"], str):
source["text"] = source["text"][:1000]
if conversation_id is not None and index is not None:
# Update existing conversation with new query
result = self.conversations_collection.update_one(
{
"_id": ObjectId(conversation_id),
"user": user_id,
f"queries.{index}": {"$exists": True},
},
{
"$set": {
f"queries.{index}.prompt": question,
f"queries.{index}.response": response,
f"queries.{index}.thought": thought,
f"queries.{index}.sources": sources,
f"queries.{index}.tool_calls": tool_calls,
f"queries.{index}.timestamp": current_time,
f"queries.{index}.attachments": attachment_ids,
}
},
)
if result.matched_count == 0:
raise ValueError("Conversation not found or unauthorized")
self.conversations_collection.update_one(
{
"_id": ObjectId(conversation_id),
"user": user_id,
f"queries.{index}": {"$exists": True},
},
{"$push": {"queries": {"$each": [], "$slice": index + 1}}},
)
return conversation_id
elif conversation_id:
# Append new message to existing conversation
result = self.conversations_collection.update_one(
{"_id": ObjectId(conversation_id), "user": user_id},
{
"$push": {
"queries": {
"prompt": question,
"response": response,
"thought": thought,
"sources": sources,
"tool_calls": tool_calls,
"timestamp": current_time,
"attachments": attachment_ids,
}
}
},
)
if result.matched_count == 0:
raise ValueError("Conversation not found or unauthorized")
return conversation_id
else:
# Create new conversation
messages_summary = [
{
"role": "assistant",
"content": "Summarise following conversation in no more than 3 "
"words, respond ONLY with the summary, use the same "
"language as the user query",
},
{
"role": "user",
"content": "Summarise following conversation in no more than 3 words, "
"respond ONLY with the summary, use the same language as the "
"user query \n\nUser: " + question + "\n\n" + "AI: " + response,
},
]
completion = llm.gen(
model=gpt_model, messages=messages_summary, max_tokens=30
)
conversation_data = {
"user": user_id,
"date": current_time,
"name": completion,
"queries": [
{
"prompt": question,
"response": response,
"thought": thought,
"sources": sources,
"tool_calls": tool_calls,
"timestamp": current_time,
"attachments": attachment_ids,
}
],
}
if api_key:
if agent_id:
conversation_data["agent_id"] = agent_id
if is_shared_usage:
conversation_data["is_shared_usage"] = is_shared_usage
conversation_data["shared_token"] = shared_token
agent = self.agents_collection.find_one({"key": api_key})
if agent:
conversation_data["api_key"] = agent["key"]
result = self.conversations_collection.insert_one(conversation_data)
return str(result.inserted_id)

View File

@@ -1,353 +0,0 @@
import datetime
import json
import logging
import os
from pathlib import Path
from typing import Any, Dict, Optional
from bson.dbref import DBRef
from bson.objectid import ObjectId
from application.agents.agent_creator import AgentCreator
from application.api.answer.services.conversation_service import ConversationService
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.retriever.retriever_creator import RetrieverCreator
from application.utils import get_gpt_model, limit_chat_history
logger = logging.getLogger(__name__)
def get_prompt(prompt_id: str, prompts_collection=None) -> str:
"""
Get a prompt by preset name or MongoDB ID
"""
current_dir = Path(__file__).resolve().parents[3]
prompts_dir = current_dir / "prompts"
preset_mapping = {
"default": "chat_combine_default.txt",
"creative": "chat_combine_creative.txt",
"strict": "chat_combine_strict.txt",
"reduce": "chat_reduce_prompt.txt",
}
if prompt_id in preset_mapping:
file_path = os.path.join(prompts_dir, preset_mapping[prompt_id])
try:
with open(file_path, "r") as f:
return f.read()
except FileNotFoundError:
raise FileNotFoundError(f"Prompt file not found: {file_path}")
try:
if prompts_collection is None:
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
prompts_collection = db["prompts"]
prompt_doc = prompts_collection.find_one({"_id": ObjectId(prompt_id)})
if not prompt_doc:
raise ValueError(f"Prompt with ID {prompt_id} not found")
return prompt_doc["content"]
except Exception as e:
raise ValueError(f"Invalid prompt ID: {prompt_id}") from e
class StreamProcessor:
def __init__(
self, request_data: Dict[str, Any], decoded_token: Optional[Dict[str, Any]]
):
mongo = MongoDB.get_client()
self.db = mongo[settings.MONGO_DB_NAME]
self.agents_collection = self.db["agents"]
self.attachments_collection = self.db["attachments"]
self.prompts_collection = self.db["prompts"]
self.data = request_data
self.decoded_token = decoded_token
self.initial_user_id = (
self.decoded_token.get("sub") if self.decoded_token is not None else None
)
self.conversation_id = self.data.get("conversation_id")
self.source = {}
self.all_sources = []
self.attachments = []
self.history = []
self.agent_config = {}
self.retriever_config = {}
self.is_shared_usage = False
self.shared_token = None
self.gpt_model = get_gpt_model()
self.conversation_service = ConversationService()
def initialize(self):
"""Initialize all required components for processing"""
self._configure_agent()
self._configure_source()
self._configure_retriever()
self._configure_agent()
self._load_conversation_history()
self._process_attachments()
def _load_conversation_history(self):
"""Load conversation history either from DB or request"""
if self.conversation_id and self.initial_user_id:
conversation = self.conversation_service.get_conversation(
self.conversation_id, self.initial_user_id
)
if not conversation:
raise ValueError("Conversation not found or unauthorized")
self.history = [
{"prompt": query["prompt"], "response": query["response"]}
for query in conversation.get("queries", [])
]
else:
self.history = limit_chat_history(
json.loads(self.data.get("history", "[]")), gpt_model=self.gpt_model
)
def _process_attachments(self):
"""Process any attachments in the request"""
attachment_ids = self.data.get("attachments", [])
self.attachments = self._get_attachments_content(
attachment_ids, self.initial_user_id
)
def _get_attachments_content(self, attachment_ids, user_id):
"""
Retrieve content from attachment documents based on their IDs.
"""
if not attachment_ids:
return []
attachments = []
for attachment_id in attachment_ids:
try:
attachment_doc = self.attachments_collection.find_one(
{"_id": ObjectId(attachment_id), "user": user_id}
)
if attachment_doc:
attachments.append(attachment_doc)
except Exception as e:
logger.error(
f"Error retrieving attachment {attachment_id}: {e}", exc_info=True
)
return attachments
def _get_agent_key(self, agent_id: Optional[str], user_id: Optional[str]) -> tuple:
"""Get API key for agent with access control"""
if not agent_id:
return None, False, None
try:
agent = self.agents_collection.find_one({"_id": ObjectId(agent_id)})
if agent is None:
raise Exception("Agent not found")
is_owner = agent.get("user") == user_id
is_shared_with_user = agent.get(
"shared_publicly", False
) or user_id in agent.get("shared_with", [])
if not (is_owner or is_shared_with_user):
raise Exception("Unauthorized access to the agent")
if is_owner:
self.agents_collection.update_one(
{"_id": ObjectId(agent_id)},
{
"$set": {
"lastUsedAt": datetime.datetime.now(datetime.timezone.utc)
}
},
)
return str(agent["key"]), not is_owner, agent.get("shared_token")
except Exception as e:
logger.error(f"Error in get_agent_key: {str(e)}", exc_info=True)
raise
def _get_data_from_api_key(self, api_key: str) -> Dict[str, Any]:
data = self.agents_collection.find_one({"key": api_key})
if not data:
raise Exception("Invalid API Key, please generate a new key", 401)
source = data.get("source")
if isinstance(source, DBRef):
source_doc = self.db.dereference(source)
if source_doc:
data["source"] = str(source_doc["_id"])
data["retriever"] = source_doc.get("retriever", data.get("retriever"))
data["chunks"] = source_doc.get("chunks", data.get("chunks"))
else:
data["source"] = None
elif source == "default":
data["source"] = "default"
else:
data["source"] = None
# Handle multiple sources
sources = data.get("sources", [])
if sources and isinstance(sources, list):
sources_list = []
for i, source_ref in enumerate(sources):
if source_ref == "default":
processed_source = {
"id": "default",
"retriever": "classic",
"chunks": data.get("chunks", "2"),
}
sources_list.append(processed_source)
elif isinstance(source_ref, DBRef):
source_doc = self.db.dereference(source_ref)
if source_doc:
processed_source = {
"id": str(source_doc["_id"]),
"retriever": source_doc.get("retriever", "classic"),
"chunks": source_doc.get("chunks", data.get("chunks", "2")),
}
sources_list.append(processed_source)
data["sources"] = sources_list
else:
data["sources"] = []
return data
def _configure_source(self):
"""Configure the source based on agent data"""
api_key = self.data.get("api_key") or self.agent_key
if api_key:
agent_data = self._get_data_from_api_key(api_key)
if agent_data.get("sources") and len(agent_data["sources"]) > 0:
source_ids = [
source["id"] for source in agent_data["sources"] if source.get("id")
]
if source_ids:
self.source = {"active_docs": source_ids}
else:
self.source = {}
self.all_sources = agent_data["sources"]
elif agent_data.get("source"):
self.source = {"active_docs": agent_data["source"]}
self.all_sources = [
{
"id": agent_data["source"],
"retriever": agent_data.get("retriever", "classic"),
}
]
else:
self.source = {}
self.all_sources = []
return
if "active_docs" in self.data:
self.source = {"active_docs": self.data["active_docs"]}
return
self.source = {}
self.all_sources = []
def _configure_agent(self):
"""Configure the agent based on request data"""
agent_id = self.data.get("agent_id")
self.agent_key, self.is_shared_usage, self.shared_token = self._get_agent_key(
agent_id, self.initial_user_id
)
api_key = self.data.get("api_key")
if api_key:
data_key = self._get_data_from_api_key(api_key)
self.agent_config.update(
{
"prompt_id": data_key.get("prompt_id", "default"),
"agent_type": data_key.get("agent_type", settings.AGENT_NAME),
"user_api_key": api_key,
"json_schema": data_key.get("json_schema"),
}
)
self.initial_user_id = data_key.get("user")
self.decoded_token = {"sub": data_key.get("user")}
if data_key.get("source"):
self.source = {"active_docs": data_key["source"]}
if data_key.get("retriever"):
self.retriever_config["retriever_name"] = data_key["retriever"]
if data_key.get("chunks") is not None:
try:
self.retriever_config["chunks"] = int(data_key["chunks"])
except (ValueError, TypeError):
logger.warning(
f"Invalid chunks value: {data_key['chunks']}, using default value 2"
)
self.retriever_config["chunks"] = 2
elif self.agent_key:
data_key = self._get_data_from_api_key(self.agent_key)
self.agent_config.update(
{
"prompt_id": data_key.get("prompt_id", "default"),
"agent_type": data_key.get("agent_type", settings.AGENT_NAME),
"user_api_key": self.agent_key,
"json_schema": data_key.get("json_schema"),
}
)
self.decoded_token = (
self.decoded_token
if self.is_shared_usage
else {"sub": data_key.get("user")}
)
if data_key.get("source"):
self.source = {"active_docs": data_key["source"]}
if data_key.get("retriever"):
self.retriever_config["retriever_name"] = data_key["retriever"]
if data_key.get("chunks") is not None:
try:
self.retriever_config["chunks"] = int(data_key["chunks"])
except (ValueError, TypeError):
logger.warning(
f"Invalid chunks value: {data_key['chunks']}, using default value 2"
)
self.retriever_config["chunks"] = 2
else:
self.agent_config.update(
{
"prompt_id": self.data.get("prompt_id", "default"),
"agent_type": settings.AGENT_NAME,
"user_api_key": None,
"json_schema": None,
}
)
def _configure_retriever(self):
"""Configure the retriever based on request data"""
self.retriever_config = {
"retriever_name": self.data.get("retriever", "classic"),
"chunks": int(self.data.get("chunks", 2)),
"token_limit": self.data.get("token_limit", settings.DEFAULT_MAX_HISTORY),
}
api_key = self.data.get("api_key") or self.agent_key
if not api_key and "isNoneDoc" in self.data and self.data["isNoneDoc"]:
self.retriever_config["chunks"] = 0
def create_agent(self):
"""Create and return the configured agent"""
return AgentCreator.create_agent(
self.agent_config["agent_type"],
endpoint="stream",
llm_name=settings.LLM_PROVIDER,
gpt_model=self.gpt_model,
api_key=settings.API_KEY,
user_api_key=self.agent_config["user_api_key"],
prompt=get_prompt(self.agent_config["prompt_id"], self.prompts_collection),
chat_history=self.history,
decoded_token=self.decoded_token,
attachments=self.attachments,
json_schema=self.agent_config.get("json_schema"),
)
def create_retriever(self):
"""Create and return the configured retriever"""
return RetrieverCreator.create_retriever(
self.retriever_config["retriever_name"],
source=self.source,
chat_history=self.history,
prompt=get_prompt(self.agent_config["prompt_id"], self.prompts_collection),
chunks=self.retriever_config["chunks"],
token_limit=self.retriever_config["token_limit"],
gpt_model=self.gpt_model,
user_api_key=self.agent_config["user_api_key"],
decoded_token=self.decoded_token,
)

View File

@@ -1,695 +0,0 @@
import base64
import datetime
import json
import uuid
from bson.objectid import ObjectId
from flask import (
Blueprint,
current_app,
jsonify,
make_response,
request
)
from flask_restx import fields, Namespace, Resource
from application.api.user.tasks import (
ingest_connector_task,
)
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.api import api
from application.utils import (
check_required_fields
)
from application.parser.connectors.connector_creator import ConnectorCreator
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
sources_collection = db["sources"]
sessions_collection = db["connector_sessions"]
connector = Blueprint("connector", __name__)
connectors_ns = Namespace("connectors", description="Connector operations", path="/")
api.add_namespace(connectors_ns)
@connectors_ns.route("/api/connectors/upload")
class UploadConnector(Resource):
@api.expect(
api.model(
"ConnectorUploadModel",
{
"user": fields.String(required=True, description="User ID"),
"source": fields.String(
required=True, description="Source type (google_drive, github, etc.)"
),
"name": fields.String(required=True, description="Job name"),
"data": fields.String(required=True, description="Configuration data"),
"repo_url": fields.String(description="GitHub repository URL"),
},
)
)
@api.doc(
description="Uploads connector source for vectorization",
)
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
data = request.form
required_fields = ["user", "source", "name", "data"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
try:
config = json.loads(data["data"])
source_data = None
sync_frequency = config.get("sync_frequency", "never")
if data["source"] == "github":
source_data = config.get("repo_url")
elif data["source"] in ["crawler", "url"]:
source_data = config.get("url")
elif data["source"] == "reddit":
source_data = config
elif data["source"] in ConnectorCreator.get_supported_connectors():
session_token = config.get("session_token")
if not session_token:
return make_response(jsonify({
"success": False,
"error": f"Missing session_token in {data['source']} configuration"
}), 400)
file_ids = config.get("file_ids", [])
if isinstance(file_ids, str):
file_ids = [id.strip() for id in file_ids.split(',') if id.strip()]
elif not isinstance(file_ids, list):
file_ids = []
folder_ids = config.get("folder_ids", [])
if isinstance(folder_ids, str):
folder_ids = [id.strip() for id in folder_ids.split(',') if id.strip()]
elif not isinstance(folder_ids, list):
folder_ids = []
config["file_ids"] = file_ids
config["folder_ids"] = folder_ids
task = ingest_connector_task.delay(
job_name=data["name"],
user=decoded_token.get("sub"),
source_type=data["source"],
session_token=session_token,
file_ids=file_ids,
folder_ids=folder_ids,
recursive=config.get("recursive", False),
retriever=config.get("retriever", "classic"),
sync_frequency=sync_frequency
)
return make_response(jsonify({"success": True, "task_id": task.id}), 200)
task = ingest_connector_task.delay(
source_data=source_data,
job_name=data["name"],
user=decoded_token.get("sub"),
loader=data["source"],
sync_frequency=sync_frequency
)
except Exception as err:
current_app.logger.error(
f"Error uploading connector source: {err}", exc_info=True
)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify({"success": True, "task_id": task.id}), 200)
@connectors_ns.route("/api/connectors/task_status")
class ConnectorTaskStatus(Resource):
task_status_model = api.model(
"ConnectorTaskStatusModel",
{"task_id": fields.String(required=True, description="Task ID")},
)
@api.expect(task_status_model)
@api.doc(description="Get connector task status")
def get(self):
task_id = request.args.get("task_id")
if not task_id:
return make_response(
jsonify({"success": False, "message": "Task ID is required"}), 400
)
try:
from application.celery_init import celery
task = celery.AsyncResult(task_id)
task_meta = task.info
print(f"Task status: {task.status}")
if not isinstance(
task_meta, (dict, list, str, int, float, bool, type(None))
):
task_meta = str(task_meta)
except Exception as err:
current_app.logger.error(f"Error getting task status: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify({"status": task.status, "result": task_meta}), 200)
@connectors_ns.route("/api/connectors/sources")
class ConnectorSources(Resource):
@api.doc(description="Get connector sources")
def get(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
try:
sources = sources_collection.find({"user": user, "type": "connector:file"}).sort("date", -1)
connector_sources = []
for source in sources:
connector_sources.append({
"id": str(source["_id"]),
"name": source.get("name"),
"date": source.get("date"),
"type": source.get("type"),
"source": source.get("source"),
"tokens": source.get("tokens", ""),
"retriever": source.get("retriever", "classic"),
"syncFrequency": source.get("sync_frequency", ""),
})
except Exception as err:
current_app.logger.error(f"Error retrieving connector sources: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify(connector_sources), 200)
@connectors_ns.route("/api/connectors/delete")
class DeleteConnectorSource(Resource):
@api.doc(
description="Delete a connector source",
params={"source_id": "The source ID to delete"},
)
def delete(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
source_id = request.args.get("source_id")
if not source_id:
return make_response(
jsonify({"success": False, "message": "source_id is required"}), 400
)
try:
result = sources_collection.delete_one(
{"_id": ObjectId(source_id), "user": decoded_token.get("sub")}
)
if result.deleted_count == 0:
return make_response(
jsonify({"success": False, "message": "Source not found"}), 404
)
except Exception as err:
current_app.logger.error(
f"Error deleting connector source: {err}", exc_info=True
)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify({"success": True}), 200)
@connectors_ns.route("/api/connectors/auth")
class ConnectorAuth(Resource):
@api.doc(description="Get connector OAuth authorization URL", params={"provider": "Connector provider (e.g., google_drive)"})
def get(self):
try:
provider = request.args.get('provider') or request.args.get('source')
if not provider:
return make_response(jsonify({"success": False, "error": "Missing provider"}), 400)
if not ConnectorCreator.is_supported(provider):
return make_response(jsonify({"success": False, "error": f"Unsupported provider: {provider}"}), 400)
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False, "error": "Unauthorized"}), 401)
user_id = decoded_token.get('sub')
now = datetime.datetime.now(datetime.timezone.utc)
result = sessions_collection.insert_one({
"provider": provider,
"user": user_id,
"status": "pending",
"created_at": now
})
state_dict = {
"provider": provider,
"object_id": str(result.inserted_id)
}
state = base64.urlsafe_b64encode(json.dumps(state_dict).encode()).decode()
auth = ConnectorCreator.create_auth(provider)
authorization_url = auth.get_authorization_url(state=state)
return make_response(jsonify({
"success": True,
"authorization_url": authorization_url,
"state": state
}), 200)
except Exception as e:
current_app.logger.error(f"Error generating connector auth URL: {e}")
return make_response(jsonify({"success": False, "error": str(e)}), 500)
@connectors_ns.route("/api/connectors/callback")
class ConnectorsCallback(Resource):
@api.doc(description="Handle OAuth callback for external connectors")
def get(self):
"""Handle OAuth callback for external connectors"""
try:
from application.parser.connectors.connector_creator import ConnectorCreator
from flask import request, redirect
authorization_code = request.args.get('code')
state = request.args.get('state')
error = request.args.get('error')
state_dict = json.loads(base64.urlsafe_b64decode(state.encode()).decode())
provider = state_dict["provider"]
state_object_id = state_dict["object_id"]
if error:
if error == "access_denied":
return redirect(f"/api/connectors/callback-status?status=cancelled&message=Authentication+was+cancelled.+You+can+try+again+if+you'd+like+to+connect+your+account.&provider={provider}")
else:
current_app.logger.warning(f"OAuth error in callback: {error}")
return redirect(f"/api/connectors/callback-status?status=error&message=Authentication+failed.+Please+try+again+and+make+sure+to+grant+all+requested+permissions.&provider={provider}")
if not authorization_code:
return redirect(f"/api/connectors/callback-status?status=error&message=Authentication+failed.+Please+try+again+and+make+sure+to+grant+all+requested+permissions.&provider={provider}")
try:
auth = ConnectorCreator.create_auth(provider)
token_info = auth.exchange_code_for_tokens(authorization_code)
session_token = str(uuid.uuid4())
try:
credentials = auth.create_credentials_from_token_info(token_info)
service = auth.build_drive_service(credentials)
user_info = service.about().get(fields="user").execute()
user_email = user_info.get('user', {}).get('emailAddress', 'Connected User')
except Exception as e:
current_app.logger.warning(f"Could not get user info: {e}")
user_email = 'Connected User'
sanitized_token_info = {
"access_token": token_info.get("access_token"),
"refresh_token": token_info.get("refresh_token"),
"token_uri": token_info.get("token_uri"),
"expiry": token_info.get("expiry")
}
sessions_collection.find_one_and_update(
{"_id": ObjectId(state_object_id), "provider": provider},
{
"$set": {
"session_token": session_token,
"token_info": sanitized_token_info,
"user_email": user_email,
"status": "authorized"
}
}
)
# Redirect to success page with session token and user email
return redirect(f"/api/connectors/callback-status?status=success&message=Authentication+successful&provider={provider}&session_token={session_token}&user_email={user_email}")
except Exception as e:
current_app.logger.error(f"Error exchanging code for tokens: {str(e)}", exc_info=True)
return redirect(f"/api/connectors/callback-status?status=error&message=Authentication+failed.+Please+try+again+and+make+sure+to+grant+all+requested+permissions.&provider={provider}")
except Exception as e:
current_app.logger.error(f"Error handling connector callback: {e}")
return redirect("/api/connectors/callback-status?status=error&message=Authentication+failed.+Please+try+again+and+make+sure+to+grant+all+requested+permissions.")
@connectors_ns.route("/api/connectors/refresh")
class ConnectorRefresh(Resource):
@api.expect(api.model("ConnectorRefreshModel", {"provider": fields.String(required=True), "refresh_token": fields.String(required=True)}))
@api.doc(description="Refresh connector access token")
def post(self):
try:
data = request.get_json()
provider = data.get('provider')
refresh_token = data.get('refresh_token')
if not provider or not refresh_token:
return make_response(jsonify({"success": False, "error": "provider and refresh_token are required"}), 400)
auth = ConnectorCreator.create_auth(provider)
token_info = auth.refresh_access_token(refresh_token)
return make_response(jsonify({"success": True, "token_info": token_info}), 200)
except Exception as e:
current_app.logger.error(f"Error refreshing token for connector: {e}")
return make_response(jsonify({"success": False, "error": str(e)}), 500)
@connectors_ns.route("/api/connectors/files")
class ConnectorFiles(Resource):
@api.expect(api.model("ConnectorFilesModel", {
"provider": fields.String(required=True),
"session_token": fields.String(required=True),
"folder_id": fields.String(required=False),
"limit": fields.Integer(required=False),
"page_token": fields.String(required=False),
"search_query": fields.String(required=False)
}))
@api.doc(description="List files from a connector provider (supports pagination and search)")
def post(self):
try:
data = request.get_json()
provider = data.get('provider')
session_token = data.get('session_token')
folder_id = data.get('folder_id')
limit = data.get('limit', 10)
page_token = data.get('page_token')
search_query = data.get('search_query')
if not provider or not session_token:
return make_response(jsonify({"success": False, "error": "provider and session_token are required"}), 400)
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False, "error": "Unauthorized"}), 401)
user = decoded_token.get('sub')
session = sessions_collection.find_one({"session_token": session_token, "user": user})
if not session:
return make_response(jsonify({"success": False, "error": "Invalid or unauthorized session"}), 401)
loader = ConnectorCreator.create_connector(provider, session_token)
input_config = {
'limit': limit,
'list_only': True,
'session_token': session_token,
'folder_id': folder_id,
'page_token': page_token
}
if search_query:
input_config['search_query'] = search_query
documents = loader.load_data(input_config)
files = []
for doc in documents[:limit]:
metadata = doc.extra_info
modified_time = metadata.get('modified_time')
if modified_time:
date_part = modified_time.split('T')[0]
time_part = modified_time.split('T')[1].split('.')[0].split('Z')[0]
formatted_time = f"{date_part} {time_part}"
else:
formatted_time = None
files.append({
'id': doc.doc_id,
'name': metadata.get('file_name', 'Unknown File'),
'type': metadata.get('mime_type', 'unknown'),
'size': metadata.get('size', None),
'modifiedTime': formatted_time,
'isFolder': metadata.get('is_folder', False)
})
next_token = getattr(loader, 'next_page_token', None)
has_more = bool(next_token)
return make_response(jsonify({
"success": True,
"files": files,
"total": len(files),
"next_page_token": next_token,
"has_more": has_more
}), 200)
except Exception as e:
current_app.logger.error(f"Error loading connector files: {e}")
return make_response(jsonify({"success": False, "error": f"Failed to load files: {str(e)}"}), 500)
@connectors_ns.route("/api/connectors/validate-session")
class ConnectorValidateSession(Resource):
@api.expect(api.model("ConnectorValidateSessionModel", {"provider": fields.String(required=True), "session_token": fields.String(required=True)}))
@api.doc(description="Validate connector session token and return user info and access token")
def post(self):
try:
data = request.get_json()
provider = data.get('provider')
session_token = data.get('session_token')
if not provider or not session_token:
return make_response(jsonify({"success": False, "error": "provider and session_token are required"}), 400)
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False, "error": "Unauthorized"}), 401)
user = decoded_token.get('sub')
session = sessions_collection.find_one({"session_token": session_token, "user": user})
if not session or "token_info" not in session:
return make_response(jsonify({"success": False, "error": "Invalid or expired session"}), 401)
token_info = session["token_info"]
auth = ConnectorCreator.create_auth(provider)
is_expired = auth.is_token_expired(token_info)
if is_expired and token_info.get('refresh_token'):
try:
refreshed_token_info = auth.refresh_access_token(token_info.get('refresh_token'))
sanitized_token_info = {
"access_token": refreshed_token_info.get("access_token"),
"refresh_token": refreshed_token_info.get("refresh_token"),
"token_uri": refreshed_token_info.get("token_uri"),
"expiry": refreshed_token_info.get("expiry")
}
sessions_collection.update_one(
{"session_token": session_token},
{"$set": {"token_info": sanitized_token_info}}
)
token_info = sanitized_token_info
is_expired = False
except Exception as refresh_error:
current_app.logger.error(f"Failed to refresh token: {refresh_error}")
if is_expired:
return make_response(jsonify({
"success": False,
"expired": True,
"error": "Session token has expired. Please reconnect."
}), 401)
return make_response(jsonify({
"success": True,
"expired": False,
"user_email": session.get('user_email', 'Connected User'),
"access_token": token_info.get('access_token')
}), 200)
except Exception as e:
current_app.logger.error(f"Error validating connector session: {e}")
return make_response(jsonify({"success": False, "error": str(e)}), 500)
@connectors_ns.route("/api/connectors/disconnect")
class ConnectorDisconnect(Resource):
@api.expect(api.model("ConnectorDisconnectModel", {"provider": fields.String(required=True), "session_token": fields.String(required=False)}))
@api.doc(description="Disconnect a connector session")
def post(self):
try:
data = request.get_json()
provider = data.get('provider')
session_token = data.get('session_token')
if not provider:
return make_response(jsonify({"success": False, "error": "provider is required"}), 400)
if session_token:
sessions_collection.delete_one({"session_token": session_token})
return make_response(jsonify({"success": True}), 200)
except Exception as e:
current_app.logger.error(f"Error disconnecting connector session: {e}")
return make_response(jsonify({"success": False, "error": str(e)}), 500)
@connectors_ns.route("/api/connectors/sync")
class ConnectorSync(Resource):
@api.expect(
api.model(
"ConnectorSyncModel",
{
"source_id": fields.String(required=True, description="Source ID to sync"),
"session_token": fields.String(required=True, description="Authentication token")
},
)
)
@api.doc(description="Sync connector source to check for modifications")
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
try:
data = request.get_json()
source_id = data.get('source_id')
session_token = data.get('session_token')
if not all([source_id, session_token]):
return make_response(
jsonify({
"success": False,
"error": "source_id and session_token are required"
}),
400
)
source = sources_collection.find_one({"_id": ObjectId(source_id)})
if not source:
return make_response(
jsonify({
"success": False,
"error": "Source not found"
}),
404
)
if source.get('user') != decoded_token.get('sub'):
return make_response(
jsonify({
"success": False,
"error": "Unauthorized access to source"
}),
403
)
remote_data = {}
try:
if source.get('remote_data'):
remote_data = json.loads(source.get('remote_data'))
except json.JSONDecodeError:
current_app.logger.error(f"Invalid remote_data format for source {source_id}")
remote_data = {}
source_type = remote_data.get('provider')
if not source_type:
return make_response(
jsonify({
"success": False,
"error": "Source provider not found in remote_data"
}),
400
)
# Extract configuration from remote_data
file_ids = remote_data.get('file_ids', [])
folder_ids = remote_data.get('folder_ids', [])
recursive = remote_data.get('recursive', True)
# Start the sync task
task = ingest_connector_task.delay(
job_name=source.get('name'),
user=decoded_token.get('sub'),
source_type=source_type,
session_token=session_token,
file_ids=file_ids,
folder_ids=folder_ids,
recursive=recursive,
retriever=source.get('retriever', 'classic'),
operation_mode="sync",
doc_id=source_id,
sync_frequency=source.get('sync_frequency', 'never')
)
return make_response(
jsonify({
"success": True,
"task_id": task.id
}),
200
)
except Exception as err:
current_app.logger.error(
f"Error syncing connector source: {err}",
exc_info=True
)
return make_response(
jsonify({
"success": False,
"error": str(err)
}),
400
)
@connectors_ns.route("/api/connectors/callback-status")
class ConnectorCallbackStatus(Resource):
@api.doc(description="Return HTML page with connector authentication status")
def get(self):
"""Return HTML page with connector authentication status"""
try:
status = request.args.get('status', 'error')
message = request.args.get('message', '')
provider = request.args.get('provider', 'connector')
session_token = request.args.get('session_token', '')
user_email = request.args.get('user_email', '')
html_content = f"""
<!DOCTYPE html>
<html>
<head>
<title>{provider.replace('_', ' ').title()} Authentication</title>
<style>
body {{ font-family: Arial, sans-serif; text-align: center; padding: 40px; }}
.container {{ max-width: 600px; margin: 0 auto; }}
.success {{ color: #4CAF50; }}
.error {{ color: #F44336; }}
.cancelled {{ color: #FF9800; }}
</style>
<script>
window.onload = function() {{
const status = "{status}";
const sessionToken = "{session_token}";
const userEmail = "{user_email}";
if (status === "success" && window.opener) {{
window.opener.postMessage({{
type: '{provider}_auth_success',
session_token: sessionToken,
user_email: userEmail
}}, '*');
setTimeout(() => window.close(), 3000);
}} else if (status === "cancelled" || status === "error") {{
setTimeout(() => window.close(), 3000);
}}
}};
</script>
</head>
<body>
<div class="container">
<h2>{provider.replace('_', ' ').title()} Authentication</h2>
<div class="{status}">
<p>{message}</p>
{f'<p>Connected as: {user_email}</p>' if status == 'success' else ''}
</div>
<p><small>You can close this window. {f"Your {provider.replace('_', ' ').title()} is now connected and ready to use." if status == 'success' else "Feel free to close this window."}</small></p>
</div>
</body>
</html>
"""
return make_response(html_content, 200, {'Content-Type': 'text/html'})
except Exception as e:
current_app.logger.error(f"Error rendering callback status page: {e}")
return make_response("Authentication error occurred", 500, {'Content-Type': 'text/html'})

View File

@@ -1,6 +1,5 @@
import os
import datetime
import json
from flask import Blueprint, request, send_from_directory
from werkzeug.utils import secure_filename
from bson.objectid import ObjectId
@@ -38,28 +37,16 @@ def upload_index_files():
"""Upload two files(index.faiss, index.pkl) to the user's folder."""
if "user" not in request.form:
return {"status": "no user"}
user = request.form["user"]
user = secure_filename(request.form["user"])
if "name" not in request.form:
return {"status": "no name"}
job_name = request.form["name"]
tokens = request.form["tokens"]
retriever = request.form["retriever"]
id = request.form["id"]
type = request.form["type"]
job_name = secure_filename(request.form["name"])
tokens = secure_filename(request.form["tokens"])
retriever = secure_filename(request.form["retriever"])
id = secure_filename(request.form["id"])
type = secure_filename(request.form["type"])
remote_data = request.form["remote_data"] if "remote_data" in request.form else None
sync_frequency = request.form["sync_frequency"] if "sync_frequency" in request.form else None
file_path = request.form.get("file_path")
directory_structure = request.form.get("directory_structure")
if directory_structure:
try:
directory_structure = json.loads(directory_structure)
except Exception:
logger.error("Error parsing directory_structure")
directory_structure = {}
else:
directory_structure = {}
sync_frequency = secure_filename(request.form["sync_frequency"]) if "sync_frequency" in request.form else None
storage = StorageCreator.get_storage()
index_base_path = f"indexes/{id}"
@@ -77,13 +64,10 @@ def upload_index_files():
file_pkl = request.files["file_pkl"]
if file_pkl.filename == "":
return {"status": "no file name"}
# Save index files to storage
faiss_storage_path = f"{index_base_path}/index.faiss"
pkl_storage_path = f"{index_base_path}/index.pkl"
storage.save_file(file_faiss, faiss_storage_path)
storage.save_file(file_pkl, pkl_storage_path)
storage.save_file(file_faiss, f"{index_base_path}/index.faiss")
storage.save_file(file_pkl, f"{index_base_path}/index.pkl")
existing_entry = sources_collection.find_one({"_id": ObjectId(id)})
if existing_entry:
@@ -101,8 +85,6 @@ def upload_index_files():
"retriever": retriever,
"remote_data": remote_data,
"sync_frequency": sync_frequency,
"file_path": file_path,
"directory_structure": directory_structure,
}
},
)
@@ -120,8 +102,6 @@ def upload_index_files():
"retriever": retriever,
"remote_data": remote_data,
"sync_frequency": sync_frequency,
"file_path": file_path,
"directory_structure": directory_structure,
}
)
return {"status": "ok"}

File diff suppressed because it is too large Load Diff

View File

@@ -5,16 +5,14 @@ from application.worker import (
agent_webhook_worker,
attachment_worker,
ingest_worker,
mcp_oauth,
mcp_oauth_status,
remote_worker,
sync_worker,
)
@celery.task(bind=True)
def ingest(self, directory, formats, job_name, user, file_path, filename):
resp = ingest_worker(self, directory, formats, job_name, file_path, filename, user)
def ingest(self, directory, formats, name_job, filename, user):
resp = ingest_worker(self, directory, formats, name_job, filename, user)
return resp
@@ -24,14 +22,6 @@ def ingest_remote(self, source_data, job_name, user, loader):
return resp
@celery.task(bind=True)
def reingest_source_task(self, source_id, user):
from application.worker import reingest_source_worker
resp = reingest_source_worker(self, source_id, user)
return resp
@celery.task(bind=True)
def schedule_syncs(self, frequency):
resp = sync_worker(self, frequency)
@@ -50,40 +40,6 @@ def process_agent_webhook(self, agent_id, payload):
return resp
@celery.task(bind=True)
def ingest_connector_task(
self,
job_name,
user,
source_type,
session_token=None,
file_ids=None,
folder_ids=None,
recursive=True,
retriever="classic",
operation_mode="upload",
doc_id=None,
sync_frequency="never",
):
from application.worker import ingest_connector
resp = ingest_connector(
self,
job_name,
user,
source_type,
session_token=session_token,
file_ids=file_ids,
folder_ids=folder_ids,
recursive=recursive,
retriever=retriever,
operation_mode=operation_mode,
doc_id=doc_id,
sync_frequency=sync_frequency,
)
return resp
@celery.on_after_configure.connect
def setup_periodic_tasks(sender, **kwargs):
sender.add_periodic_task(
@@ -98,15 +54,3 @@ def setup_periodic_tasks(sender, **kwargs):
timedelta(days=30),
schedule_syncs.s("monthly"),
)
@celery.task(bind=True)
def mcp_oauth_task(self, config, user):
resp = mcp_oauth(self, config, user)
return resp
@celery.task(bind=True)
def mcp_oauth_status_task(self, task_id):
resp = mcp_oauth_status(self, task_id)
return resp

View File

@@ -12,26 +12,25 @@ from application.core.logging_config import setup_logging
setup_logging()
from application.api import api # noqa: E402
from application.api.answer import answer # noqa: E402
from application.api.answer.routes import answer # noqa: E402
from application.api.internal.routes import internal # noqa: E402
from application.api.user.routes import user # noqa: E402
from application.api.connector.routes import connector # noqa: E402
from application.celery_init import celery # noqa: E402
from application.core.settings import settings # noqa: E402
from application.extensions import api # noqa: E402
if platform.system() == "Windows":
import pathlib
pathlib.PosixPath = pathlib.WindowsPath
dotenv.load_dotenv()
app = Flask(__name__)
app.register_blueprint(user)
app.register_blueprint(answer)
app.register_blueprint(internal)
app.register_blueprint(connector)
app.config.update(
UPLOAD_FOLDER="inputs",
CELERY_BROKER_URL=settings.CELERY_BROKER_URL,
@@ -53,6 +52,7 @@ if settings.AUTH_TYPE in ("simple_jwt", "session_jwt") and not settings.JWT_SECR
settings.JWT_SECRET_KEY = new_key
except Exception as e:
raise RuntimeError(f"Failed to setup JWT_SECRET_KEY: {e}")
SIMPLE_JWT_TOKEN = None
if settings.AUTH_TYPE == "simple_jwt":
payload = {"sub": "local"}
@@ -92,6 +92,7 @@ def generate_token():
def authenticate_request():
if request.method == "OPTIONS":
return "", 200
decoded_token = handle_auth(request)
if not decoded_token:
request.decoded_token = None

View File

@@ -10,7 +10,7 @@ current_dir = os.path.dirname(
class Settings(BaseSettings):
AUTH_TYPE: Optional[str] = None # simple_jwt, session_jwt, or None
AUTH_TYPE: Optional[str] = None
LLM_PROVIDER: str = "docsgpt"
LLM_NAME: Optional[str] = (
None # if LLM_PROVIDER is openai, LLM_NAME can be gpt-4 or gpt-3.5-turbo
@@ -26,26 +26,19 @@ class Settings(BaseSettings):
"gpt-4o-mini": 128000,
"gpt-3.5-turbo": 4096,
"claude-2": 1e5,
"gemini-2.5-flash": 1e6,
"gemini-2.0-flash-exp": 1e6,
}
UPLOAD_FOLDER: str = "inputs"
PARSE_PDF_AS_IMAGE: bool = False
PARSE_IMAGE_REMOTE: bool = False
VECTOR_STORE: str = (
"faiss" # "faiss" or "elasticsearch" or "qdrant" or "milvus" or "lancedb"
)
RETRIEVERS_ENABLED: list = ["classic_rag"]
RETRIEVERS_ENABLED: list = ["classic_rag", "duckduck_search"] # also brave_search
AGENT_NAME: str = "classic"
FALLBACK_LLM_PROVIDER: Optional[str] = None # provider for fallback llm
FALLBACK_LLM_NAME: Optional[str] = None # model name for fallback llm
FALLBACK_LLM_API_KEY: Optional[str] = None # api key for fallback llm
# Google Drive integration
GOOGLE_CLIENT_ID: Optional[str] = None # Replace with your actual Google OAuth client ID
GOOGLE_CLIENT_SECRET: Optional[str] = None# Replace with your actual Google OAuth client secret
CONNECTOR_REDIRECT_BASE_URI: Optional[str] = "http://127.0.0.1:7091/api/connectors/callback" ##add redirect url as it is to your provider's console(gcp)
# LLM Cache
CACHE_REDIS_URL: str = "redis://localhost:6379/2"
@@ -96,8 +89,6 @@ class Settings(BaseSettings):
QDRANT_PATH: Optional[str] = None
QDRANT_DISTANCE_FUNC: str = "Cosine"
# PGVector vectorstore config
PGVECTOR_CONNECTION_STRING: Optional[str] = None
# Milvus vectorstore config
MILVUS_COLLECTION_NAME: Optional[str] = "docsgpt"
MILVUS_URI: Optional[str] = "./milvus_local.db" # milvus lite version as default
@@ -108,16 +99,13 @@ class Settings(BaseSettings):
LANCEDB_TABLE_NAME: Optional[str] = (
"docsgpts" # Name of the table to use for storing vectors
)
BRAVE_SEARCH_API_KEY: Optional[str] = None
FLASK_DEBUG_MODE: bool = False
STORAGE_TYPE: str = "local" # local or s3
URL_STRATEGY: str = "backend" # backend or s3
JWT_SECRET_KEY: str = ""
# Encryption settings
ENCRYPTION_SECRET_KEY: str = "default-docsgpt-encryption-key"
path = Path(__file__).parent.parent.absolute()
settings = Settings(_env_file=path.joinpath(".env"), _env_file_encoding="utf-8")

View File

@@ -0,0 +1,7 @@
from flask_restx import Api
api = Api(
version="1.0",
title="DocsGPT API",
description="API for DocsGPT",
)

View File

@@ -120,20 +120,6 @@ class BaseLLM(ABC):
def _supports_tools(self):
raise NotImplementedError("Subclass must implement _supports_tools method")
def supports_structured_output(self):
"""Check if the LLM supports structured output/JSON schema enforcement"""
return hasattr(self, "_supports_structured_output") and callable(
getattr(self, "_supports_structured_output")
)
def _supports_structured_output(self):
return False
def prepare_structured_output_format(self, json_schema):
"""Prepare structured output format specific to the LLM provider"""
_ = json_schema
return None
def get_supported_attachment_types(self):
"""
Return a list of MIME types supported by this LLM for file uploads.
@@ -141,4 +127,4 @@ class BaseLLM(ABC):
Returns:
list: List of supported MIME types
"""
return []
return [] # Default: no attachments supported

View File

@@ -1,13 +1,11 @@
import json
import logging
from google import genai
from google.genai import types
from application.core.settings import settings
import logging
import json
from application.llm.base import BaseLLM
from application.storage.storage_creator import StorageCreator
from application.core.settings import settings
class GoogleLLM(BaseLLM):
@@ -26,12 +24,12 @@ class GoogleLLM(BaseLLM):
list: List of supported MIME types
"""
return [
"application/pdf",
"image/png",
"image/jpeg",
"image/jpg",
"image/webp",
"image/gif",
'application/pdf',
'image/png',
'image/jpeg',
'image/jpg',
'image/webp',
'image/gif'
]
def prepare_messages_with_attachments(self, messages, attachments=None):
@@ -72,30 +70,26 @@ class GoogleLLM(BaseLLM):
files = []
for attachment in attachments:
mime_type = attachment.get("mime_type")
mime_type = attachment.get('mime_type')
if mime_type in self.get_supported_attachment_types():
try:
file_uri = self._upload_file_to_google(attachment)
logging.info(
f"GoogleLLM: Successfully uploaded file, got URI: {file_uri}"
)
logging.info(f"GoogleLLM: Successfully uploaded file, got URI: {file_uri}")
files.append({"file_uri": file_uri, "mime_type": mime_type})
except Exception as e:
logging.error(
f"GoogleLLM: Error uploading file: {e}", exc_info=True
)
if "content" in attachment:
prepared_messages[user_message_index]["content"].append(
{
"type": "text",
"text": f"[File could not be processed: {attachment.get('path', 'unknown')}]",
}
)
logging.error(f"GoogleLLM: Error uploading file: {e}", exc_info=True)
if 'content' in attachment:
prepared_messages[user_message_index]["content"].append({
"type": "text",
"text": f"[File could not be processed: {attachment.get('path', 'unknown')}]"
})
if files:
logging.info(f"GoogleLLM: Adding {len(files)} files to message")
prepared_messages[user_message_index]["content"].append({"files": files})
prepared_messages[user_message_index]["content"].append({
"files": files
})
return prepared_messages
@@ -109,10 +103,10 @@ class GoogleLLM(BaseLLM):
Returns:
str: Google AI file URI for the uploaded file.
"""
if "google_file_uri" in attachment:
return attachment["google_file_uri"]
if 'google_file_uri' in attachment:
return attachment['google_file_uri']
file_path = attachment.get("path")
file_path = attachment.get('path')
if not file_path:
raise ValueError("No file path provided in attachment")
@@ -122,19 +116,17 @@ class GoogleLLM(BaseLLM):
try:
file_uri = self.storage.process_file(
file_path,
lambda local_path, **kwargs: self.client.files.upload(
file=local_path
).uri,
lambda local_path, **kwargs: self.client.files.upload(file=local_path).uri
)
from application.core.mongo_db import MongoDB
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
attachments_collection = db["attachments"]
if "_id" in attachment:
if '_id' in attachment:
attachments_collection.update_one(
{"_id": attachment["_id"]}, {"$set": {"google_file_uri": file_uri}}
{"_id": attachment['_id']},
{"$set": {"google_file_uri": file_uri}}
)
return file_uri
@@ -143,7 +135,6 @@ class GoogleLLM(BaseLLM):
raise
def _clean_messages_google(self, messages):
"""Convert OpenAI format messages to Google AI format."""
cleaned_messages = []
for message in messages:
role = message.get("role")
@@ -151,8 +142,6 @@ class GoogleLLM(BaseLLM):
if role == "assistant":
role = "model"
elif role == "tool":
role = "model"
parts = []
if role and content is not None:
@@ -177,13 +166,13 @@ class GoogleLLM(BaseLLM):
)
)
elif "files" in item:
for file_data in item["files"]:
parts.append(
types.Part.from_uri(
file_uri=file_data["file_uri"],
mime_type=file_data["mime_type"],
for file_data in item["files"]:
parts.append(
types.Part.from_uri(
file_uri=file_data["file_uri"],
mime_type=file_data["mime_type"]
)
)
)
else:
raise ValueError(
f"Unexpected content dictionary format:{item}"
@@ -191,63 +180,11 @@ class GoogleLLM(BaseLLM):
else:
raise ValueError(f"Unexpected content type: {type(content)}")
if parts:
cleaned_messages.append(types.Content(role=role, parts=parts))
cleaned_messages.append(types.Content(role=role, parts=parts))
return cleaned_messages
def _clean_schema(self, schema_obj):
"""
Recursively remove unsupported fields from schema objects
and validate required properties.
"""
if not isinstance(schema_obj, dict):
return schema_obj
allowed_fields = {
"type",
"description",
"items",
"properties",
"required",
"enum",
"pattern",
"minimum",
"maximum",
"nullable",
"default",
}
cleaned = {}
for key, value in schema_obj.items():
if key not in allowed_fields:
continue
elif key == "type" and isinstance(value, str):
cleaned[key] = value.upper()
elif isinstance(value, dict):
cleaned[key] = self._clean_schema(value)
elif isinstance(value, list):
cleaned[key] = [self._clean_schema(item) for item in value]
else:
cleaned[key] = value
# Validate that required properties actually exist in properties
if "required" in cleaned and "properties" in cleaned:
valid_required = []
properties_keys = set(cleaned["properties"].keys())
for required_prop in cleaned["required"]:
if required_prop in properties_keys:
valid_required.append(required_prop)
if valid_required:
cleaned["required"] = valid_required
else:
cleaned.pop("required", None)
elif "required" in cleaned and "properties" not in cleaned:
cleaned.pop("required", None)
return cleaned
def _clean_tools_format(self, tools_list):
"""Convert OpenAI format tools to Google AI format."""
genai_tools = []
for tool_data in tools_list:
if tool_data["type"] == "function":
@@ -256,16 +193,18 @@ class GoogleLLM(BaseLLM):
properties = parameters.get("properties", {})
if properties:
cleaned_properties = {}
for k, v in properties.items():
cleaned_properties[k] = self._clean_schema(v)
genai_function = dict(
name=function["name"],
description=function["description"],
parameters={
"type": "OBJECT",
"properties": cleaned_properties,
"properties": {
k: {
**v,
"type": v["type"].upper() if v["type"] else None,
}
for k, v in properties.items()
},
"required": (
parameters["required"]
if "required" in parameters
@@ -292,10 +231,8 @@ class GoogleLLM(BaseLLM):
stream=False,
tools=None,
formatting="openai",
response_schema=None,
**kwargs,
):
"""Generate content using Google AI API without streaming."""
client = genai.Client(api_key=self.api_key)
if formatting == "openai":
messages = self._clean_messages_google(messages)
@@ -307,21 +244,16 @@ class GoogleLLM(BaseLLM):
if tools:
cleaned_tools = self._clean_tools_format(tools)
config.tools = cleaned_tools
# Add response schema for structured output if provided
if response_schema:
config.response_schema = response_schema
config.response_mime_type = "application/json"
response = client.models.generate_content(
model=model,
contents=messages,
config=config,
)
if tools:
response = client.models.generate_content(
model=model,
contents=messages,
config=config,
)
return response
else:
response = client.models.generate_content(
model=model, contents=messages, config=config
)
return response.text
def _raw_gen_stream(
@@ -332,10 +264,8 @@ class GoogleLLM(BaseLLM):
stream=True,
tools=None,
formatting="openai",
response_schema=None,
**kwargs,
):
"""Generate content using Google AI API with streaming."""
client = genai.Client(api_key=self.api_key)
if formatting == "openai":
messages = self._clean_messages_google(messages)
@@ -348,24 +278,17 @@ class GoogleLLM(BaseLLM):
cleaned_tools = self._clean_tools_format(tools)
config.tools = cleaned_tools
# Add response schema for structured output if provided
if response_schema:
config.response_schema = response_schema
config.response_mime_type = "application/json"
# Check if we have both tools and file attachments
has_attachments = False
for message in messages:
for part in message.parts:
if hasattr(part, "file_data") and part.file_data is not None:
if hasattr(part, 'file_data') and part.file_data is not None:
has_attachments = True
break
if has_attachments:
break
logging.info(
f"GoogleLLM: Starting stream generation. Model: {model}, Messages: {json.dumps(messages, default=str)}, Has attachments: {has_attachments}"
)
logging.info(f"GoogleLLM: Starting stream generation. Model: {model}, Messages: {json.dumps(messages, default=str)}, Has attachments: {has_attachments}")
response = client.models.generate_content_stream(
model=model,
@@ -373,6 +296,7 @@ class GoogleLLM(BaseLLM):
config=config,
)
for chunk in response:
if hasattr(chunk, "candidates") and chunk.candidates:
for candidate in chunk.candidates:
@@ -386,79 +310,4 @@ class GoogleLLM(BaseLLM):
yield chunk.text
def _supports_tools(self):
"""Return whether this LLM supports function calling."""
return True
def _supports_structured_output(self):
"""Return whether this LLM supports structured JSON output."""
return True
def prepare_structured_output_format(self, json_schema):
"""Convert JSON schema to Google AI structured output format."""
if not json_schema:
return None
type_map = {
"object": "OBJECT",
"array": "ARRAY",
"string": "STRING",
"integer": "INTEGER",
"number": "NUMBER",
"boolean": "BOOLEAN",
}
def convert(schema):
if not isinstance(schema, dict):
return schema
result = {}
schema_type = schema.get("type")
if schema_type:
result["type"] = type_map.get(schema_type.lower(), schema_type.upper())
for key in [
"description",
"nullable",
"enum",
"minItems",
"maxItems",
"required",
"propertyOrdering",
]:
if key in schema:
result[key] = schema[key]
if "format" in schema:
format_value = schema["format"]
if schema_type == "string":
if format_value == "date":
result["format"] = "date-time"
elif format_value in ["enum", "date-time"]:
result["format"] = format_value
else:
result["format"] = format_value
if "properties" in schema:
result["properties"] = {
k: convert(v) for k, v in schema["properties"].items()
}
if "propertyOrdering" not in result and result.get("type") == "OBJECT":
result["propertyOrdering"] = list(result["properties"].keys())
if "items" in schema:
result["items"] = convert(schema["items"])
for field in ["anyOf", "oneOf", "allOf"]:
if field in schema:
result[field] = [convert(s) for s in schema[field]]
return result
try:
return convert(json_schema)
except Exception as e:
logging.error(
f"Error preparing structured output format for Google: {e}",
exc_info=True,
)
return None

View File

@@ -205,6 +205,7 @@ class LLMHandler(ABC):
except StopIteration as e:
tool_response, call_id = e.value
break
updated_messages.append(
{
"role": "assistant",
@@ -221,36 +222,17 @@ class LLMHandler(ABC):
)
updated_messages.append(self.create_tool_message(call, tool_response))
except Exception as e:
logger.error(f"Error executing tool: {str(e)}", exc_info=True)
error_call = ToolCall(
id=call.id, name=call.name, arguments=call.arguments
updated_messages.append(
{
"role": "tool",
"content": f"Error executing tool: {str(e)}",
"tool_call_id": call.id,
}
)
error_response = f"Error executing tool: {str(e)}"
error_message = self.create_tool_message(error_call, error_response)
updated_messages.append(error_message)
call_parts = call.name.split("_")
if len(call_parts) >= 2:
tool_id = call_parts[-1] # Last part is tool ID (e.g., "1")
action_name = "_".join(call_parts[:-1])
tool_name = tools_dict.get(tool_id, {}).get("name", "unknown_tool")
full_action_name = f"{action_name}_{tool_id}"
else:
tool_name = "unknown_tool"
action_name = call.name
full_action_name = call.name
yield {
"type": "tool_call",
"data": {
"tool_name": tool_name,
"call_id": call.id,
"action_name": full_action_name,
"arguments": call.arguments,
"error": error_response,
"status": "error",
},
}
return updated_messages
def handle_non_streaming(
@@ -281,11 +263,13 @@ class LLMHandler(ABC):
except StopIteration as e:
messages = e.value
break
response = agent.llm.gen(
model=agent.gpt_model, messages=messages, tools=agent.tools
)
parsed = self.parse_response(response)
self.llm_calls.append(build_stack_data(agent.llm))
return parsed.content
def handle_streaming(

View File

@@ -17,6 +17,7 @@ class GoogleLLMHandler(LLMHandler):
finish_reason="stop",
raw_response=response,
)
if hasattr(response, "candidates"):
parts = response.candidates[0].content.parts if response.candidates else []
tool_calls = [
@@ -40,6 +41,7 @@ class GoogleLLMHandler(LLMHandler):
finish_reason="tool_calls" if tool_calls else "stop",
raw_response=response,
)
else:
tool_calls = []
if hasattr(response, "function_call"):
@@ -59,16 +61,14 @@ class GoogleLLMHandler(LLMHandler):
def create_tool_message(self, tool_call: ToolCall, result: Any) -> Dict:
"""Create Google-style tool message."""
from google.genai import types
return {
"role": "model",
"role": "tool",
"content": [
{
"function_response": {
"name": tool_call.name,
"response": {"result": result},
}
}
types.Part.from_function_response(
name=tool_call.name, response={"result": result}
).to_json_dict()
],
}

View File

@@ -14,5 +14,5 @@ class LLMHandlerCreator:
def create_handler(cls, llm_type: str, *args, **kwargs) -> LLMHandler:
handler_class = cls.handlers.get(llm_type.lower())
if not handler_class:
handler_class = OpenAILLMHandler
raise ValueError(f"No LLM handler class found for type {llm_type}")
return handler_class(*args, **kwargs)

View File

@@ -1,5 +1,5 @@
import base64
import json
import base64
import logging
from application.core.settings import settings
@@ -13,10 +13,7 @@ class OpenAILLM(BaseLLM):
from openai import OpenAI
super().__init__(*args, **kwargs)
if (
isinstance(settings.OPENAI_BASE_URL, str)
and settings.OPENAI_BASE_URL.strip()
):
if isinstance(settings.OPENAI_BASE_URL, str) and settings.OPENAI_BASE_URL.strip():
self.client = OpenAI(api_key=api_key, base_url=settings.OPENAI_BASE_URL)
else:
DEFAULT_OPENAI_API_BASE = "https://api.openai.com/v1"
@@ -76,30 +73,14 @@ class OpenAILLM(BaseLLM):
elif isinstance(item, dict):
content_parts = []
if "text" in item:
content_parts.append(
{"type": "text", "text": item["text"]}
)
elif (
"type" in item
and item["type"] == "text"
and "text" in item
):
content_parts.append({"type": "text", "text": item["text"]})
elif "type" in item and item["type"] == "text" and "text" in item:
content_parts.append(item)
elif (
"type" in item
and item["type"] == "file"
and "file" in item
):
elif "type" in item and item["type"] == "file" and "file" in item:
content_parts.append(item)
elif (
"type" in item
and item["type"] == "image_url"
and "image_url" in item
):
elif "type" in item and item["type"] == "image_url" and "image_url" in item:
content_parts.append(item)
cleaned_messages.append(
{"role": role, "content": content_parts}
)
cleaned_messages.append({"role": role, "content": content_parts})
else:
raise ValueError(
f"Unexpected content dictionary format: {item}"
@@ -117,29 +98,22 @@ class OpenAILLM(BaseLLM):
stream=False,
tools=None,
engine=settings.AZURE_DEPLOYMENT_NAME,
response_format=None,
**kwargs,
):
messages = self._clean_messages_openai(messages)
request_params = {
"model": model,
"messages": messages,
"stream": stream,
**kwargs,
}
if tools:
request_params["tools"] = tools
if response_format:
request_params["response_format"] = response_format
response = self.client.chat.completions.create(**request_params)
if tools:
response = self.client.chat.completions.create(
model=model,
messages=messages,
stream=stream,
tools=tools,
**kwargs,
)
return response.choices[0]
else:
response = self.client.chat.completions.create(
model=model, messages=messages, stream=stream, **kwargs
)
return response.choices[0].message.content
def _raw_gen_stream(
@@ -150,32 +124,24 @@ class OpenAILLM(BaseLLM):
stream=True,
tools=None,
engine=settings.AZURE_DEPLOYMENT_NAME,
response_format=None,
**kwargs,
):
messages = self._clean_messages_openai(messages)
request_params = {
"model": model,
"messages": messages,
"stream": stream,
**kwargs,
}
if tools:
request_params["tools"] = tools
if response_format:
request_params["response_format"] = response_format
response = self.client.chat.completions.create(**request_params)
response = self.client.chat.completions.create(
model=model,
messages=messages,
stream=stream,
tools=tools,
**kwargs,
)
else:
response = self.client.chat.completions.create(
model=model, messages=messages, stream=stream, **kwargs
)
for line in response:
if (
len(line.choices) > 0
and line.choices[0].delta.content is not None
and len(line.choices[0].delta.content) > 0
):
if len(line.choices) > 0 and line.choices[0].delta.content is not None and len(line.choices[0].delta.content) > 0:
yield line.choices[0].delta.content
elif len(line.choices) > 0:
yield line.choices[0]
@@ -183,66 +149,6 @@ class OpenAILLM(BaseLLM):
def _supports_tools(self):
return True
def _supports_structured_output(self):
return True
def prepare_structured_output_format(self, json_schema):
if not json_schema:
return None
try:
def add_additional_properties_false(schema_obj):
if isinstance(schema_obj, dict):
schema_copy = schema_obj.copy()
if schema_copy.get("type") == "object":
schema_copy["additionalProperties"] = False
# Ensure 'required' includes all properties for OpenAI strict mode
if "properties" in schema_copy:
schema_copy["required"] = list(
schema_copy["properties"].keys()
)
for key, value in schema_copy.items():
if key == "properties" and isinstance(value, dict):
schema_copy[key] = {
prop_name: add_additional_properties_false(prop_schema)
for prop_name, prop_schema in value.items()
}
elif key == "items" and isinstance(value, dict):
schema_copy[key] = add_additional_properties_false(value)
elif key in ["anyOf", "oneOf", "allOf"] and isinstance(
value, list
):
schema_copy[key] = [
add_additional_properties_false(sub_schema)
for sub_schema in value
]
return schema_copy
return schema_obj
processed_schema = add_additional_properties_false(json_schema)
result = {
"type": "json_schema",
"json_schema": {
"name": processed_schema.get("name", "response"),
"description": processed_schema.get(
"description", "Structured response"
),
"schema": processed_schema,
"strict": True,
},
}
return result
except Exception as e:
logging.error(f"Error preparing structured output format: {e}")
return None
def get_supported_attachment_types(self):
"""
Return a list of MIME types supported by OpenAI for file uploads.
@@ -251,12 +157,12 @@ class OpenAILLM(BaseLLM):
list: List of supported MIME types
"""
return [
"application/pdf",
"image/png",
"image/jpeg",
"image/jpg",
"image/webp",
"image/gif",
'application/pdf',
'image/png',
'image/jpeg',
'image/jpg',
'image/webp',
'image/gif'
]
def prepare_messages_with_attachments(self, messages, attachments=None):
@@ -296,46 +202,39 @@ class OpenAILLM(BaseLLM):
prepared_messages[user_message_index]["content"] = []
for attachment in attachments:
mime_type = attachment.get("mime_type")
mime_type = attachment.get('mime_type')
if mime_type and mime_type.startswith("image/"):
if mime_type and mime_type.startswith('image/'):
try:
base64_image = self._get_base64_image(attachment)
prepared_messages[user_message_index]["content"].append(
{
"type": "image_url",
"image_url": {
"url": f"data:{mime_type};base64,{base64_image}"
},
prepared_messages[user_message_index]["content"].append({
"type": "image_url",
"image_url": {
"url": f"data:{mime_type};base64,{base64_image}"
}
)
})
except Exception as e:
logging.error(
f"Error processing image attachment: {e}", exc_info=True
)
if "content" in attachment:
prepared_messages[user_message_index]["content"].append(
{
"type": "text",
"text": f"[Image could not be processed: {attachment.get('path', 'unknown')}]",
}
)
logging.error(f"Error processing image attachment: {e}", exc_info=True)
if 'content' in attachment:
prepared_messages[user_message_index]["content"].append({
"type": "text",
"text": f"[Image could not be processed: {attachment.get('path', 'unknown')}]"
})
# Handle PDFs using the file API
elif mime_type == "application/pdf":
elif mime_type == 'application/pdf':
try:
file_id = self._upload_file_to_openai(attachment)
prepared_messages[user_message_index]["content"].append(
{"type": "file", "file": {"file_id": file_id}}
)
prepared_messages[user_message_index]["content"].append({
"type": "file",
"file": {"file_id": file_id}
})
except Exception as e:
logging.error(f"Error uploading PDF to OpenAI: {e}", exc_info=True)
if "content" in attachment:
prepared_messages[user_message_index]["content"].append(
{
"type": "text",
"text": f"File content:\n\n{attachment['content']}",
}
)
if 'content' in attachment:
prepared_messages[user_message_index]["content"].append({
"type": "text",
"text": f"File content:\n\n{attachment['content']}"
})
return prepared_messages
@@ -349,13 +248,13 @@ class OpenAILLM(BaseLLM):
Returns:
str: Base64-encoded image data.
"""
file_path = attachment.get("path")
file_path = attachment.get('path')
if not file_path:
raise ValueError("No file path provided in attachment")
try:
with self.storage.get_file(file_path) as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
return base64.b64encode(image_file.read()).decode('utf-8')
except FileNotFoundError:
raise FileNotFoundError(f"File not found: {file_path}")
@@ -374,10 +273,10 @@ class OpenAILLM(BaseLLM):
"""
import logging
if "openai_file_id" in attachment:
return attachment["openai_file_id"]
if 'openai_file_id' in attachment:
return attachment['openai_file_id']
file_path = attachment.get("path")
file_path = attachment.get('path')
if not self.storage.file_exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
@@ -386,18 +285,19 @@ class OpenAILLM(BaseLLM):
file_id = self.storage.process_file(
file_path,
lambda local_path, **kwargs: self.client.files.create(
file=open(local_path, "rb"), purpose="assistants"
).id,
file=open(local_path, 'rb'),
purpose="assistants"
).id
)
from application.core.mongo_db import MongoDB
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
attachments_collection = db["attachments"]
if "_id" in attachment:
if '_id' in attachment:
attachments_collection.update_one(
{"_id": attachment["_id"]}, {"$set": {"openai_file_id": file_id}}
{"_id": attachment['_id']},
{"$set": {"openai_file_id": file_id}}
)
return file_id
@@ -408,7 +308,9 @@ class OpenAILLM(BaseLLM):
class AzureOpenAILLM(OpenAILLM):
def __init__(self, api_key, user_api_key, *args, **kwargs):
def __init__(
self, api_key, user_api_key, *args, **kwargs
):
super().__init__(api_key)
self.api_base = (settings.OPENAI_API_BASE,)
@@ -419,5 +321,5 @@ class AzureOpenAILLM(OpenAILLM):
self.client = AzureOpenAI(
api_key=api_key,
api_version=settings.OPENAI_API_VERSION,
azure_endpoint=settings.OPENAI_API_BASE,
azure_endpoint=settings.OPENAI_API_BASE
)

View File

@@ -136,8 +136,6 @@ def _log_to_mongodb(
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
user_logs_collection = db["stack_logs"]
log_entry = {
"endpoint": endpoint,
@@ -149,11 +147,6 @@ def _log_to_mongodb(
"stacks": stacks,
"timestamp": datetime.datetime.now(datetime.timezone.utc),
}
# clean up text fields to be no longer than 10000 characters
for key, value in log_entry.items():
if isinstance(value, str) and len(value) > 10000:
log_entry[key] = value[:10000]
user_logs_collection.insert_one(log_entry)
logging.debug(f"Logged activity to MongoDB: {activity_id}")

View File

@@ -32,7 +32,16 @@ class Chunker:
header, body = "", text # No header, treat entire text as body
return header, body
def combine_documents(self, doc: Document, next_doc: Document) -> Document:
combined_text = doc.text + " " + next_doc.text
combined_token_count = len(self.encoding.encode(combined_text))
new_doc = Document(
text=combined_text,
doc_id=doc.doc_id,
embedding=doc.embedding,
extra_info={**(doc.extra_info or {}), "token_count": combined_token_count}
)
return new_doc
def split_document(self, doc: Document) -> List[Document]:
split_docs = []
@@ -73,11 +82,26 @@ class Chunker:
processed_docs.append(doc)
i += 1
elif token_count < self.min_tokens:
doc.extra_info = doc.extra_info or {}
doc.extra_info["token_count"] = token_count
processed_docs.append(doc)
i += 1
if i + 1 < len(documents):
next_doc = documents[i + 1]
next_tokens = self.encoding.encode(next_doc.text)
if token_count + len(next_tokens) <= self.max_tokens:
# Combine small documents
combined_doc = self.combine_documents(doc, next_doc)
processed_docs.append(combined_doc)
i += 2
else:
# Keep the small document as is if adding next_doc would exceed max_tokens
doc.extra_info = doc.extra_info or {}
doc.extra_info["token_count"] = token_count
processed_docs.append(doc)
i += 1
else:
# No next document to combine with; add the small document as is
doc.extra_info = doc.extra_info or {}
doc.extra_info["token_count"] = token_count
processed_docs.append(doc)
i += 1
else:
# Split large documents
processed_docs.extend(self.split_document(doc))

View File

@@ -1,18 +0,0 @@
"""
External knowledge base connectors for DocsGPT.
This module contains connectors for external knowledge bases and document storage systems
that require authentication and specialized handling, separate from simple web scrapers.
"""
from .base import BaseConnectorAuth, BaseConnectorLoader
from .connector_creator import ConnectorCreator
from .google_drive import GoogleDriveAuth, GoogleDriveLoader
__all__ = [
'BaseConnectorAuth',
'BaseConnectorLoader',
'ConnectorCreator',
'GoogleDriveAuth',
'GoogleDriveLoader'
]

View File

@@ -1,129 +0,0 @@
"""
Base classes for external knowledge base connectors.
This module provides minimal abstract base classes that define the essential
interface for external knowledge base connectors.
"""
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
from application.parser.schema.base import Document
class BaseConnectorAuth(ABC):
"""
Abstract base class for connector authentication.
Defines the minimal interface that all connector authentication
implementations must follow.
"""
@abstractmethod
def get_authorization_url(self, state: Optional[str] = None) -> str:
"""
Generate authorization URL for OAuth flows.
Args:
state: Optional state parameter for CSRF protection
Returns:
Authorization URL
"""
pass
@abstractmethod
def exchange_code_for_tokens(self, authorization_code: str) -> Dict[str, Any]:
"""
Exchange authorization code for access tokens.
Args:
authorization_code: Authorization code from OAuth callback
Returns:
Dictionary containing token information
"""
pass
@abstractmethod
def refresh_access_token(self, refresh_token: str) -> Dict[str, Any]:
"""
Refresh an expired access token.
Args:
refresh_token: Refresh token
Returns:
Dictionary containing refreshed token information
"""
pass
@abstractmethod
def is_token_expired(self, token_info: Dict[str, Any]) -> bool:
"""
Check if a token is expired.
Args:
token_info: Token information dictionary
Returns:
True if token is expired, False otherwise
"""
pass
class BaseConnectorLoader(ABC):
"""
Abstract base class for connector loaders.
Defines the minimal interface that all connector loader
implementations must follow.
"""
@abstractmethod
def __init__(self, session_token: str):
"""
Initialize the connector loader.
Args:
session_token: Authentication session token
"""
pass
@abstractmethod
def load_data(self, inputs: Dict[str, Any]) -> List[Document]:
"""
Load documents from the external knowledge base.
Args:
inputs: Configuration dictionary containing:
- file_ids: Optional list of specific file IDs to load
- folder_ids: Optional list of folder IDs to browse/download
- limit: Maximum number of items to return
- list_only: If True, return metadata without content
- recursive: Whether to recursively process folders
Returns:
List of Document objects
"""
pass
@abstractmethod
def download_to_directory(self, local_dir: str, source_config: Dict[str, Any] = None) -> Dict[str, Any]:
"""
Download files/folders to a local directory.
Args:
local_dir: Local directory path to download files to
source_config: Configuration for what to download
Returns:
Dictionary containing download results:
- files_downloaded: Number of files downloaded
- directory_path: Path where files were downloaded
- empty_result: Whether no files were downloaded
- source_type: Type of connector
- config_used: Configuration that was used
- error: Error message if download failed (optional)
"""
pass

View File

@@ -1,81 +0,0 @@
from application.parser.connectors.google_drive.loader import GoogleDriveLoader
from application.parser.connectors.google_drive.auth import GoogleDriveAuth
class ConnectorCreator:
"""
Factory class for creating external knowledge base connectors and auth providers.
These are different from remote loaders as they typically require
authentication and connect to external document storage systems.
"""
connectors = {
"google_drive": GoogleDriveLoader,
}
auth_providers = {
"google_drive": GoogleDriveAuth,
}
@classmethod
def create_connector(cls, connector_type, *args, **kwargs):
"""
Create a connector instance for the specified type.
Args:
connector_type: Type of connector to create (e.g., 'google_drive')
*args, **kwargs: Arguments to pass to the connector constructor
Returns:
Connector instance
Raises:
ValueError: If connector type is not supported
"""
connector_class = cls.connectors.get(connector_type.lower())
if not connector_class:
raise ValueError(f"No connector class found for type {connector_type}")
return connector_class(*args, **kwargs)
@classmethod
def create_auth(cls, connector_type):
"""
Create an auth provider instance for the specified connector type.
Args:
connector_type: Type of connector auth to create (e.g., 'google_drive')
Returns:
Auth provider instance
Raises:
ValueError: If connector type is not supported for auth
"""
auth_class = cls.auth_providers.get(connector_type.lower())
if not auth_class:
raise ValueError(f"No auth class found for type {connector_type}")
return auth_class()
@classmethod
def get_supported_connectors(cls):
"""
Get list of supported connector types.
Returns:
List of supported connector type strings
"""
return list(cls.connectors.keys())
@classmethod
def is_supported(cls, connector_type):
"""
Check if a connector type is supported.
Args:
connector_type: Type of connector to check
Returns:
True if supported, False otherwise
"""
return connector_type.lower() in cls.connectors

View File

@@ -1,10 +0,0 @@
"""
Google Drive connector for DocsGPT.
This module provides authentication and document loading capabilities for Google Drive.
"""
from .auth import GoogleDriveAuth
from .loader import GoogleDriveLoader
__all__ = ['GoogleDriveAuth', 'GoogleDriveLoader']

View File

@@ -1,267 +0,0 @@
import logging
import datetime
from typing import Optional, Dict, Any
from google.oauth2.credentials import Credentials
from google_auth_oauthlib.flow import Flow
from googleapiclient.discovery import build
from googleapiclient.errors import HttpError
from application.core.settings import settings
from application.parser.connectors.base import BaseConnectorAuth
class GoogleDriveAuth(BaseConnectorAuth):
"""
Handles Google OAuth 2.0 authentication for Google Drive access.
"""
SCOPES = [
'https://www.googleapis.com/auth/drive.file'
]
def __init__(self):
self.client_id = settings.GOOGLE_CLIENT_ID
self.client_secret = settings.GOOGLE_CLIENT_SECRET
self.redirect_uri = f"{settings.CONNECTOR_REDIRECT_BASE_URI}"
if not self.client_id or not self.client_secret:
raise ValueError("Google OAuth credentials not configured. Please set GOOGLE_CLIENT_ID and GOOGLE_CLIENT_SECRET in settings.")
def get_authorization_url(self, state: Optional[str] = None) -> str:
try:
flow = Flow.from_client_config(
{
"web": {
"client_id": self.client_id,
"client_secret": self.client_secret,
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
"token_uri": "https://oauth2.googleapis.com/token",
"redirect_uris": [self.redirect_uri]
}
},
scopes=self.SCOPES
)
flow.redirect_uri = self.redirect_uri
authorization_url, _ = flow.authorization_url(
access_type='offline',
prompt='consent',
include_granted_scopes='false',
state=state
)
return authorization_url
except Exception as e:
logging.error(f"Error generating authorization URL: {e}")
raise
def exchange_code_for_tokens(self, authorization_code: str) -> Dict[str, Any]:
try:
if not authorization_code:
raise ValueError("Authorization code is required")
flow = Flow.from_client_config(
{
"web": {
"client_id": self.client_id,
"client_secret": self.client_secret,
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
"token_uri": "https://oauth2.googleapis.com/token",
"redirect_uris": [self.redirect_uri]
}
},
scopes=self.SCOPES
)
flow.redirect_uri = self.redirect_uri
flow.fetch_token(code=authorization_code)
credentials = flow.credentials
if not credentials.refresh_token:
logging.warning("OAuth flow did not return a refresh_token.")
if not credentials.token:
raise ValueError("OAuth flow did not return an access token")
if not credentials.token_uri:
credentials.token_uri = "https://oauth2.googleapis.com/token"
if not credentials.client_id:
credentials.client_id = self.client_id
if not credentials.client_secret:
credentials.client_secret = self.client_secret
if not credentials.refresh_token:
raise ValueError(
"No refresh token received. This typically happens when offline access wasn't granted. "
)
return {
'access_token': credentials.token,
'refresh_token': credentials.refresh_token,
'token_uri': credentials.token_uri,
'client_id': credentials.client_id,
'client_secret': credentials.client_secret,
'scopes': credentials.scopes,
'expiry': credentials.expiry.isoformat() if credentials.expiry else None
}
except Exception as e:
logging.error(f"Error exchanging code for tokens: {e}")
raise
def refresh_access_token(self, refresh_token: str) -> Dict[str, Any]:
try:
if not refresh_token:
raise ValueError("Refresh token is required")
credentials = Credentials(
token=None,
refresh_token=refresh_token,
token_uri="https://oauth2.googleapis.com/token",
client_id=self.client_id,
client_secret=self.client_secret
)
from google.auth.transport.requests import Request
credentials.refresh(Request())
return {
'access_token': credentials.token,
'refresh_token': refresh_token,
'token_uri': credentials.token_uri,
'client_id': credentials.client_id,
'client_secret': credentials.client_secret,
'scopes': credentials.scopes,
'expiry': credentials.expiry.isoformat() if credentials.expiry else None
}
except Exception as e:
logging.error(f"Error refreshing access token: {e}", exc_info=True)
raise
def create_credentials_from_token_info(self, token_info: Dict[str, Any]) -> Credentials:
from application.core.settings import settings
access_token = token_info.get('access_token')
if not access_token:
raise ValueError("No access token found in token_info")
credentials = Credentials(
token=access_token,
refresh_token=token_info.get('refresh_token'),
token_uri= 'https://oauth2.googleapis.com/token',
client_id=settings.GOOGLE_CLIENT_ID,
client_secret=settings.GOOGLE_CLIENT_SECRET,
scopes=token_info.get('scopes', ['https://www.googleapis.com/auth/drive.readonly'])
)
if not credentials.token:
raise ValueError("Credentials created without valid access token")
return credentials
def build_drive_service(self, credentials: Credentials):
try:
if not credentials:
raise ValueError("No credentials provided")
if not credentials.token and not credentials.refresh_token:
raise ValueError("No access token or refresh token available. User must re-authorize with offline access.")
needs_refresh = credentials.expired or not credentials.token
if needs_refresh:
if credentials.refresh_token:
try:
from google.auth.transport.requests import Request
credentials.refresh(Request())
except Exception as refresh_error:
raise ValueError(f"Failed to refresh credentials: {refresh_error}")
else:
raise ValueError("No access token or refresh token available. User must re-authorize with offline access.")
return build('drive', 'v3', credentials=credentials)
except HttpError as e:
raise ValueError(f"Failed to build Google Drive service: HTTP {e.resp.status}")
except Exception as e:
raise ValueError(f"Failed to build Google Drive service: {str(e)}")
def is_token_expired(self, token_info):
if 'expiry' in token_info and token_info['expiry']:
try:
from dateutil import parser
# Google Drive provides timezone-aware ISO8601 dates
expiry_dt = parser.parse(token_info['expiry'])
current_time = datetime.datetime.now(datetime.timezone.utc)
return current_time >= expiry_dt - datetime.timedelta(seconds=60)
except Exception:
return True
if 'access_token' in token_info and token_info['access_token']:
return False
return True
def get_token_info_from_session(self, session_token: str) -> Dict[str, Any]:
try:
from application.core.mongo_db import MongoDB
from application.core.settings import settings
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
sessions_collection = db["connector_sessions"]
session = sessions_collection.find_one({"session_token": session_token})
if not session:
raise ValueError(f"Invalid session token: {session_token}")
if "token_info" not in session:
raise ValueError("Session missing token information")
token_info = session["token_info"]
if not token_info:
raise ValueError("Invalid token information")
required_fields = ["access_token", "refresh_token"]
missing_fields = [field for field in required_fields if field not in token_info or not token_info.get(field)]
if missing_fields:
raise ValueError(f"Missing required token fields: {missing_fields}")
if 'client_id' not in token_info:
token_info['client_id'] = settings.GOOGLE_CLIENT_ID
if 'client_secret' not in token_info:
token_info['client_secret'] = settings.GOOGLE_CLIENT_SECRET
if 'token_uri' not in token_info:
token_info['token_uri'] = 'https://oauth2.googleapis.com/token'
return token_info
except Exception as e:
raise ValueError(f"Failed to retrieve Google Drive token information: {str(e)}")
def validate_credentials(self, credentials: Credentials) -> bool:
"""
Validate Google Drive credentials by making a test API call.
Args:
credentials: Google credentials object
Returns:
True if credentials are valid, False otherwise
"""
try:
service = self.build_drive_service(credentials)
service.about().get(fields="user").execute()
return True
except HttpError as e:
logging.error(f"HTTP error validating credentials: {e}")
return False
except Exception as e:
logging.error(f"Error validating credentials: {e}")
return False

View File

@@ -1,559 +0,0 @@
"""
Google Drive loader for DocsGPT.
Loads documents from Google Drive using Google Drive API.
"""
import io
import logging
import os
from typing import List, Dict, Any, Optional
from googleapiclient.http import MediaIoBaseDownload
from googleapiclient.errors import HttpError
from application.parser.connectors.base import BaseConnectorLoader
from application.parser.connectors.google_drive.auth import GoogleDriveAuth
from application.parser.schema.base import Document
class GoogleDriveLoader(BaseConnectorLoader):
SUPPORTED_MIME_TYPES = {
'application/pdf': '.pdf',
'application/vnd.google-apps.document': '.docx',
'application/vnd.google-apps.presentation': '.pptx',
'application/vnd.google-apps.spreadsheet': '.xlsx',
'application/vnd.openxmlformats-officedocument.wordprocessingml.document': '.docx',
'application/vnd.openxmlformats-officedocument.presentationml.presentation': '.pptx',
'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet': '.xlsx',
'application/msword': '.doc',
'application/vnd.ms-powerpoint': '.ppt',
'application/vnd.ms-excel': '.xls',
'text/plain': '.txt',
'text/csv': '.csv',
'text/html': '.html',
'text/markdown': '.md',
'text/x-rst': '.rst',
'application/json': '.json',
'application/epub+zip': '.epub',
'application/rtf': '.rtf',
'image/jpeg': '.jpg',
'image/jpg': '.jpg',
'image/png': '.png',
}
EXPORT_FORMATS = {
'application/vnd.google-apps.document': 'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
'application/vnd.google-apps.presentation': 'application/vnd.openxmlformats-officedocument.presentationml.presentation',
'application/vnd.google-apps.spreadsheet': 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet'
}
def __init__(self, session_token: str):
self.auth = GoogleDriveAuth()
self.session_token = session_token
token_info = self.auth.get_token_info_from_session(session_token)
self.credentials = self.auth.create_credentials_from_token_info(token_info)
try:
self.service = self.auth.build_drive_service(self.credentials)
except Exception as e:
logging.warning(f"Could not build Google Drive service: {e}")
self.service = None
self.next_page_token = None
def _process_file(self, file_metadata: Dict[str, Any], load_content: bool = True) -> Optional[Document]:
try:
file_id = file_metadata.get('id')
file_name = file_metadata.get('name', 'Unknown')
mime_type = file_metadata.get('mimeType', 'application/octet-stream')
if mime_type not in self.SUPPORTED_MIME_TYPES and not mime_type.startswith('application/vnd.google-apps.'):
return None
if mime_type not in self.SUPPORTED_MIME_TYPES and not mime_type.startswith('application/vnd.google-apps.'):
logging.info(f"Skipping unsupported file type: {mime_type} for file {file_name}")
return None
# Google Drive provides timezone-aware ISO8601 dates
doc_metadata = {
'file_name': file_name,
'mime_type': mime_type,
'size': file_metadata.get('size', None),
'created_time': file_metadata.get('createdTime'),
'modified_time': file_metadata.get('modifiedTime'),
'parents': file_metadata.get('parents', []),
'source': 'google_drive'
}
if not load_content:
return Document(
text="",
doc_id=file_id,
extra_info=doc_metadata
)
content = self._download_file_content(file_id, mime_type)
if content is None:
logging.warning(f"Could not load content for file {file_name} ({file_id})")
return None
return Document(
text=content,
doc_id=file_id,
extra_info=doc_metadata
)
except Exception as e:
logging.error(f"Error processing file: {e}")
return None
def load_data(self, inputs: Dict[str, Any]) -> List[Document]:
session_token = inputs.get('session_token')
if session_token and session_token != self.session_token:
logging.warning("Session token in inputs differs from loader's session token. Using loader's session token.")
self.config = inputs
try:
documents: List[Document] = []
folder_id = inputs.get('folder_id')
file_ids = inputs.get('file_ids', [])
limit = inputs.get('limit', 100)
list_only = inputs.get('list_only', False)
load_content = not list_only
page_token = inputs.get('page_token')
search_query = inputs.get('search_query')
self.next_page_token = None
if file_ids:
# Specific files requested: load them
for file_id in file_ids:
try:
doc = self._load_file_by_id(file_id, load_content=load_content)
if doc:
if not search_query or (
search_query.lower() in doc.extra_info.get('file_name', '').lower()
):
documents.append(doc)
elif hasattr(self, '_credential_refreshed') and self._credential_refreshed:
self._credential_refreshed = False
logging.info(f"Retrying load of file {file_id} after credential refresh")
doc = self._load_file_by_id(file_id, load_content=load_content)
if doc and (
not search_query or
search_query.lower() in doc.extra_info.get('file_name', '').lower()
):
documents.append(doc)
except Exception as e:
logging.error(f"Error loading file {file_id}: {e}")
continue
else:
# Browsing mode: list immediate children of provided folder or root
parent_id = folder_id if folder_id else 'root'
documents = self._list_items_in_parent(
parent_id,
limit=limit,
load_content=load_content,
page_token=page_token,
search_query=search_query
)
logging.info(f"Loaded {len(documents)} documents from Google Drive")
return documents
except Exception as e:
logging.error(f"Error loading data from Google Drive: {e}", exc_info=True)
raise
def _load_file_by_id(self, file_id: str, load_content: bool = True) -> Optional[Document]:
self._ensure_service()
try:
file_metadata = self.service.files().get(
fileId=file_id,
fields='id,name,mimeType,size,createdTime,modifiedTime,parents'
).execute()
return self._process_file(file_metadata, load_content=load_content)
except HttpError as e:
logging.error(f"HTTP error loading file {file_id}: {e.resp.status} - {e.content}")
if e.resp.status in [401, 403]:
if hasattr(self.credentials, 'refresh_token') and self.credentials.refresh_token:
try:
from google.auth.transport.requests import Request
self.credentials.refresh(Request())
self._ensure_service()
return None
except Exception as refresh_error:
raise ValueError(f"Authentication failed and could not be refreshed: {refresh_error}")
else:
raise ValueError("Authentication failed and cannot be refreshed: missing refresh_token")
return None
except Exception as e:
logging.error(f"Error loading file {file_id}: {e}")
return None
def _list_items_in_parent(self, parent_id: str, limit: int = 100, load_content: bool = False, page_token: Optional[str] = None, search_query: Optional[str] = None) -> List[Document]:
self._ensure_service()
documents: List[Document] = []
try:
query = f"'{parent_id}' in parents and trashed=false"
if search_query:
safe_search = search_query.replace("'", "\\'")
query += f" and name contains '{safe_search}'"
next_token_out: Optional[str] = None
while True:
page_size = 100
if limit:
remaining = max(0, limit - len(documents))
if remaining == 0:
break
page_size = min(100, remaining)
results = self.service.files().list(
q=query,
fields='nextPageToken,files(id,name,mimeType,size,createdTime,modifiedTime,parents)',
pageToken=page_token,
pageSize=page_size,
orderBy='name'
).execute()
items = results.get('files', [])
for item in items:
mime_type = item.get('mimeType')
if mime_type == 'application/vnd.google-apps.folder':
doc_metadata = {
'file_name': item.get('name', 'Unknown'),
'mime_type': mime_type,
'size': item.get('size', None),
'created_time': item.get('createdTime'),
'modified_time': item.get('modifiedTime'),
'parents': item.get('parents', []),
'source': 'google_drive',
'is_folder': True
}
documents.append(Document(text="", doc_id=item.get('id'), extra_info=doc_metadata))
else:
doc = self._process_file(item, load_content=load_content)
if doc:
documents.append(doc)
if limit and len(documents) >= limit:
self.next_page_token = results.get('nextPageToken')
return documents
page_token = results.get('nextPageToken')
next_token_out = page_token
if not page_token:
break
self.next_page_token = next_token_out
return documents
except Exception as e:
logging.error(f"Error listing items under parent {parent_id}: {e}")
return documents
def _download_file_content(self, file_id: str, mime_type: str) -> Optional[str]:
if not self.credentials.token:
logging.warning("No access token in credentials, attempting to refresh")
if hasattr(self.credentials, 'refresh_token') and self.credentials.refresh_token:
try:
from google.auth.transport.requests import Request
self.credentials.refresh(Request())
logging.info("Credentials refreshed successfully")
self._ensure_service()
except Exception as e:
logging.error(f"Failed to refresh credentials: {e}")
raise ValueError("Authentication failed and cannot be refreshed: missing or invalid refresh_token")
else:
logging.error("No access token and no refresh_token available")
raise ValueError("Authentication failed and cannot be refreshed: missing refresh_token")
if self.credentials.expired:
logging.warning("Credentials are expired, attempting to refresh")
if hasattr(self.credentials, 'refresh_token') and self.credentials.refresh_token:
try:
from google.auth.transport.requests import Request
self.credentials.refresh(Request())
logging.info("Credentials refreshed successfully")
self._ensure_service()
except Exception as e:
logging.error(f"Failed to refresh expired credentials: {e}")
raise ValueError("Authentication failed and cannot be refreshed: expired credentials")
else:
logging.error("Credentials expired and no refresh_token available")
raise ValueError("Authentication failed and cannot be refreshed: missing refresh_token")
try:
if mime_type in self.EXPORT_FORMATS:
export_mime_type = self.EXPORT_FORMATS[mime_type]
request = self.service.files().export_media(
fileId=file_id,
mimeType=export_mime_type
)
else:
request = self.service.files().get_media(fileId=file_id)
file_io = io.BytesIO()
downloader = MediaIoBaseDownload(file_io, request)
done = False
while done is False:
try:
_, done = downloader.next_chunk()
except HttpError as e:
logging.error(f"HTTP error downloading file {file_id}: {e.resp.status} - {e.content}")
return None
except Exception as e:
logging.error(f"Error during download of file {file_id}: {e}")
return None
content_bytes = file_io.getvalue()
try:
content = content_bytes.decode('utf-8')
except UnicodeDecodeError:
try:
content = content_bytes.decode('latin-1')
except UnicodeDecodeError:
logging.error(f"Could not decode file {file_id} as text")
return None
return content
except HttpError as e:
logging.error(f"HTTP error downloading file {file_id}: {e.resp.status} - {e.content}")
if e.resp.status in [401, 403]:
logging.error(f"Authentication error downloading file {file_id}")
if hasattr(self.credentials, 'refresh_token') and self.credentials.refresh_token:
logging.info(f"Attempting to refresh credentials for file {file_id}")
try:
from google.auth.transport.requests import Request
self.credentials.refresh(Request())
logging.info("Credentials refreshed successfully")
self._credential_refreshed = True
self._ensure_service()
return None
except Exception as refresh_error:
logging.error(f"Error refreshing credentials: {refresh_error}")
raise ValueError(f"Authentication failed and could not be refreshed: {refresh_error}")
else:
logging.error("Cannot refresh credentials: missing refresh_token")
raise ValueError("Authentication failed and cannot be refreshed: missing refresh_token")
return None
except Exception as e:
logging.error(f"Error downloading file {file_id}: {e}")
return None
def _download_file_to_directory(self, file_id: str, local_dir: str) -> bool:
try:
self._ensure_service()
return self._download_single_file(file_id, local_dir)
except Exception as e:
logging.error(f"Error downloading file {file_id}: {e}", exc_info=True)
return False
def _ensure_service(self):
if not self.service:
try:
self.service = self.auth.build_drive_service(self.credentials)
except Exception as e:
raise ValueError(f"Cannot access Google Drive: {e}")
def _download_single_file(self, file_id: str, local_dir: str) -> bool:
file_metadata = self.service.files().get(
fileId=file_id,
fields='name,mimeType'
).execute()
file_name = file_metadata['name']
mime_type = file_metadata['mimeType']
if mime_type not in self.SUPPORTED_MIME_TYPES and not mime_type.startswith('application/vnd.google-apps.'):
return False
os.makedirs(local_dir, exist_ok=True)
full_path = os.path.join(local_dir, file_name)
if mime_type in self.EXPORT_FORMATS:
export_mime_type = self.EXPORT_FORMATS[mime_type]
request = self.service.files().export_media(
fileId=file_id,
mimeType=export_mime_type
)
extension = self._get_extension_for_mime_type(export_mime_type)
if not full_path.endswith(extension):
full_path += extension
else:
request = self.service.files().get_media(fileId=file_id)
with open(full_path, 'wb') as f:
downloader = MediaIoBaseDownload(f, request)
done = False
while not done:
_, done = downloader.next_chunk()
return True
def _download_folder_recursive(self, folder_id: str, local_dir: str, recursive: bool = True) -> int:
files_downloaded = 0
try:
os.makedirs(local_dir, exist_ok=True)
query = f"'{folder_id}' in parents and trashed=false"
page_token = None
while True:
results = self.service.files().list(
q=query,
fields='nextPageToken, files(id, name, mimeType)',
pageToken=page_token,
pageSize=1000
).execute()
items = results.get('files', [])
logging.info(f"Found {len(items)} items in folder {folder_id}")
for item in items:
item_name = item['name']
item_id = item['id']
mime_type = item['mimeType']
if mime_type == 'application/vnd.google-apps.folder':
if recursive:
# Create subfolder and recurse
subfolder_path = os.path.join(local_dir, item_name)
os.makedirs(subfolder_path, exist_ok=True)
subfolder_files = self._download_folder_recursive(
item_id,
subfolder_path,
recursive
)
files_downloaded += subfolder_files
logging.info(f"Downloaded {subfolder_files} files from subfolder {item_name}")
else:
# Download file
success = self._download_single_file(item_id, local_dir)
if success:
files_downloaded += 1
logging.info(f"Downloaded file: {item_name}")
else:
logging.warning(f"Failed to download file: {item_name}")
page_token = results.get('nextPageToken')
if not page_token:
break
return files_downloaded
except Exception as e:
logging.error(f"Error in _download_folder_recursive for folder {folder_id}: {e}", exc_info=True)
return files_downloaded
def _get_extension_for_mime_type(self, mime_type: str) -> str:
extensions = {
'application/pdf': '.pdf',
'text/plain': '.txt',
'application/vnd.openxmlformats-officedocument.wordprocessingml.document': '.docx',
'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet': '.xlsx',
'application/vnd.openxmlformats-officedocument.presentationml.presentation': '.pptx',
'text/html': '.html',
'text/markdown': '.md',
}
return extensions.get(mime_type, '.bin')
def _download_folder_contents(self, folder_id: str, local_dir: str, recursive: bool = True) -> int:
try:
self._ensure_service()
return self._download_folder_recursive(folder_id, local_dir, recursive)
except Exception as e:
logging.error(f"Error downloading folder {folder_id}: {e}", exc_info=True)
return 0
def download_to_directory(self, local_dir: str, source_config: dict = None) -> dict:
if source_config is None:
source_config = {}
config = source_config if source_config else getattr(self, 'config', {})
files_downloaded = 0
try:
folder_ids = config.get('folder_ids', [])
file_ids = config.get('file_ids', [])
recursive = config.get('recursive', True)
self._ensure_service()
if file_ids:
if isinstance(file_ids, str):
file_ids = [file_ids]
for file_id in file_ids:
if self._download_file_to_directory(file_id, local_dir):
files_downloaded += 1
# Process folders
if folder_ids:
if isinstance(folder_ids, str):
folder_ids = [folder_ids]
for folder_id in folder_ids:
try:
folder_metadata = self.service.files().get(
fileId=folder_id,
fields='name'
).execute()
folder_name = folder_metadata.get('name', '')
folder_path = os.path.join(local_dir, folder_name)
os.makedirs(folder_path, exist_ok=True)
folder_files = self._download_folder_recursive(
folder_id,
folder_path,
recursive
)
files_downloaded += folder_files
logging.info(f"Downloaded {folder_files} files from folder {folder_name}")
except Exception as e:
logging.error(f"Error downloading folder {folder_id}: {e}", exc_info=True)
if not file_ids and not folder_ids:
raise ValueError("No folder_ids or file_ids provided for download")
return {
"files_downloaded": files_downloaded,
"directory_path": local_dir,
"empty_result": files_downloaded == 0,
"source_type": "google_drive",
"config_used": config
}
except Exception as e:
return {
"files_downloaded": files_downloaded,
"directory_path": local_dir,
"empty_result": True,
"source_type": "google_drive",
"config_used": config,
"error": str(e)
}

View File

@@ -6,21 +6,6 @@ from application.core.settings import settings
from application.vectorstore.vector_creator import VectorCreator
def sanitize_content(content: str) -> str:
"""
Remove NUL characters that can cause vector store ingestion to fail.
Args:
content (str): Raw content that may contain NUL characters
Returns:
str: Sanitized content with NUL characters removed
"""
if not content:
return content
return content.replace('\x00', '')
@retry(tries=10, delay=60)
def add_text_to_store_with_retry(store, doc, source_id):
"""
@@ -31,9 +16,6 @@ def add_text_to_store_with_retry(store, doc, source_id):
source_id: Unique identifier for the source.
"""
try:
# Sanitize content to remove NUL characters that cause ingestion failures
doc.page_content = sanitize_content(doc.page_content)
doc.metadata["source_id"] = str(source_id)
store.add_texts([doc.page_content], metadatas=[doc.metadata])
except Exception as e:
@@ -64,7 +46,7 @@ def embed_and_store_documents(docs, folder_name, source_id, task_status):
store = VectorCreator.create_vectorstore(
settings.VECTOR_STORE,
docs_init=docs_init,
source_id=source_id,
source_id=folder_name,
embeddings_key=os.getenv("EMBEDDINGS_KEY"),
)
else:

View File

@@ -15,7 +15,6 @@ from application.parser.file.json_parser import JSONParser
from application.parser.file.pptx_parser import PPTXParser
from application.parser.file.image_parser import ImageParser
from application.parser.schema.base import Document
from application.utils import num_tokens_from_string
DEFAULT_FILE_EXTRACTOR: Dict[str, BaseParser] = {
".pdf": PDFParser(),
@@ -142,12 +141,11 @@ class SimpleDirectoryReader(BaseReader):
Returns:
List[Document]: A list of documents.
"""
data: Union[str, List[str]] = ""
data_list: List[str] = []
metadata_list = []
self.file_token_counts = {}
for input_file in self.input_files:
if input_file.suffix in self.file_extractor:
parser = self.file_extractor[input_file.suffix]
@@ -158,48 +156,24 @@ class SimpleDirectoryReader(BaseReader):
# do standard read
with open(input_file, "r", errors=self.errors) as f:
data = f.read()
# Calculate token count for this file
if isinstance(data, List):
file_tokens = sum(num_tokens_from_string(str(d)) for d in data)
else:
file_tokens = num_tokens_from_string(str(data))
full_path = str(input_file.resolve())
self.file_token_counts[full_path] = file_tokens
base_metadata = {
'title': input_file.name,
'token_count': file_tokens,
}
if hasattr(self, 'input_dir'):
try:
relative_path = str(input_file.relative_to(self.input_dir))
base_metadata['source'] = relative_path
except ValueError:
base_metadata['source'] = str(input_file)
else:
base_metadata['source'] = str(input_file)
# Prepare metadata for this file
if self.file_metadata is not None:
custom_metadata = self.file_metadata(input_file.name)
base_metadata.update(custom_metadata)
file_metadata = self.file_metadata(input_file.name)
else:
# Provide a default empty metadata
file_metadata = {'title': '', 'store': ''}
# TODO: Find a case with no metadata and check if breaks anything
if isinstance(data, List):
# Extend data_list with each item in the data list
data_list.extend([str(d) for d in data])
metadata_list.extend([base_metadata for _ in data])
# For each item in the data list, add the file's metadata to metadata_list
metadata_list.extend([file_metadata for _ in data])
else:
# Add the single piece of data to data_list
data_list.append(str(data))
metadata_list.append(base_metadata)
# Build directory structure if input_dir is provided
if hasattr(self, 'input_dir'):
self.directory_structure = self.build_directory_structure(self.input_dir)
logging.info("Directory structure built successfully")
else:
self.directory_structure = {}
# Add the file's metadata to metadata_list
metadata_list.append(file_metadata)
if concatenate:
return [Document("\n".join(data_list))]
@@ -207,48 +181,3 @@ class SimpleDirectoryReader(BaseReader):
return [Document(d, extra_info=m) for d, m in zip(data_list, metadata_list)]
else:
return [Document(d) for d in data_list]
def build_directory_structure(self, base_path):
"""Build a dictionary representing the directory structure.
Args:
base_path: The base path to start building the structure from.
Returns:
dict: A nested dictionary representing the directory structure.
"""
import mimetypes
def build_tree(path):
"""Helper function to recursively build the directory tree."""
result = {}
for item in path.iterdir():
if self.exclude_hidden and item.name.startswith('.'):
continue
if item.is_dir():
subtree = build_tree(item)
if subtree:
result[item.name] = subtree
else:
if self.required_exts is not None and item.suffix not in self.required_exts:
continue
full_path = str(item.resolve())
file_size_bytes = item.stat().st_size
mime_type = mimetypes.guess_type(item.name)[0] or "application/octet-stream"
file_info = {
"type": mime_type,
"size_bytes": file_size_bytes
}
if hasattr(self, 'file_token_counts') and full_path in self.file_token_counts:
file_info["token_count"] = self.file_token_counts[full_path]
result[item.name] = file_info
return result
return build_tree(Path(base_path))

View File

@@ -8,7 +8,6 @@ import requests
from typing import Dict, Union
from application.parser.file.base_parser import BaseParser
from application.core.settings import settings
class ImageParser(BaseParser):
@@ -19,13 +18,10 @@ class ImageParser(BaseParser):
return {}
def parse_file(self, file: Path, errors: str = "ignore") -> Union[str, list[str]]:
if settings.PARSE_IMAGE_REMOTE:
doc2md_service = "https://llm.arc53.com/doc2md"
# alternatively you can use local vision capable LLM
with open(file, "rb") as file_loaded:
files = {'file': file_loaded}
response = requests.post(doc2md_service, files=files)
data = response.json()["markdown"]
else:
data = ""
doc2md_service = "https://llm.arc53.com/doc2md"
# alternatively you can use local vision capable LLM
with open(file, "rb") as file_loaded:
files = {'file': file_loaded}
response = requests.post(doc2md_service, files=files)
data = response.json()["markdown"]
return data

View File

@@ -6,16 +6,6 @@ from application.parser.remote.github_loader import GitHubLoader
class RemoteCreator:
"""
Factory class for creating remote content loaders.
These loaders fetch content from remote web sources like URLs,
sitemaps, web crawlers, social media platforms, etc.
For external knowledge base connectors (like Google Drive),
use ConnectorCreator instead.
"""
loaders = {
"url": WebLoader,
"sitemap": SitemapLoader,
@@ -28,5 +18,5 @@ class RemoteCreator:
def create_loader(cls, type, *args, **kwargs):
loader_class = cls.loaders.get(type.lower())
if not loader_class:
raise ValueError(f"No loader class found for type {type}")
raise ValueError(f"No LLM class found for type {type}")
return loader_class(*args, **kwargs)

View File

@@ -2,7 +2,6 @@ anthropic==0.49.0
boto3==1.38.18
beautifulsoup4==4.13.4
celery==5.4.0
cryptography==42.0.8
dataclasses-json==0.6.7
docx2txt==0.8
duckduckgo-search==7.5.2
@@ -12,12 +11,8 @@ esprima==4.0.1
esutils==1.0.1
Flask==3.1.1
faiss-cpu==1.9.0.post1
fastmcp==2.11.0
flask-restx==1.3.0
google-genai==1.3.0
google-api-python-client==2.179.0
google-auth-httplib2==0.2.0
google-auth-oauthlib==1.2.2
gTTS==2.5.4
gunicorn==23.0.0
javalang==0.13.0
@@ -57,13 +52,13 @@ prompt-toolkit==3.0.51
protobuf==5.29.3
psycopg2-binary==2.9.10
py==1.11.0
pydantic
pydantic-core
pydantic-settings
pydantic==2.10.6
pydantic-core==2.27.2
pydantic-settings==2.7.1
pymongo==4.11.3
pypdf==5.5.0
python-dateutil==2.9.0.post0
python-dotenv
python-dotenv==1.0.1
python-jose==3.4.0
python-pptx==1.0.2
redis==5.2.1
@@ -83,7 +78,7 @@ tzdata==2024.2
urllib3==2.3.0
vine==5.1.0
wcwidth==0.2.13
werkzeug>=3.1.0,<3.1.2
werkzeug==3.1.3
yarl==1.20.0
markdownify==1.1.0
tldextract==5.1.3

View File

@@ -5,6 +5,10 @@ class BaseRetriever(ABC):
def __init__(self):
pass
@abstractmethod
def gen(self, *args, **kwargs):
pass
@abstractmethod
def search(self, *args, **kwargs):
pass

View File

@@ -0,0 +1,112 @@
import json
from langchain_community.tools import BraveSearch
from application.core.settings import settings
from application.llm.llm_creator import LLMCreator
from application.retriever.base import BaseRetriever
class BraveRetSearch(BaseRetriever):
def __init__(
self,
source,
chat_history,
prompt,
chunks=2,
token_limit=150,
gpt_model="docsgpt",
user_api_key=None,
decoded_token=None,
):
self.question = ""
self.source = source
self.chat_history = chat_history
self.prompt = prompt
self.chunks = chunks
self.gpt_model = gpt_model
self.token_limit = (
token_limit
if token_limit
< settings.LLM_TOKEN_LIMITS.get(
self.gpt_model, settings.DEFAULT_MAX_HISTORY
)
else settings.LLM_TOKEN_LIMITS.get(
self.gpt_model, settings.DEFAULT_MAX_HISTORY
)
)
self.user_api_key = user_api_key
self.decoded_token = decoded_token
def _get_data(self):
if self.chunks == 0:
docs = []
else:
search = BraveSearch.from_api_key(
api_key=settings.BRAVE_SEARCH_API_KEY,
search_kwargs={"count": int(self.chunks)},
)
results = search.run(self.question)
results = json.loads(results)
docs = []
for i in results:
try:
title = i["title"]
link = i["link"]
snippet = i["snippet"]
docs.append({"text": snippet, "title": title, "link": link})
except IndexError:
pass
if settings.LLM_PROVIDER == "llama.cpp":
docs = [docs[0]]
return docs
def gen(self):
docs = self._get_data()
# join all page_content together with a newline
docs_together = "\n".join([doc["text"] for doc in docs])
p_chat_combine = self.prompt.replace("{summaries}", docs_together)
messages_combine = [{"role": "system", "content": p_chat_combine}]
for doc in docs:
yield {"source": doc}
if len(self.chat_history) > 0:
for i in self.chat_history:
if "prompt" in i and "response" in i:
messages_combine.append({"role": "user", "content": i["prompt"]})
messages_combine.append(
{"role": "assistant", "content": i["response"]}
)
messages_combine.append({"role": "user", "content": self.question})
llm = LLMCreator.create_llm(
settings.LLM_PROVIDER,
api_key=settings.API_KEY,
user_api_key=self.user_api_key,
decoded_token=self.decoded_token,
)
completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine)
for line in completion:
yield {"answer": str(line)}
def search(self, query: str = ""):
if query:
self.question = query
return self._get_data()
def get_params(self):
return {
"question": self.question,
"source": self.source,
"chat_history": self.chat_history,
"prompt": self.prompt,
"chunks": self.chunks,
"token_limit": self.token_limit,
"gpt_model": self.gpt_model,
"user_api_key": self.user_api_key,
}

View File

@@ -1,6 +1,4 @@
import logging
import os
from application.core.settings import settings
from application.llm.llm_creator import LLMCreator
from application.retriever.base import BaseRetriever
@@ -22,20 +20,10 @@ class ClassicRAG(BaseRetriever):
api_key=settings.API_KEY,
decoded_token=None,
):
"""Initialize ClassicRAG retriever with vectorstore sources and LLM configuration"""
self.original_question = source.get("question", "")
self.original_question = ""
self.chat_history = chat_history if chat_history is not None else []
self.prompt = prompt
if isinstance(chunks, str):
try:
self.chunks = int(chunks)
except ValueError:
logging.warning(
f"Invalid chunks value '{chunks}', using default value 2"
)
self.chunks = 2
else:
self.chunks = chunks
self.chunks = chunks
self.gpt_model = gpt_model
self.token_limit = (
token_limit
@@ -56,52 +44,25 @@ class ClassicRAG(BaseRetriever):
user_api_key=self.user_api_key,
decoded_token=decoded_token,
)
if "active_docs" in source and source["active_docs"] is not None:
if isinstance(source["active_docs"], list):
self.vectorstores = source["active_docs"]
else:
self.vectorstores = [source["active_docs"]]
else:
self.vectorstores = []
self.vectorstore = source["active_docs"] if "active_docs" in source else None
self.question = self._rephrase_query()
self.decoded_token = decoded_token
self._validate_vectorstore_config()
def _validate_vectorstore_config(self):
"""Validate vectorstore IDs and remove any empty/invalid entries"""
if not self.vectorstores:
logging.warning("No vectorstores configured for retrieval")
return
invalid_ids = [
vs_id for vs_id in self.vectorstores if not vs_id or not vs_id.strip()
]
if invalid_ids:
logging.warning(f"Found invalid vectorstore IDs: {invalid_ids}")
self.vectorstores = [
vs_id for vs_id in self.vectorstores if vs_id and vs_id.strip()
]
def _rephrase_query(self):
"""Rephrase user query with chat history context for better retrieval"""
if (
not self.original_question
or not self.chat_history
or self.chat_history == []
or self.chunks == 0
or not self.vectorstores
or self.vectorstore is None
):
return self.original_question
prompt = f"""Given the following conversation history:
prompt = f"""Given the following conversation history:
{self.chat_history}
Rephrase the following user question to be a standalone search query
that captures all relevant context from the conversation:
"""
messages = [
@@ -118,75 +79,44 @@ class ClassicRAG(BaseRetriever):
return self.original_question
def _get_data(self):
"""Retrieve relevant documents from configured vectorstores"""
if self.chunks == 0 or not self.vectorstores:
return []
all_docs = []
chunks_per_source = max(1, self.chunks // len(self.vectorstores))
if self.chunks == 0 or self.vectorstore is None:
docs = []
else:
docsearch = VectorCreator.create_vectorstore(
settings.VECTOR_STORE, self.vectorstore, settings.EMBEDDINGS_KEY
)
docs_temp = docsearch.search(self.question, k=self.chunks)
docs = [
{
"title": i.metadata.get(
"title", i.metadata.get("post_title", i.page_content)
).split("/")[-1],
"text": i.page_content,
"source": (
i.metadata.get("source")
if i.metadata.get("source")
else "local"
),
}
for i in docs_temp
]
for vectorstore_id in self.vectorstores:
if vectorstore_id:
try:
docsearch = VectorCreator.create_vectorstore(
settings.VECTOR_STORE, vectorstore_id, settings.EMBEDDINGS_KEY
)
docs_temp = docsearch.search(self.question, k=chunks_per_source)
return docs
for doc in docs_temp:
if hasattr(doc, "page_content") and hasattr(doc, "metadata"):
page_content = doc.page_content
metadata = doc.metadata
else:
page_content = doc.get("text", doc.get("page_content", ""))
metadata = doc.get("metadata", {})
title = metadata.get(
"title", metadata.get("post_title", page_content)
)
if not isinstance(title, str):
title = str(title)
title = title.split("/")[-1]
filename = (
metadata.get("filename")
or metadata.get("file_name")
or metadata.get("source")
)
if isinstance(filename, str):
filename = os.path.basename(filename) or filename
else:
filename = title
if not filename:
filename = title
source_path = metadata.get("source") or vectorstore_id
all_docs.append(
{
"title": title,
"text": page_content,
"source": source_path,
"filename": filename,
}
)
except Exception as e:
logging.error(
f"Error searching vectorstore {vectorstore_id}: {e}",
exc_info=True,
)
continue
return all_docs
def gen():
pass
def search(self, query: str = ""):
"""Search for documents using optional query override"""
if query:
self.original_question = query
self.question = self._rephrase_query()
return self._get_data()
def get_params(self):
"""Return current retriever configuration parameters"""
return {
"question": self.original_question,
"rephrased_question": self.question,
"sources": self.vectorstores,
"source": self.vectorstore,
"chunks": self.chunks,
"token_limit": self.token_limit,
"gpt_model": self.gpt_model,

View File

@@ -0,0 +1,111 @@
from langchain_community.tools import DuckDuckGoSearchResults
from langchain_community.utilities import DuckDuckGoSearchAPIWrapper
from application.core.settings import settings
from application.llm.llm_creator import LLMCreator
from application.retriever.base import BaseRetriever
class DuckDuckSearch(BaseRetriever):
def __init__(
self,
source,
chat_history,
prompt,
chunks=2,
token_limit=150,
gpt_model="docsgpt",
user_api_key=None,
decoded_token=None,
):
self.question = ""
self.source = source
self.chat_history = chat_history
self.prompt = prompt
self.chunks = chunks
self.gpt_model = gpt_model
self.token_limit = (
token_limit
if token_limit
< settings.LLM_TOKEN_LIMITS.get(
self.gpt_model, settings.DEFAULT_MAX_HISTORY
)
else settings.LLM_TOKEN_LIMITS.get(
self.gpt_model, settings.DEFAULT_MAX_HISTORY
)
)
self.user_api_key = user_api_key
self.decoded_token = decoded_token
def _get_data(self):
if self.chunks == 0:
docs = []
else:
wrapper = DuckDuckGoSearchAPIWrapper(max_results=self.chunks)
search = DuckDuckGoSearchResults(api_wrapper=wrapper, output_format="list")
results = search.run(self.question)
docs = []
for i in results:
try:
docs.append(
{
"text": i.get("snippet", "").strip(),
"title": i.get("title", "").strip(),
"link": i.get("link", "").strip(),
}
)
except IndexError:
pass
if settings.LLM_PROVIDER == "llama.cpp":
docs = [docs[0]]
return docs
def gen(self):
docs = self._get_data()
# join all page_content together with a newline
docs_together = "\n".join([doc["text"] for doc in docs])
p_chat_combine = self.prompt.replace("{summaries}", docs_together)
messages_combine = [{"role": "system", "content": p_chat_combine}]
for doc in docs:
yield {"source": doc}
if len(self.chat_history) > 0:
for i in self.chat_history:
if "prompt" in i and "response" in i:
messages_combine.append({"role": "user", "content": i["prompt"]})
messages_combine.append(
{"role": "assistant", "content": i["response"]}
)
messages_combine.append({"role": "user", "content": self.question})
llm = LLMCreator.create_llm(
settings.LLM_PROVIDER,
api_key=settings.API_KEY,
user_api_key=self.user_api_key,
decoded_token=self.decoded_token,
)
completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine)
for line in completion:
yield {"answer": str(line)}
def search(self, query: str = ""):
if query:
self.question = query
return self._get_data()
def get_params(self):
return {
"question": self.question,
"source": self.source,
"chat_history": self.chat_history,
"prompt": self.prompt,
"chunks": self.chunks,
"token_limit": self.token_limit,
"gpt_model": self.gpt_model,
"user_api_key": self.user_api_key,
}

View File

@@ -1,9 +1,13 @@
from application.retriever.classic_rag import ClassicRAG
from application.retriever.duckduck_search import DuckDuckSearch
from application.retriever.brave_search import BraveRetSearch
class RetrieverCreator:
retrievers = {
"classic": ClassicRAG,
"duckduck_search": DuckDuckSearch,
"brave_search": BraveRetSearch,
"default": ClassicRAG,
}

View File

@@ -1,85 +0,0 @@
import base64
import json
import os
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.ciphers import algorithms, Cipher, modes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
from application.core.settings import settings
def _derive_key(user_id: str, salt: bytes) -> bytes:
app_secret = settings.ENCRYPTION_SECRET_KEY
password = f"{app_secret}#{user_id}".encode()
kdf = PBKDF2HMAC(
algorithm=hashes.SHA256(),
length=32,
salt=salt,
iterations=100000,
backend=default_backend(),
)
return kdf.derive(password)
def encrypt_credentials(credentials: dict, user_id: str) -> str:
if not credentials:
return ""
try:
salt = os.urandom(16)
iv = os.urandom(16)
key = _derive_key(user_id, salt)
json_str = json.dumps(credentials)
cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend())
encryptor = cipher.encryptor()
padded_data = _pad_data(json_str.encode())
encrypted_data = encryptor.update(padded_data) + encryptor.finalize()
result = salt + iv + encrypted_data
return base64.b64encode(result).decode()
except Exception as e:
print(f"Warning: Failed to encrypt credentials: {e}")
return ""
def decrypt_credentials(encrypted_data: str, user_id: str) -> dict:
if not encrypted_data:
return {}
try:
data = base64.b64decode(encrypted_data.encode())
salt = data[:16]
iv = data[16:32]
encrypted_content = data[32:]
key = _derive_key(user_id, salt)
cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend())
decryptor = cipher.decryptor()
decrypted_padded = decryptor.update(encrypted_content) + decryptor.finalize()
decrypted_data = _unpad_data(decrypted_padded)
return json.loads(decrypted_data.decode())
except Exception as e:
print(f"Warning: Failed to decrypt credentials: {e}")
return {}
def _pad_data(data: bytes) -> bytes:
block_size = 16
padding_len = block_size - (len(data) % block_size)
padding = bytes([padding_len]) * padding_len
return data + padding
def _unpad_data(data: bytes) -> bytes:
padding_len = data[-1]
return data[:-padding_len]

View File

@@ -1,5 +1,4 @@
"""Base storage class for file system abstraction."""
from abc import ABC, abstractmethod
from typing import BinaryIO, List, Callable
@@ -8,7 +7,7 @@ class BaseStorage(ABC):
"""Abstract base class for storage implementations."""
@abstractmethod
def save_file(self, file_data: BinaryIO, path: str, **kwargs) -> dict:
def save_file(self, file_data: BinaryIO, path: str) -> dict:
"""
Save a file to storage.
@@ -93,32 +92,3 @@ class BaseStorage(ABC):
List[str]: List of file paths
"""
pass
@abstractmethod
def is_directory(self, path: str) -> bool:
"""
Check if a path is a directory.
Args:
path: Path to check
Returns:
bool: True if the path is a directory
"""
pass
@abstractmethod
def remove_directory(self, directory: str) -> bool:
"""
Remove a directory and all its contents.
For local storage, this removes the directory and all files/subdirectories within it.
For S3 storage, this removes all objects with the directory path as a prefix.
Args:
directory: Directory path to remove
Returns:
bool: True if removal was successful, False otherwise
"""
pass

View File

@@ -101,40 +101,3 @@ class LocalStorage(BaseStorage):
raise FileNotFoundError(f"File not found: {full_path}")
return processor_func(local_path=full_path, **kwargs)
def is_directory(self, path: str) -> bool:
"""
Check if a path is a directory in local storage.
Args:
path: Path to check
Returns:
bool: True if the path is a directory, False otherwise
"""
full_path = self._get_full_path(path)
return os.path.isdir(full_path)
def remove_directory(self, directory: str) -> bool:
"""
Remove a directory and all its contents from local storage.
Args:
directory: Directory path to remove
Returns:
bool: True if removal was successful, False otherwise
"""
full_path = self._get_full_path(directory)
if not os.path.exists(full_path):
return False
if not os.path.isdir(full_path):
return False
try:
shutil.rmtree(full_path)
return True
except (OSError, PermissionError):
return False

View File

@@ -1,14 +1,13 @@
"""S3 storage implementation."""
import io
from typing import BinaryIO, List, Callable
import os
from typing import BinaryIO, Callable, List
import boto3
from application.core.settings import settings
from botocore.exceptions import ClientError
from application.storage.base import BaseStorage
from botocore.exceptions import ClientError
from application.core.settings import settings
class S3Storage(BaseStorage):
@@ -21,48 +20,38 @@ class S3Storage(BaseStorage):
Args:
bucket_name: S3 bucket name (optional, defaults to settings)
"""
self.bucket_name = bucket_name or getattr(
settings, "S3_BUCKET_NAME", "docsgpt-test-bucket"
)
self.bucket_name = bucket_name or getattr(settings, "S3_BUCKET_NAME", "docsgpt-test-bucket")
# Get credentials from settings
aws_access_key_id = getattr(settings, "SAGEMAKER_ACCESS_KEY", None)
aws_secret_access_key = getattr(settings, "SAGEMAKER_SECRET_KEY", None)
region_name = getattr(settings, "SAGEMAKER_REGION", None)
self.s3 = boto3.client(
"s3",
's3',
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
region_name=region_name,
region_name=region_name
)
def save_file(
self,
file_data: BinaryIO,
path: str,
storage_class: str = "INTELLIGENT_TIERING",
**kwargs,
) -> dict:
def save_file(self, file_data: BinaryIO, path: str) -> dict:
"""Save a file to S3 storage."""
self.s3.upload_fileobj(
file_data, self.bucket_name, path, ExtraArgs={"StorageClass": storage_class}
)
self.s3.upload_fileobj(file_data, self.bucket_name, path)
region = getattr(settings, "SAGEMAKER_REGION", None)
return {
"storage_type": "s3",
"bucket_name": self.bucket_name,
"uri": f"s3://{self.bucket_name}/{path}",
"region": region,
'storage_type': 's3',
'bucket_name': self.bucket_name,
'uri': f's3://{self.bucket_name}/{path}',
'region': region
}
def get_file(self, path: str) -> BinaryIO:
"""Get a file from S3 storage."""
if not self.file_exists(path):
raise FileNotFoundError(f"File not found: {path}")
file_obj = io.BytesIO()
self.s3.download_fileobj(self.bucket_name, path, file_obj)
file_obj.seek(0)
@@ -87,17 +76,18 @@ class S3Storage(BaseStorage):
def list_files(self, directory: str) -> List[str]:
"""List all files in a directory in S3 storage."""
# Ensure directory ends with a slash if it's not empty
if directory and not directory.endswith('/'):
directory += '/'
if directory and not directory.endswith("/"):
directory += "/"
result = []
paginator = self.s3.get_paginator("list_objects_v2")
paginator = self.s3.get_paginator('list_objects_v2')
pages = paginator.paginate(Bucket=self.bucket_name, Prefix=directory)
for page in pages:
if "Contents" in page:
for obj in page["Contents"]:
result.append(obj["Key"])
if 'Contents' in page:
for obj in page['Contents']:
result.append(obj['Key'])
return result
def process_file(self, path: str, processor_func: Callable, **kwargs):
@@ -108,99 +98,23 @@ class S3Storage(BaseStorage):
path: Path to the file
processor_func: Function that processes the file
**kwargs: Additional arguments to pass to the processor function
Returns:
The result of the processor function
"""
import logging
import tempfile
import logging
if not self.file_exists(path):
raise FileNotFoundError(f"File not found in S3: {path}")
with tempfile.NamedTemporaryFile(
suffix=os.path.splitext(path)[1], delete=True
) as temp_file:
with tempfile.NamedTemporaryFile(suffix=os.path.splitext(path)[1], delete=True) as temp_file:
try:
# Download the file from S3 to the temporary file
self.s3.download_fileobj(self.bucket_name, path, temp_file)
temp_file.flush()
return processor_func(local_path=temp_file.name, **kwargs)
except Exception as e:
logging.error(f"Error processing S3 file {path}: {e}", exc_info=True)
raise
def is_directory(self, path: str) -> bool:
"""
Check if a path is a directory in S3 storage.
In S3, directories are virtual concepts. A path is considered a directory
if there are objects with the path as a prefix.
Args:
path: Path to check
Returns:
bool: True if the path is a directory, False otherwise
"""
# Ensure path ends with a slash if not empty
if path and not path.endswith('/'):
path += '/'
response = self.s3.list_objects_v2(
Bucket=self.bucket_name,
Prefix=path,
MaxKeys=1
)
return 'Contents' in response
def remove_directory(self, directory: str) -> bool:
"""
Remove a directory and all its contents from S3 storage.
In S3, this removes all objects with the directory path as a prefix.
Since S3 doesn't have actual directories, this effectively removes
all files within the virtual directory structure.
Args:
directory: Directory path to remove
Returns:
bool: True if removal was successful, False otherwise
"""
# Ensure directory ends with a slash if not empty
if directory and not directory.endswith('/'):
directory += '/'
try:
# Get all objects with the directory prefix
objects_to_delete = []
paginator = self.s3.get_paginator('list_objects_v2')
pages = paginator.paginate(Bucket=self.bucket_name, Prefix=directory)
for page in pages:
if 'Contents' in page:
for obj in page['Contents']:
objects_to_delete.append({'Key': obj['Key']})
if not objects_to_delete:
return False
batch_size = 1000
for i in range(0, len(objects_to_delete), batch_size):
batch = objects_to_delete[i:i + batch_size]
response = self.s3.delete_objects(
Bucket=self.bucket_name,
Delete={'Objects': batch}
)
if 'Errors' in response and response['Errors']:
return False
return True
except ClientError:
return False

View File

@@ -1,13 +1,8 @@
import hashlib
import os
import re
import uuid
import tiktoken
from flask import jsonify, make_response
from werkzeug.utils import secure_filename
from application.core.settings import settings
_encoding = None
@@ -20,41 +15,6 @@ def get_encoding():
return _encoding
def get_gpt_model() -> str:
"""Get the appropriate GPT model based on provider"""
model_map = {
"openai": "gpt-4o-mini",
"anthropic": "claude-2",
"groq": "llama3-8b-8192",
"novita": "deepseek/deepseek-r1",
}
return settings.LLM_NAME or model_map.get(settings.LLM_PROVIDER, "")
def safe_filename(filename):
"""
Creates a safe filename that preserves the original extension.
Uses secure_filename, but ensures a proper filename is returned even with non-Latin characters.
Args:
filename (str): The original filename
Returns:
str: A safe filename that can be used for storage
"""
if not filename:
return str(uuid.uuid4())
_, extension = os.path.splitext(filename)
safe_name = secure_filename(filename)
# If secure_filename returns just the extension or an empty string
if not safe_name or safe_name == extension.lstrip("."):
return f"{str(uuid.uuid4())}{extension}"
return safe_name
def num_tokens_from_string(string: str) -> int:
encoding = get_encoding()
if isinstance(string, str):
@@ -79,6 +39,7 @@ def count_tokens_docs(docs):
docs_content = ""
for doc in docs:
docs_content += doc.page_content
tokens = num_tokens_from_string(docs_content)
return tokens
@@ -90,7 +51,7 @@ def check_required_fields(data, required_fields):
jsonify(
{
"success": False,
"message": f"Missing required fields: {', '.join(missing_fields)}",
"message": f"Missing fields: {', '.join(missing_fields)}",
}
),
400,
@@ -98,27 +59,6 @@ def check_required_fields(data, required_fields):
return None
def validate_required_fields(data, required_fields):
missing_fields = []
empty_fields = []
for field in required_fields:
if field not in data:
missing_fields.append(field)
elif not data[field]:
empty_fields.append(field)
errors = []
if missing_fields:
errors.append(f"Missing required fields: {', '.join(missing_fields)}")
if empty_fields:
errors.append(f"Empty values in required fields: {', '.join(empty_fields)}")
if errors:
return make_response(
jsonify({"success": False, "message": " | ".join(errors)}), 400
)
return None
def get_hash(data):
return hashlib.md5(data.encode(), usedforsecurity=False).hexdigest()
@@ -140,6 +80,7 @@ def limit_chat_history(history, max_token_limit=None, gpt_model="docsgpt"):
if not history:
return []
trimmed_history = []
tokens_current_history = 0
@@ -148,15 +89,18 @@ def limit_chat_history(history, max_token_limit=None, gpt_model="docsgpt"):
if "prompt" in message and "response" in message:
tokens_batch += num_tokens_from_string(message["prompt"])
tokens_batch += num_tokens_from_string(message["response"])
if "tool_calls" in message:
for tool_call in message["tool_calls"]:
tool_call_string = f"Tool: {tool_call.get('tool_name')} | Action: {tool_call.get('action_name')} | Args: {tool_call.get('arguments')} | Response: {tool_call.get('result')}"
tokens_batch += num_tokens_from_string(tool_call_string)
if tokens_current_history + tokens_batch < max_token_limit:
tokens_current_history += tokens_batch
trimmed_history.insert(0, message)
else:
break
return trimmed_history
@@ -165,14 +109,3 @@ def validate_function_name(function_name):
if not re.match(r"^[a-zA-Z0-9_-]+$", function_name):
return False
return True
def generate_image_url(image_path):
strategy = getattr(settings, "URL_STRATEGY", "backend")
if strategy == "s3":
bucket_name = getattr(settings, "S3_BUCKET_NAME", "docsgpt-test-bucket")
region_name = getattr(settings, "SAGEMAKER_REGION", "eu-central-1")
return f"https://{bucket_name}.s3.{region_name}.amazonaws.com/{image_path}"
else:
base_url = getattr(settings, "API_URL", "http://localhost:7091")
return f"{base_url}/api/images/{image_path}"

View File

@@ -1,28 +1,20 @@
import os
from abc import ABC, abstractmethod
from langchain_openai import OpenAIEmbeddings
import os
from sentence_transformers import SentenceTransformer
from langchain_openai import OpenAIEmbeddings
from application.core.settings import settings
class EmbeddingsWrapper:
def __init__(self, model_name, *args, **kwargs):
self.model = SentenceTransformer(
model_name,
config_kwargs={"allow_dangerous_deserialization": True},
*args,
**kwargs
)
self.model = SentenceTransformer(model_name, config_kwargs={'allow_dangerous_deserialization': True}, *args, **kwargs)
self.dimension = self.model.get_sentence_embedding_dimension()
def embed_query(self, query: str):
return self.model.encode(query).tolist()
def embed_documents(self, documents: list):
return self.model.encode(documents).tolist()
def __call__(self, text):
if isinstance(text, str):
return self.embed_query(text)
@@ -32,14 +24,15 @@ class EmbeddingsWrapper:
raise ValueError("Input must be a string or a list of strings")
class EmbeddingsSingleton:
_instances = {}
@staticmethod
def get_instance(embeddings_name, *args, **kwargs):
if embeddings_name not in EmbeddingsSingleton._instances:
EmbeddingsSingleton._instances[embeddings_name] = (
EmbeddingsSingleton._create_instance(embeddings_name, *args, **kwargs)
EmbeddingsSingleton._instances[embeddings_name] = EmbeddingsSingleton._create_instance(
embeddings_name, *args, **kwargs
)
return EmbeddingsSingleton._instances[embeddings_name]
@@ -47,15 +40,9 @@ class EmbeddingsSingleton:
def _create_instance(embeddings_name, *args, **kwargs):
embeddings_factory = {
"openai_text-embedding-ada-002": OpenAIEmbeddings,
"huggingface_sentence-transformers/all-mpnet-base-v2": lambda: EmbeddingsWrapper(
"sentence-transformers/all-mpnet-base-v2"
),
"huggingface_sentence-transformers-all-mpnet-base-v2": lambda: EmbeddingsWrapper(
"sentence-transformers/all-mpnet-base-v2"
),
"huggingface_hkunlp/instructor-large": lambda: EmbeddingsWrapper(
"hkunlp/instructor-large"
),
"huggingface_sentence-transformers/all-mpnet-base-v2": lambda: EmbeddingsWrapper("sentence-transformers/all-mpnet-base-v2"),
"huggingface_sentence-transformers-all-mpnet-base-v2": lambda: EmbeddingsWrapper("sentence-transformers/all-mpnet-base-v2"),
"huggingface_hkunlp/instructor-large": lambda: EmbeddingsWrapper("hkunlp/instructor-large"),
}
if embeddings_name in embeddings_factory:
@@ -63,63 +50,34 @@ class EmbeddingsSingleton:
else:
return EmbeddingsWrapper(embeddings_name, *args, **kwargs)
class BaseVectorStore(ABC):
def __init__(self):
pass
@abstractmethod
def search(self, *args, **kwargs):
"""Search for similar documents/chunks in the vectorstore"""
pass
@abstractmethod
def add_texts(self, texts, metadatas=None, *args, **kwargs):
"""Add texts with their embeddings to the vectorstore"""
pass
def delete_index(self, *args, **kwargs):
"""Delete the entire index/collection"""
pass
def save_local(self, *args, **kwargs):
"""Save vectorstore to local storage"""
pass
def get_chunks(self, *args, **kwargs):
"""Get all chunks from the vectorstore"""
pass
def add_chunk(self, text, metadata=None, *args, **kwargs):
"""Add a single chunk to the vectorstore"""
pass
def delete_chunk(self, chunk_id, *args, **kwargs):
"""Delete a specific chunk from the vectorstore"""
pass
def is_azure_configured(self):
return (
settings.OPENAI_API_BASE
and settings.OPENAI_API_VERSION
and settings.AZURE_DEPLOYMENT_NAME
)
return settings.OPENAI_API_BASE and settings.OPENAI_API_VERSION and settings.AZURE_DEPLOYMENT_NAME
def _get_embeddings(self, embeddings_name, embeddings_key=None):
if embeddings_name == "openai_text-embedding-ada-002":
if self.is_azure_configured():
os.environ["OPENAI_API_TYPE"] = "azure"
embedding_instance = EmbeddingsSingleton.get_instance(
embeddings_name, model=settings.AZURE_EMBEDDINGS_DEPLOYMENT_NAME
embeddings_name,
model=settings.AZURE_EMBEDDINGS_DEPLOYMENT_NAME
)
else:
embedding_instance = EmbeddingsSingleton.get_instance(
embeddings_name, openai_api_key=embeddings_key
embeddings_name,
openai_api_key=embeddings_key
)
elif embeddings_name == "huggingface_sentence-transformers/all-mpnet-base-v2":
if os.path.exists("./models/all-mpnet-base-v2"):
embedding_instance = EmbeddingsSingleton.get_instance(
embeddings_name="./models/all-mpnet-base-v2",
embeddings_name = "./models/all-mpnet-base-v2",
)
else:
embedding_instance = EmbeddingsSingleton.get_instance(
@@ -129,3 +87,4 @@ class BaseVectorStore(ABC):
embedding_instance = EmbeddingsSingleton.get_instance(embeddings_name)
return embedding_instance

View File

@@ -1,6 +1,5 @@
import os
import tempfile
import io
from langchain_community.vectorstores import FAISS
@@ -67,37 +66,8 @@ class FaissStore(BaseVectorStore):
def add_texts(self, *args, **kwargs):
return self.docsearch.add_texts(*args, **kwargs)
def _save_to_storage(self):
"""
Save the FAISS index to storage using temporary directory pattern.
Works consistently for both local and S3 storage.
"""
with tempfile.TemporaryDirectory() as temp_dir:
self.docsearch.save_local(temp_dir)
faiss_path = os.path.join(temp_dir, "index.faiss")
pkl_path = os.path.join(temp_dir, "index.pkl")
with open(faiss_path, "rb") as f_faiss:
faiss_data = f_faiss.read()
with open(pkl_path, "rb") as f_pkl:
pkl_data = f_pkl.read()
storage_path = get_vectorstore(self.source_id)
self.storage.save_file(io.BytesIO(faiss_data), f"{storage_path}/index.faiss")
self.storage.save_file(io.BytesIO(pkl_data), f"{storage_path}/index.pkl")
return True
def save_local(self, path=None):
if path:
os.makedirs(path, exist_ok=True)
self.docsearch.save_local(path)
self._save_to_storage()
return True
def save_local(self, *args, **kwargs):
return self.docsearch.save_local(*args, **kwargs)
def delete_index(self, *args, **kwargs):
return self.docsearch.delete(*args, **kwargs)
@@ -133,17 +103,13 @@ class FaissStore(BaseVectorStore):
return chunks
def add_chunk(self, text, metadata=None):
"""Add a new chunk and save to storage."""
metadata = metadata or {}
doc = Document(text=text, extra_info=metadata).to_langchain_format()
doc_id = self.docsearch.add_documents([doc])
self._save_to_storage()
self.save_local(self.path)
return doc_id
def delete_chunk(self, chunk_id):
"""Delete a chunk and save to storage."""
self.delete_index([chunk_id])
self._save_to_storage()
self.save_local(self.path)
return True

View File

@@ -1,303 +0,0 @@
import logging
from typing import List, Optional, Any, Dict
from application.core.settings import settings
from application.vectorstore.base import BaseVectorStore
from application.vectorstore.document_class import Document
class PGVectorStore(BaseVectorStore):
def __init__(
self,
source_id: str = "",
embeddings_key: str = "embeddings",
table_name: str = "documents",
vector_column: str = "embedding",
text_column: str = "text",
metadata_column: str = "metadata",
connection_string: str = None,
):
super().__init__()
# Store the source_id for use in add_chunk
self._source_id = str(source_id).replace("application/indexes/", "").rstrip("/")
self._embeddings_key = embeddings_key
self._table_name = table_name
self._vector_column = vector_column
self._text_column = text_column
self._metadata_column = metadata_column
self._embedding = self._get_embeddings(settings.EMBEDDINGS_NAME, embeddings_key)
# Use provided connection string or fall back to settings
self._connection_string = connection_string or getattr(settings, 'PGVECTOR_CONNECTION_STRING', None)
if not self._connection_string:
raise ValueError(
"PostgreSQL connection string is required. "
"Set PGVECTOR_CONNECTION_STRING in settings or pass connection_string parameter."
)
try:
import psycopg2
from psycopg2.extras import Json
import pgvector.psycopg2
except ImportError:
raise ImportError(
"Could not import required packages. "
"Please install with `pip install psycopg2-binary pgvector`."
)
self._psycopg2 = psycopg2
self._Json = Json
self._pgvector = pgvector.psycopg2
self._connection = None
self._ensure_table_exists()
def _get_connection(self):
"""Get or create database connection"""
if self._connection is None or self._connection.closed:
self._connection = self._psycopg2.connect(self._connection_string)
# Register pgvector types
self._pgvector.register_vector(self._connection)
return self._connection
def _ensure_table_exists(self):
"""Create table and enable pgvector extension if they don't exist"""
conn = self._get_connection()
cursor = conn.cursor()
try:
# Enable pgvector extension
cursor.execute("CREATE EXTENSION IF NOT EXISTS vector;")
# Get embedding dimension
embedding_dim = getattr(self._embedding, 'dimension', 1536) # Default to OpenAI dimension
# Create table with vector column
create_table_query = f"""
CREATE TABLE IF NOT EXISTS {self._table_name} (
id SERIAL PRIMARY KEY,
{self._text_column} TEXT NOT NULL,
{self._vector_column} vector({embedding_dim}),
{self._metadata_column} JSONB,
source_id TEXT NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
"""
cursor.execute(create_table_query)
# Create index for vector similarity search
index_query = f"""
CREATE INDEX IF NOT EXISTS {self._table_name}_{self._vector_column}_idx
ON {self._table_name} USING ivfflat ({self._vector_column} vector_cosine_ops)
WITH (lists = 100);
"""
cursor.execute(index_query)
# Create index for source_id filtering
source_index_query = f"""
CREATE INDEX IF NOT EXISTS {self._table_name}_source_id_idx
ON {self._table_name} (source_id);
"""
cursor.execute(source_index_query)
conn.commit()
except Exception as e:
conn.rollback()
logging.error(f"Error creating table: {e}")
raise
finally:
cursor.close()
def search(self, question: str, k: int = 2, *args, **kwargs) -> List[Document]:
"""Search for similar documents using vector similarity"""
query_vector = self._embedding.embed_query(question)
conn = self._get_connection()
cursor = conn.cursor()
try:
# Use cosine distance for similarity search with proper vector formatting
search_query = f"""
SELECT {self._text_column}, {self._metadata_column},
({self._vector_column} <=> %s::vector) as distance
FROM {self._table_name}
WHERE source_id = %s
ORDER BY {self._vector_column} <=> %s::vector
LIMIT %s;
"""
cursor.execute(search_query, (query_vector, self._source_id, query_vector, k))
results = cursor.fetchall()
documents = []
for text, metadata, distance in results:
metadata = metadata or {}
documents.append(Document(page_content=text, metadata=metadata))
return documents
except Exception as e:
logging.error(f"Error searching documents: {e}", exc_info=True)
return []
finally:
cursor.close()
def add_texts(
self,
texts: List[str],
metadatas: Optional[List[Dict[str, Any]]] = None,
*args,
**kwargs,
) -> List[str]:
"""Add texts with their embeddings to the vector store"""
if not texts:
return []
embeddings = self._embedding.embed_documents(texts)
metadatas = metadatas or [{}] * len(texts)
conn = self._get_connection()
cursor = conn.cursor()
try:
insert_query = f"""
INSERT INTO {self._table_name} ({self._text_column}, {self._vector_column}, {self._metadata_column}, source_id)
VALUES (%s, %s, %s, %s)
RETURNING id;
"""
inserted_ids = []
for text, embedding, metadata in zip(texts, embeddings, metadatas):
cursor.execute(
insert_query,
(text, embedding, self._Json(metadata), self._source_id)
)
inserted_id = cursor.fetchone()[0]
inserted_ids.append(str(inserted_id))
conn.commit()
return inserted_ids
except Exception as e:
conn.rollback()
logging.error(f"Error adding texts: {e}")
raise
finally:
cursor.close()
def delete_index(self, *args, **kwargs):
"""Delete all documents for this source_id"""
conn = self._get_connection()
cursor = conn.cursor()
try:
delete_query = f"DELETE FROM {self._table_name} WHERE source_id = %s;"
cursor.execute(delete_query, (self._source_id,))
conn.commit()
except Exception as e:
conn.rollback()
logging.error(f"Error deleting index: {e}")
raise
finally:
cursor.close()
def save_local(self, *args, **kwargs):
"""No-op for PostgreSQL - data is already persisted"""
pass
def get_chunks(self) -> List[Dict[str, Any]]:
"""Get all chunks for this source_id"""
conn = self._get_connection()
cursor = conn.cursor()
try:
select_query = f"""
SELECT id, {self._text_column}, {self._metadata_column}
FROM {self._table_name}
WHERE source_id = %s;
"""
cursor.execute(select_query, (self._source_id,))
results = cursor.fetchall()
chunks = []
for doc_id, text, metadata in results:
chunks.append({
"doc_id": str(doc_id),
"text": text,
"metadata": metadata or {}
})
return chunks
except Exception as e:
logging.error(f"Error getting chunks: {e}")
return []
finally:
cursor.close()
def add_chunk(self, text: str, metadata: Optional[Dict[str, Any]] = None) -> str:
"""Add a single chunk to the vector store"""
metadata = metadata or {}
# Create a copy to avoid modifying the original metadata
final_metadata = metadata.copy()
# Ensure the source_id is in the metadata so the chunk can be found by filters
final_metadata["source_id"] = self._source_id
embeddings = self._embedding.embed_documents([text])
if not embeddings:
raise ValueError("Could not generate embedding for chunk")
conn = self._get_connection()
cursor = conn.cursor()
try:
insert_query = f"""
INSERT INTO {self._table_name} ({self._text_column}, {self._vector_column}, {self._metadata_column}, source_id)
VALUES (%s, %s, %s, %s)
RETURNING id;
"""
cursor.execute(
insert_query,
(text, embeddings[0], self._Json(final_metadata), self._source_id)
)
inserted_id = cursor.fetchone()[0]
conn.commit()
return str(inserted_id)
except Exception as e:
conn.rollback()
logging.error(f"Error adding chunk: {e}")
raise
finally:
cursor.close()
def delete_chunk(self, chunk_id: str) -> bool:
"""Delete a specific chunk by its ID"""
conn = self._get_connection()
cursor = conn.cursor()
try:
delete_query = f"DELETE FROM {self._table_name} WHERE id = %s AND source_id = %s;"
cursor.execute(delete_query, (int(chunk_id), self._source_id))
deleted_count = cursor.rowcount
conn.commit()
return deleted_count > 0
except Exception as e:
conn.rollback()
logging.error(f"Error deleting chunk: {e}")
return False
finally:
cursor.close()
def __del__(self):
"""Close database connection when object is destroyed"""
if hasattr(self, '_connection') and self._connection and not self._connection.closed:
self._connection.close()

View File

@@ -1,7 +1,5 @@
import logging
from application.vectorstore.base import BaseVectorStore
from application.core.settings import settings
from application.vectorstore.document_class import Document
class QdrantStore(BaseVectorStore):
@@ -9,22 +7,18 @@ class QdrantStore(BaseVectorStore):
from qdrant_client import models
from langchain_community.vectorstores.qdrant import Qdrant
# Store the source_id for use in add_chunk
self._source_id = str(source_id).replace("application/indexes/", "").rstrip("/")
self._filter = models.Filter(
must=[
models.FieldCondition(
key="metadata.source_id",
match=models.MatchValue(value=self._source_id),
match=models.MatchValue(value=source_id.replace("application/indexes/", "").rstrip("/")),
)
]
)
embedding=self._get_embeddings(settings.EMBEDDINGS_NAME, embeddings_key)
self._docsearch = Qdrant.construct_instance(
["TEXT_TO_OBTAIN_EMBEDDINGS_DIMENSION"],
embedding=embedding,
embedding=self._get_embeddings(settings.EMBEDDINGS_NAME, embeddings_key),
collection_name=settings.QDRANT_COLLECTION_NAME,
location=settings.QDRANT_LOCATION,
url=settings.QDRANT_URL,
@@ -38,32 +32,6 @@ class QdrantStore(BaseVectorStore):
path=settings.QDRANT_PATH,
distance_func=settings.QDRANT_DISTANCE_FUNC,
)
try:
collections = self._docsearch.client.get_collections()
collection_exists = settings.QDRANT_COLLECTION_NAME in [
collection.name for collection in collections.collections
]
if not collection_exists:
self._docsearch.client.recreate_collection(
collection_name=settings.QDRANT_COLLECTION_NAME,
vectors_config=models.VectorParams(size=embedding.client[1].word_embedding_dimension, distance=models.Distance.COSINE),
)
# Ensure the required index exists for metadata.source_id
try:
self._docsearch.client.create_payload_index(
collection_name=settings.QDRANT_COLLECTION_NAME,
field_name="metadata.source_id",
field_schema=models.PayloadSchemaType.KEYWORD,
)
except Exception as index_error:
# Index might already exist, which is fine
if "already exists" not in str(index_error).lower():
logging.warning(f"Could not create index for metadata.source_id: {index_error}")
except Exception as e:
logging.warning(f"Could not check for collection: {e}")
def search(self, *args, **kwargs):
return self._docsearch.similarity_search(filter=self._filter, *args, **kwargs)
@@ -78,59 +46,3 @@ class QdrantStore(BaseVectorStore):
return self._docsearch.client.delete(
collection_name=settings.QDRANT_COLLECTION_NAME, points_selector=self._filter
)
def get_chunks(self):
try:
chunks = []
offset = None
while True:
records, offset = self._docsearch.client.scroll(
collection_name=settings.QDRANT_COLLECTION_NAME,
scroll_filter=self._filter,
limit=10,
with_payload=True,
with_vectors=False,
offset=offset,
)
for record in records:
doc_id = record.id
text = record.payload.get("page_content")
metadata = record.payload.get("metadata")
chunks.append(
{"doc_id": doc_id, "text": text, "metadata": metadata}
)
if offset is None:
break
return chunks
except Exception as e:
logging.error(f"Error getting chunks: {e}", exc_info=True)
return []
def add_chunk(self, text, metadata=None):
import uuid
metadata = metadata or {}
# Create a copy to avoid modifying the original metadata
final_metadata = metadata.copy()
# Ensure the source_id is in the metadata so the chunk can be found by filters
final_metadata["source_id"] = self._source_id
doc = Document(page_content=text, metadata=final_metadata)
# Generate a unique ID for the document
doc_id = str(uuid.uuid4())
doc.id = doc_id
doc_ids = self._docsearch.add_documents([doc])
return doc_ids[0] if doc_ids else doc_id
def delete_chunk(self, chunk_id):
try:
self._docsearch.client.delete(
collection_name=settings.QDRANT_COLLECTION_NAME,
points_selector=[chunk_id],
)
return True
except Exception as e:
logging.error(f"Error deleting chunk: {e}", exc_info=True)
return False

View File

@@ -3,7 +3,6 @@ from application.vectorstore.elasticsearch import ElasticsearchStore
from application.vectorstore.milvus import MilvusStore
from application.vectorstore.mongodb import MongoDBVectorStore
from application.vectorstore.qdrant import QdrantStore
from application.vectorstore.pgvector import PGVectorStore
class VectorCreator:
@@ -13,7 +12,6 @@ class VectorCreator:
"mongodb": MongoDBVectorStore,
"qdrant": QdrantStore,
"milvus": MilvusStore,
"pgvector": PGVectorStore
}
@classmethod

View File

@@ -6,7 +6,6 @@ import os
import shutil
import string
import tempfile
from typing import Any, Dict
import zipfile
from collections import Counter
@@ -17,13 +16,11 @@ from bson.dbref import DBRef
from bson.objectid import ObjectId
from application.agents.agent_creator import AgentCreator
from application.api.answer.services.stream_processor import get_prompt
from application.api.answer.routes import get_prompt
from application.cache import get_redis_instance
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.parser.chunking import Chunker
from application.parser.connectors.connector_creator import ConnectorCreator
from application.parser.embedding_pipeline import embed_and_store_documents
from application.parser.file.bulk import SimpleDirectoryReader
from application.parser.remote.remote_creator import RemoteCreator
@@ -38,22 +35,17 @@ db = mongo[settings.MONGO_DB_NAME]
sources_collection = db["sources"]
# Constants
MIN_TOKENS = 150
MAX_TOKENS = 1250
RECURSION_DEPTH = 2
# Define a function to extract metadata from a given filename.
def metadata_from_filename(title):
return {"title": title}
# Define a function to generate a random string of a given length.
def generate_random_string(length):
return "".join([string.ascii_letters[i % 52] for i in range(length)])
@@ -76,6 +68,7 @@ def extract_zip_recursive(zip_path, extract_to, current_depth=0, max_depth=5):
if current_depth > max_depth:
logging.warning(f"Reached maximum recursion depth of {max_depth}")
return
try:
with zipfile.ZipFile(zip_path, "r") as zip_ref:
zip_ref.extractall(extract_to)
@@ -83,13 +76,12 @@ def extract_zip_recursive(zip_path, extract_to, current_depth=0, max_depth=5):
except Exception as e:
logging.error(f"Error extracting zip file {zip_path}: {e}", exc_info=True)
return
# Check for nested zip files and extract them
# Check for nested zip files and extract them
for root, dirs, files in os.walk(extract_to):
for file in files:
if file.endswith(".zip"):
# If a nested zip file is found, extract it recursively
file_path = os.path.join(root, file)
extract_zip_recursive(file_path, root, current_depth + 1, max_depth)
@@ -106,23 +98,11 @@ def download_file(url, params, dest_path):
def upload_index(full_path, file_data):
files = None
try:
if settings.VECTOR_STORE == "faiss":
faiss_path = full_path + "/index.faiss"
pkl_path = full_path + "/index.pkl"
if not os.path.exists(faiss_path):
logging.error(f"FAISS index file not found: {faiss_path}")
raise FileNotFoundError(f"FAISS index file not found: {faiss_path}")
if not os.path.exists(pkl_path):
logging.error(f"FAISS pickle file not found: {pkl_path}")
raise FileNotFoundError(f"FAISS pickle file not found: {pkl_path}")
files = {
"file_faiss": open(faiss_path, "rb"),
"file_pkl": open(pkl_path, "rb"),
"file_faiss": open(full_path + "/index.faiss", "rb"),
"file_pkl": open(full_path + "/index.pkl", "rb"),
}
response = requests.post(
urljoin(settings.API_URL, "/api/upload_index"),
@@ -134,11 +114,11 @@ def upload_index(full_path, file_data):
urljoin(settings.API_URL, "/api/upload_index"), data=file_data
)
response.raise_for_status()
except (requests.RequestException, FileNotFoundError) as e:
except requests.RequestException as e:
logging.error(f"Error uploading index: {e}")
raise
finally:
if settings.VECTOR_STORE == "faiss" and files is not None:
if settings.VECTOR_STORE == "faiss":
for file in files.values():
file.close()
@@ -159,7 +139,7 @@ def run_agent_logic(agent_config, input_data):
user_api_key = agent_config["key"]
agent_type = agent_config.get("agent_type", "classic")
decoded_token = {"sub": agent_config.get("user")}
prompt = get_prompt(prompt_id, db["prompts"])
prompt = get_prompt(prompt_id)
agent = AgentCreator.create_agent(
agent_type,
endpoint="webhook",
@@ -198,6 +178,7 @@ def run_agent_logic(agent_config, input_data):
tool_calls.extend(line["tool_calls"])
elif "thought" in line:
thought += line["thought"]
result = {
"answer": response_full,
"sources": source_log_docs,
@@ -212,10 +193,8 @@ def run_agent_logic(agent_config, input_data):
# Define the main function for ingesting and processing documents.
def ingest_worker(
self, directory, formats, job_name, file_path, filename, user, retriever="classic"
self, directory, formats, name_job, filename, user, retriever="classic"
):
"""
Ingest and process documents.
@@ -224,10 +203,9 @@ def ingest_worker(
self: Reference to the instance of the task.
directory (str): Specifies the directory for ingesting ('inputs' or 'temp').
formats (list of str): List of file extensions to consider for ingestion (e.g., [".rst", ".md"]).
job_name (str): Name of the job for this ingestion task (original, unsanitized).
file_path (str): Complete file path to use consistently throughout the pipeline.
filename (str): Original unsanitized filename provided by the user.
user (str): Identifier for the user initiating the ingestion (original, unsanitized).
name_job (str): Name of the job for this ingestion task.
filename (str): Name of the file to be ingested.
user (str): Identifier for the user initiating the ingestion.
retriever (str): Type of retriever to use for processing the documents.
Returns:
@@ -241,61 +219,35 @@ def ingest_worker(
storage = StorageCreator.get_storage()
logging.info(f"Ingest path: {file_path}", extra={"user": user, "job": job_name})
full_path = os.path.join(directory, user, name_job)
source_file_path = os.path.join(full_path, filename)
logging.info(f"Ingest file: {full_path}", extra={"user": user, "job": name_job})
# Create temporary working directory
with tempfile.TemporaryDirectory() as temp_dir:
try:
os.makedirs(temp_dir, exist_ok=True)
if storage.is_directory(file_path):
# Handle directory case
logging.info(f"Processing directory: {file_path}")
files_list = storage.list_files(file_path)
# Download file from storage to temp directory
temp_file_path = os.path.join(temp_dir, filename)
file_data = storage.get_file(source_file_path)
for storage_file_path in files_list:
if storage.is_directory(storage_file_path):
continue
# Create relative path structure in temp directory
rel_path = os.path.relpath(storage_file_path, file_path)
local_file_path = os.path.join(temp_dir, rel_path)
os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
# Download file
try:
file_data = storage.get_file(storage_file_path)
with open(local_file_path, "wb") as f:
f.write(file_data.read())
except Exception as e:
logging.error(
f"Error downloading file {storage_file_path}: {e}"
)
continue
else:
# Handle single file case
temp_filename = os.path.basename(file_path)
temp_file_path = os.path.join(temp_dir, temp_filename)
file_data = storage.get_file(file_path)
with open(temp_file_path, "wb") as f:
f.write(file_data.read())
# Handle zip files
if temp_filename.endswith(".zip"):
logging.info(f"Extracting zip file: {temp_filename}")
extract_zip_recursive(
temp_file_path,
temp_dir,
current_depth=0,
max_depth=RECURSION_DEPTH,
)
with open(temp_file_path, "wb") as f:
f.write(file_data.read())
self.update_state(state="PROGRESS", meta={"current": 1})
# Handle zip files
if filename.endswith(".zip"):
logging.info(f"Extracting zip file: {filename}")
extract_zip_recursive(
temp_file_path, temp_dir, current_depth=0, max_depth=RECURSION_DEPTH
)
if sample:
logging.info(f"Sample mode enabled. Using {limit} documents.")
reader = SimpleDirectoryReader(
input_dir=temp_dir,
input_files=input_files,
@@ -306,9 +258,6 @@ def ingest_worker(
)
raw_docs = reader.load_data()
directory_structure = getattr(reader, "directory_structure", {})
logging.info(f"Directory structure from reader: {directory_structure}")
chunker = Chunker(
chunking_strategy="classic_chunk",
max_tokens=MAX_TOKENS,
@@ -334,348 +283,31 @@ def ingest_worker(
for i in range(min(5, len(raw_docs))):
logging.info(f"Sample document {i}: {raw_docs[i]}")
file_data = {
"name": job_name,
"name": name_job,
"file": filename,
"user": user,
"tokens": tokens,
"retriever": retriever,
"id": str(id),
"type": "local",
"file_path": file_path,
"directory_structure": json.dumps(directory_structure),
}
upload_index(vector_store_path, file_data)
except Exception as e:
logging.error(f"Error in ingest_worker: {e}", exc_info=True)
raise
return {
"directory": directory,
"formats": formats,
"name_job": job_name, # Use original job_name
"name_job": name_job,
"filename": filename,
"user": user, # Use original user
"user": user,
"limited": False,
}
def reingest_source_worker(self, source_id, user):
"""
Re-ingestion worker that handles incremental updates by:
1. Adding chunks from newly added files
2. Removing chunks from deleted files
Args:
self: Task instance
source_id: ID of the source to re-ingest
user: User identifier
Returns:
dict: Information about the re-ingestion task
"""
try:
from application.vectorstore.vector_creator import VectorCreator
self.update_state(
state="PROGRESS",
meta={"current": 10, "status": "Initializing re-ingestion scan"},
)
source = sources_collection.find_one({"_id": ObjectId(source_id), "user": user})
if not source:
raise ValueError(f"Source {source_id} not found or access denied")
storage = StorageCreator.get_storage()
source_file_path = source.get("file_path", "")
self.update_state(
state="PROGRESS", meta={"current": 20, "status": "Scanning current files"}
)
with tempfile.TemporaryDirectory() as temp_dir:
# Download all files from storage to temp directory, preserving directory structure
if storage.is_directory(source_file_path):
files_list = storage.list_files(source_file_path)
for storage_file_path in files_list:
if storage.is_directory(storage_file_path):
continue
rel_path = os.path.relpath(storage_file_path, source_file_path)
local_file_path = os.path.join(temp_dir, rel_path)
os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
# Download file
try:
file_data = storage.get_file(storage_file_path)
with open(local_file_path, "wb") as f:
f.write(file_data.read())
except Exception as e:
logging.error(
f"Error downloading file {storage_file_path}: {e}"
)
continue
reader = SimpleDirectoryReader(
input_dir=temp_dir,
recursive=True,
required_exts=[
".rst",
".md",
".pdf",
".txt",
".docx",
".csv",
".epub",
".html",
".mdx",
".json",
".xlsx",
".pptx",
".png",
".jpg",
".jpeg",
],
exclude_hidden=True,
file_metadata=metadata_from_filename,
)
reader.load_data()
directory_structure = reader.directory_structure
logging.info(
f"Directory structure built with token counts: {directory_structure}"
)
try:
old_directory_structure = source.get("directory_structure") or {}
if isinstance(old_directory_structure, str):
try:
old_directory_structure = json.loads(old_directory_structure)
except Exception:
old_directory_structure = {}
def _flatten_directory_structure(struct, prefix=""):
files = set()
if isinstance(struct, dict):
for name, meta in struct.items():
current_path = (
os.path.join(prefix, name) if prefix else name
)
if isinstance(meta, dict) and (
"type" in meta and "size_bytes" in meta
):
files.add(current_path)
elif isinstance(meta, dict):
files |= _flatten_directory_structure(
meta, current_path
)
return files
old_files = _flatten_directory_structure(old_directory_structure)
new_files = _flatten_directory_structure(directory_structure)
added_files = sorted(new_files - old_files)
removed_files = sorted(old_files - new_files)
if added_files:
logging.info(f"Files added since last ingest: {added_files}")
else:
logging.info("No files added since last ingest.")
if removed_files:
logging.info(f"Files removed since last ingest: {removed_files}")
else:
logging.info("No files removed since last ingest.")
except Exception as e:
logging.error(
f"Error comparing directory structures: {e}", exc_info=True
)
added_files = []
removed_files = []
try:
if not added_files and not removed_files:
logging.info("No changes detected.")
return {
"source_id": source_id,
"user": user,
"status": "no_changes",
"added_files": [],
"removed_files": [],
}
vector_store = VectorCreator.create_vectorstore(
settings.VECTOR_STORE,
source_id,
settings.EMBEDDINGS_KEY,
)
self.update_state(
state="PROGRESS",
meta={"current": 40, "status": "Processing file changes"},
)
# 1) Delete chunks from removed files
deleted = 0
if removed_files:
try:
for ch in vector_store.get_chunks() or []:
metadata = (
ch.get("metadata", {})
if isinstance(ch, dict)
else getattr(ch, "metadata", {})
)
raw_source = metadata.get("source")
source_file = str(raw_source) if raw_source else ""
if source_file in removed_files:
cid = ch.get("doc_id")
if cid:
try:
vector_store.delete_chunk(cid)
deleted += 1
except Exception as de:
logging.error(
f"Failed deleting chunk {cid}: {de}"
)
logging.info(
f"Deleted {deleted} chunks from {len(removed_files)} removed files"
)
except Exception as e:
logging.error(
f"Error during deletion of removed file chunks: {e}",
exc_info=True,
)
# 2) Add chunks from new files
added = 0
if added_files:
try:
# Build list of local files for added files only
added_local_files = []
for rel_path in added_files:
local_path = os.path.join(temp_dir, rel_path)
if os.path.isfile(local_path):
added_local_files.append(local_path)
if added_local_files:
reader_new = SimpleDirectoryReader(
input_files=added_local_files,
exclude_hidden=True,
errors="ignore",
file_metadata=metadata_from_filename,
)
raw_docs_new = reader_new.load_data()
chunker_new = Chunker(
chunking_strategy="classic_chunk",
max_tokens=MAX_TOKENS,
min_tokens=MIN_TOKENS,
duplicate_headers=False,
)
chunked_new = chunker_new.chunk(documents=raw_docs_new)
for (
file_path,
token_count,
) in reader_new.file_token_counts.items():
try:
rel_path = os.path.relpath(
file_path, start=temp_dir
)
path_parts = rel_path.split(os.sep)
current_dir = directory_structure
for part in path_parts[:-1]:
if part in current_dir and isinstance(
current_dir[part], dict
):
current_dir = current_dir[part]
else:
break
filename = path_parts[-1]
if filename in current_dir and isinstance(
current_dir[filename], dict
):
current_dir[filename][
"token_count"
] = token_count
logging.info(
f"Updated token count for {rel_path}: {token_count}"
)
except Exception as e:
logging.warning(
f"Could not update token count for {file_path}: {e}"
)
for d in chunked_new:
meta = dict(d.extra_info or {})
try:
raw_src = meta.get("source")
if isinstance(raw_src, str) and os.path.isabs(
raw_src
):
meta["source"] = os.path.relpath(
raw_src, start=temp_dir
)
except Exception:
pass
vector_store.add_chunk(d.text, metadata=meta)
added += 1
logging.info(
f"Added {added} chunks from {len(added_files)} new files"
)
except Exception as e:
logging.error(
f"Error during ingestion of new files: {e}", exc_info=True
)
# 3) Update source directory structure timestamp
try:
total_tokens = sum(reader.file_token_counts.values())
sources_collection.update_one(
{"_id": ObjectId(source_id)},
{
"$set": {
"directory_structure": directory_structure,
"date": datetime.datetime.now(),
"tokens": total_tokens,
}
},
)
except Exception as e:
logging.error(
f"Error updating directory_structure in DB: {e}", exc_info=True
)
self.update_state(
state="PROGRESS",
meta={"current": 100, "status": "Re-ingestion completed"},
)
return {
"source_id": source_id,
"user": user,
"status": "completed",
"added_files": added_files,
"removed_files": removed_files,
"chunks_added": added,
"chunks_deleted": deleted,
}
except Exception as e:
logging.error(
f"Error while processing file changes: {e}", exc_info=True
)
raise
except Exception as e:
logging.error(f"Error in reingest_source_worker: {e}", exc_info=True)
raise
def remote_worker(
self,
source_data,
@@ -691,6 +323,7 @@ def remote_worker(
full_path = os.path.join(directory, user, name_job)
if not os.path.exists(full_path):
os.makedirs(full_path)
self.update_state(state="PROGRESS", meta={"current": 1})
try:
logging.info("Initializing remote loader with type: %s", loader)
@@ -717,6 +350,7 @@ def remote_worker(
raise ValueError("doc_id must be provided for sync operation.")
id = ObjectId(doc_id)
embed_and_store_documents(docs, full_path, id, self)
self.update_state(state="PROGRESS", meta={"current": 100})
file_data = {
@@ -729,16 +363,16 @@ def remote_worker(
"remote_data": source_data,
"sync_frequency": sync_frequency,
}
if operation_mode == "sync":
file_data["last_sync"] = datetime.datetime.now()
upload_index(full_path, file_data)
except Exception as e:
logging.error("Error in remote_worker task: %s", str(e), exc_info=True)
raise
finally:
if os.path.exists(full_path):
shutil.rmtree(full_path)
logging.info("remote_worker task completed successfully")
return {"urls": source_data, "name_job": name_job, "user": user, "limited": False}
@@ -791,6 +425,7 @@ def sync_worker(self, frequency):
sync_counts[
"sync_success" if resp["status"] == "success" else "sync_failure"
] += 1
return {
key: sync_counts[key]
for key in ["total_sync_count", "sync_success", "sync_failure"]
@@ -829,9 +464,6 @@ def attachment_worker(self, file_info, user):
)
token_count = num_tokens_from_string(content)
if token_count > 100000:
content = content[:250000]
token_count = num_tokens_from_string(content)
self.update_state(
state="PROGRESS", meta={"current": 80, "status": "Storing in database"}
@@ -868,6 +500,7 @@ def attachment_worker(self, file_info, user):
"mime_type": mime_type,
"metadata": metadata,
}
except Exception as e:
logging.error(
f"Error processing file {filename}: {e}",
@@ -903,6 +536,7 @@ def agent_webhook_worker(self, agent_id, payload):
except Exception as e:
logging.error(f"Error processing agent webhook: {e}", exc_info=True)
return {"status": "error", "error": str(e)}
self.update_state(state="PROGRESS", meta={"current": 50})
try:
result = run_agent_logic(agent_config, input_data)
@@ -915,334 +549,3 @@ def agent_webhook_worker(self, agent_id, payload):
f"Webhook processed for agent {agent_id}", extra={"agent_id": agent_id}
)
return {"status": "success", "result": result}
def ingest_connector(
self,
job_name: str,
user: str,
source_type: str,
session_token=None,
file_ids=None,
folder_ids=None,
recursive=True,
retriever: str = "classic",
operation_mode: str = "upload",
doc_id=None,
sync_frequency: str = "never",
) -> Dict[str, Any]:
"""
Ingestion for internal knowledge bases (GoogleDrive, etc.).
Args:
job_name: Name of the ingestion job
user: User identifier
source_type: Type of remote source ("google_drive", "dropbox", etc.)
session_token: Authentication token for the service
file_ids: List of file IDs to download
folder_ids: List of folder IDs to download
recursive: Whether to recursively download folders
retriever: Type of retriever to use
operation_mode: "upload" for initial ingestion, "sync" for incremental sync
doc_id: Document ID for sync operations (required when operation_mode="sync")
sync_frequency: How often to sync ("never", "daily", "weekly", "monthly")
"""
logging.info(
f"Starting remote ingestion from {source_type} for user: {user}, job: {job_name}"
)
self.update_state(state="PROGRESS", meta={"current": 1})
with tempfile.TemporaryDirectory() as temp_dir:
try:
# Step 1: Initialize the appropriate loader
self.update_state(
state="PROGRESS",
meta={"current": 10, "status": "Initializing connector"},
)
if not session_token:
raise ValueError(f"{source_type} connector requires session_token")
if not ConnectorCreator.is_supported(source_type):
raise ValueError(
f"Unsupported connector type: {source_type}. Supported types: {ConnectorCreator.get_supported_connectors()}"
)
remote_loader = ConnectorCreator.create_connector(
source_type, session_token
)
# Create a clean config for storage
api_source_config = {
"file_ids": file_ids or [],
"folder_ids": folder_ids or [],
"recursive": recursive,
}
# Step 2: Download files to temp directory
self.update_state(
state="PROGRESS", meta={"current": 20, "status": "Downloading files"}
)
download_info = remote_loader.download_to_directory(
temp_dir, api_source_config
)
if download_info.get("empty_result", False) or not download_info.get(
"files_downloaded", 0
):
logging.warning(f"No files were downloaded from {source_type}")
# Create empty result directly instead of calling a separate method
return {
"name": job_name,
"user": user,
"tokens": 0,
"type": source_type,
"source_config": api_source_config,
"directory_structure": "{}",
}
# Step 3: Use SimpleDirectoryReader to process downloaded files
self.update_state(
state="PROGRESS", meta={"current": 40, "status": "Processing files"}
)
reader = SimpleDirectoryReader(
input_dir=temp_dir,
recursive=True,
required_exts=[
".rst",
".md",
".pdf",
".txt",
".docx",
".csv",
".epub",
".html",
".mdx",
".json",
".xlsx",
".pptx",
".png",
".jpg",
".jpeg",
],
exclude_hidden=True,
file_metadata=metadata_from_filename,
)
raw_docs = reader.load_data()
directory_structure = getattr(reader, "directory_structure", {})
# Step 4: Process documents (chunking, embedding, etc.)
self.update_state(
state="PROGRESS", meta={"current": 60, "status": "Processing documents"}
)
chunker = Chunker(
chunking_strategy="classic_chunk",
max_tokens=MAX_TOKENS,
min_tokens=MIN_TOKENS,
duplicate_headers=False,
)
raw_docs = chunker.chunk(documents=raw_docs)
# Preserve source information in document metadata
for doc in raw_docs:
if hasattr(doc, "extra_info") and doc.extra_info:
source = doc.extra_info.get("source")
if source and os.path.isabs(source):
# Convert absolute path to relative path
doc.extra_info["source"] = os.path.relpath(
source, start=temp_dir
)
docs = [Document.to_langchain_format(raw_doc) for raw_doc in raw_docs]
if operation_mode == "upload":
id = ObjectId()
elif operation_mode == "sync":
if not doc_id or not ObjectId.is_valid(doc_id):
logging.error(
"Invalid doc_id provided for sync operation: %s", doc_id
)
raise ValueError("doc_id must be provided for sync operation.")
id = ObjectId(doc_id)
else:
raise ValueError(f"Invalid operation_mode: {operation_mode}")
vector_store_path = os.path.join(temp_dir, "vector_store")
os.makedirs(vector_store_path, exist_ok=True)
self.update_state(
state="PROGRESS", meta={"current": 80, "status": "Storing documents"}
)
embed_and_store_documents(docs, vector_store_path, id, self)
tokens = count_tokens_docs(docs)
# Step 6: Upload index files
file_data = {
"user": user,
"name": job_name,
"tokens": tokens,
"retriever": retriever,
"id": str(id),
"type": "connector:file",
"remote_data": json.dumps(
{"provider": source_type, **api_source_config}
),
"directory_structure": json.dumps(directory_structure),
"sync_frequency": sync_frequency,
}
if operation_mode == "sync":
file_data["last_sync"] = datetime.datetime.now()
else:
file_data["last_sync"] = datetime.datetime.now()
upload_index(vector_store_path, file_data)
# Ensure we mark the task as complete
self.update_state(
state="PROGRESS", meta={"current": 100, "status": "Complete"}
)
logging.info(f"Remote ingestion completed: {job_name}")
return {
"user": user,
"name": job_name,
"tokens": tokens,
"type": source_type,
"id": str(id),
"status": "complete",
}
except Exception as e:
logging.error(f"Error during remote ingestion: {e}", exc_info=True)
raise
def mcp_oauth(self, config: Dict[str, Any], user_id: str = None) -> Dict[str, Any]:
"""Worker to handle MCP OAuth flow asynchronously."""
logging.info(
"[MCP OAuth] Worker started for user_id=%s, config=%s", user_id, config
)
try:
import asyncio
from application.agents.tools.mcp_tool import MCPTool
task_id = self.request.id
logging.info("[MCP OAuth] Task ID: %s", task_id)
redis_client = get_redis_instance()
def update_status(status_data: Dict[str, Any]):
logging.info("[MCP OAuth] Updating status: %s", status_data)
status_key = f"mcp_oauth_status:{task_id}"
redis_client.setex(status_key, 600, json.dumps(status_data))
update_status(
{
"status": "in_progress",
"message": "Starting OAuth flow...",
"task_id": task_id,
}
)
tool_config = config.copy()
tool_config["oauth_task_id"] = task_id
logging.info("[MCP OAuth] Initializing MCPTool with config: %s", tool_config)
mcp_tool = MCPTool(tool_config, user_id)
async def run_oauth_discovery():
if not mcp_tool._client:
mcp_tool._setup_client()
return await mcp_tool._execute_with_client("list_tools")
update_status(
{
"status": "awaiting_redirect",
"message": "Waiting for OAuth redirect...",
"task_id": task_id,
}
)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
logging.info("[MCP OAuth] Starting event loop for OAuth discovery...")
tools_response = loop.run_until_complete(run_oauth_discovery())
logging.info(
"[MCP OAuth] Tools response after async call: %s", tools_response
)
status_key = f"mcp_oauth_status:{task_id}"
redis_status = redis_client.get(status_key)
if redis_status:
logging.info(
"[MCP OAuth] Redis status after async call: %s", redis_status
)
else:
logging.warning(
"[MCP OAuth] No Redis status found after async call for key: %s",
status_key,
)
tools = mcp_tool.get_actions_metadata()
update_status(
{
"status": "completed",
"message": f"OAuth completed successfully. Found {len(tools)} tools.",
"tools": tools,
"tools_count": len(tools),
"task_id": task_id,
}
)
logging.info(
"[MCP OAuth] OAuth flow completed successfully for task_id=%s", task_id
)
return {"success": True, "tools": tools, "tools_count": len(tools)}
except Exception as e:
error_msg = f"OAuth flow failed: {str(e)}"
logging.error(
"[MCP OAuth] Exception in OAuth discovery: %s", error_msg, exc_info=True
)
update_status(
{
"status": "error",
"message": error_msg,
"error": str(e),
"task_id": task_id,
}
)
return {"success": False, "error": error_msg}
finally:
logging.info("[MCP OAuth] Closing event loop for task_id=%s", task_id)
loop.close()
except Exception as e:
error_msg = f"Failed to initialize OAuth flow: {str(e)}"
logging.error(
"[MCP OAuth] Exception during initialization: %s", error_msg, exc_info=True
)
update_status(
{
"status": "error",
"message": error_msg,
"error": str(e),
"task_id": task_id,
}
)
return {"success": False, "error": error_msg}
def mcp_oauth_status(self, task_id: str) -> Dict[str, Any]:
"""Check the status of an MCP OAuth flow."""
redis_client = get_redis_instance()
status_key = f"mcp_oauth_status:{task_id}"
status_data = redis_client.get(status_key)
if status_data:
return json.loads(status_data)
return {"status": "not_found", "message": "Status not found"}

View File

@@ -1,75 +0,0 @@
name: docsgpt-oss
services:
frontend:
image: arc53/docsgpt-fe:develop
environment:
- VITE_API_HOST=http://localhost:7091
- VITE_API_STREAMING=$VITE_API_STREAMING
- VITE_GOOGLE_CLIENT_ID=$VITE_GOOGLE_CLIENT_ID
ports:
- "5173:5173"
depends_on:
- backend
backend:
user: root
image: arc53/docsgpt:develop
environment:
- API_KEY=$API_KEY
- EMBEDDINGS_KEY=$API_KEY
- LLM_PROVIDER=$LLM_PROVIDER
- LLM_NAME=$LLM_NAME
- CELERY_BROKER_URL=redis://redis:6379/0
- CELERY_RESULT_BACKEND=redis://redis:6379/1
- MONGO_URI=mongodb://mongo:27017/docsgpt
- CACHE_REDIS_URL=redis://redis:6379/2
- OPENAI_BASE_URL=$OPENAI_BASE_URL
ports:
- "7091:7091"
volumes:
- ../application/indexes:/app/indexes
- ../application/inputs:/app/inputs
- ../application/vectors:/app/vectors
depends_on:
- redis
- mongo
worker:
user: root
image: arc53/docsgpt:develop
command: celery -A application.app.celery worker -l INFO -B
environment:
- API_KEY=$API_KEY
- EMBEDDINGS_KEY=$API_KEY
- LLM_PROVIDER=$LLM_PROVIDER
- LLM_NAME=$LLM_NAME
- CELERY_BROKER_URL=redis://redis:6379/0
- CELERY_RESULT_BACKEND=redis://redis:6379/1
- MONGO_URI=mongodb://mongo:27017/docsgpt
- API_URL=http://backend:7091
- CACHE_REDIS_URL=redis://redis:6379/2
volumes:
- ../application/indexes:/app/indexes
- ../application/inputs:/app/inputs
- ../application/vectors:/app/vectors
depends_on:
- redis
- mongo
redis:
image: redis:6-alpine
ports:
- 6379:6379
mongo:
image: mongo:6
ports:
- 27017:27017
volumes:
- mongodb_data_container:/data/db
volumes:
mongodb_data_container:

View File

@@ -7,7 +7,6 @@ services:
environment:
- VITE_API_HOST=http://localhost:7091
- VITE_API_STREAMING=$VITE_API_STREAMING
- VITE_GOOGLE_CLIENT_ID=$VITE_GOOGLE_CLIENT_ID
ports:
- "5173:5173"
depends_on:
@@ -29,9 +28,9 @@ services:
ports:
- "7091:7091"
volumes:
- ../application/indexes:/app/indexes
- ../application/indexes:/app/application/indexes
- ../application/inputs:/app/inputs
- ../application/vectors:/app/vectors
- ../application/vectors:/app/application/vectors
depends_on:
- redis
- mongo
@@ -51,9 +50,9 @@ services:
- API_URL=http://backend:7091
- CACHE_REDIS_URL=redis://redis:6379/2
volumes:
- ../application/indexes:/app/indexes
- ../application/indexes:/app/application/indexes
- ../application/inputs:/app/inputs
- ../application/vectors:/app/vectors
- ../application/vectors:/app/application/vectors
depends_on:
- redis
- mongo

5417
docs/package-lock.json generated

File diff suppressed because it is too large Load Diff

View File

@@ -9,8 +9,8 @@
"@vercel/analytics": "^1.1.1",
"docsgpt-react": "^0.5.1",
"next": "^15.3.3",
"nextra": "^2.13.2",
"nextra-theme-docs": "^2.13.2",
"nextra": "^4.2.17",
"nextra-theme-docs": "^4.2.17",
"react": "^18.2.0",
"react-dom": "^18.2.0"
}

View File

@@ -2,13 +2,5 @@
"basics": {
"title": "🤖 Agent Basics",
"href": "/Agents/basics"
},
"api": {
"title": "🔌 Agent API",
"href": "/Agents/api"
},
"webhooks": {
"title": "🪝 Agent Webhooks",
"href": "/Agents/webhooks"
}
}
}

View File

@@ -1,227 +0,0 @@
---
title: Interacting with Agents via API
description: Learn how to programmatically interact with DocsGPT Agents using the streaming and non-streaming API endpoints.
---
import { Callout, Tabs } from 'nextra/components';
# Interacting with Agents via API
DocsGPT Agents can be accessed programmatically through a dedicated API, allowing you to integrate their specialized capabilities into your own applications, scripts, and workflows. This guide covers the two primary methods for interacting with an agent: the streaming API for real-time responses and the non-streaming API for a single, consolidated answer.
When you use an API key generated for a specific agent, you do not need to pass `prompt`, `tools` etc. The agent's configuration (including its prompt, selected tools, and knowledge sources) is already associated with its unique API key.
### API Endpoints
- **Non-Streaming:** `http://localhost:7091/api/answer`
- **Streaming:** `http://localhost:7091/stream`
<Callout type="info">
For DocsGPT Cloud, use `https://gptcloud.arc53.com/` as the base URL.
</Callout>
For more technical details, you can explore the API swagger documentation available for the cloud version or your local instance.
---
## Non-Streaming API (`/api/answer`)
This is a standard synchronous endpoint. It waits for the agent to fully process the request and returns a single JSON object with the complete answer. This is the simplest method and is ideal for backend processes where a real-time feed is not required.
### Request
- **Endpoint:** `/api/answer`
- **Method:** `POST`
- **Payload:**
- `question` (string, required): The user's query or input for the agent.
- `api_key` (string, required): The unique API key for the agent you wish to interact with.
- `history` (string, optional): A JSON string representing the conversation history, e.g., `[{\"prompt\": \"first question\", \"answer\": \"first answer\"}]`.
### Response
A single JSON object containing:
- `answer`: The complete, final answer from the agent.
- `sources`: A list of sources the agent consulted.
- `conversation_id`: The unique ID for the interaction.
### Examples
<Tabs items={['cURL', 'Python', 'JavaScript']}>
<Tabs.Tab>
```bash
curl -X POST http://localhost:7091/api/answer \
-H "Content-Type: application/json" \
-d '{
"question": "your question here",
"api_key": "your_agent_api_key"
}'
```
</Tabs.Tab>
<Tabs.Tab>
```python
import requests
API_URL = "http://localhost:7091/api/answer"
API_KEY = "your_agent_api_key"
QUESTION = "your question here"
response = requests.post(
API_URL,
json={"question": QUESTION, "api_key": API_KEY}
)
if response.status_code == 200:
print(response.json())
else:
print(f"Error: {response.status_code}")
print(response.text)
```
</Tabs.Tab>
<Tabs.Tab>
```javascript
const apiUrl = 'http://localhost:7091/api/answer';
const apiKey = 'your_agent_api_key';
const question = 'your question here';
async function getAnswer() {
try {
const response = await fetch(apiUrl, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({ question, api_key: apiKey }),
});
if (!response.ok) {
throw new Error(`HTTP error! Status: ${response.status}`);
}
const data = await response.json();
console.log(data);
} catch (error) {
console.error("Failed to fetch answer:", error);
}
}
getAnswer();
```
</Tabs.Tab>
</Tabs>
---
## Streaming API (`/stream`)
The `/stream` endpoint uses Server-Sent Events (SSE) to push data in real-time. This is ideal for applications where you want to display the response as it's being generated, such as in a live chatbot interface.
### Request
- **Endpoint:** `/stream`
- **Method:** `POST`
- **Payload:** Same as the non-streaming API.
### Response (SSE Stream)
The stream consists of multiple `data:` events, each containing a JSON object. Your client should listen for these events and process them based on their `type`.
**Event Types:**
- `answer`: A chunk of the agent's final answer.
- `source`: A document or source used by the agent.
- `thought`: A reasoning step from the agent (for ReAct agents).
- `id`: The unique `conversation_id` for the interaction.
- `error`: An error message.
- `end`: A final message indicating the stream has concluded.
### Examples
<Tabs items={['cURL', 'Python', 'JavaScript']}>
<Tabs.Tab>
```bash
curl -X POST http://localhost:7091/stream \
-H "Content-Type: application/json" \
-H "Accept: text/event-stream" \
-d '{
"question": "your question here",
"api_key": "your_agent_api_key"
}'
```
</Tabs.Tab>
<Tabs.Tab>
```python
import requests
import json
API_URL = "http://localhost:7091/stream"
payload = {
"question": "your question here",
"api_key": "your_agent_api_key"
}
with requests.post(API_URL, json=payload, stream=True) as r:
for line in r.iter_lines():
if line:
decoded_line = line.decode('utf-8')
if decoded_line.startswith('data: '):
try:
data = json.loads(decoded_line[6:])
print(data)
except json.JSONDecodeError:
pass
```
</Tabs.Tab>
<Tabs.Tab>
```javascript
const apiUrl = 'http://localhost:7091/stream';
const apiKey = 'your_agent_api_key';
const question = 'your question here';
async function getStream() {
try {
const response = await fetch(apiUrl, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'Accept': 'text/event-stream'
},
// Corrected line: 'apiKey' is changed to 'api_key'
body: JSON.stringify({ question, api_key: apiKey }),
});
if (!response.ok) {
throw new Error(`HTTP error! Status: ${response.status}`);
}
const reader = response.body.getReader();
const decoder = new TextDecoder();
while (true) {
const { done, value } = await reader.read();
if (done) break;
const chunk = decoder.decode(value, { stream: true });
// Note: This parsing method assumes each chunk contains whole lines.
// For a more robust production implementation, buffer the chunks
// and process them line by line.
const lines = chunk.split('\n');
for (const line of lines) {
if (line.startsWith('data: ')) {
try {
const data = JSON.parse(line.substring(6));
console.log(data);
} catch (e) {
console.error("Failed to parse JSON from SSE event:", e);
}
}
}
}
} catch (error) {
console.error("Failed to fetch stream:", error);
}
}
getStream();
```
</Tabs.Tab>
</Tabs>

View File

@@ -1,152 +0,0 @@
---
title: Triggering Agents with Webhooks
description: Learn how to automate and integrate DocsGPT Agents using webhooks for asynchronous task execution.
---
import { Callout, Tabs } from 'nextra/components';
# Triggering Agents with Webhooks
Agent Webhooks provide a powerful mechanism to trigger an agent's execution from external systems. Unlike the direct API which provides an immediate response, webhooks are designed for **asynchronous** operations. When you call a webhook, DocsGPT enqueues the agent's task for background processing and immediately returns a `task_id`. You then use this ID to poll for the result.
This workflow is ideal for integrating with services that expect a quick initial response (e.g., form submissions) or for triggering long-running tasks without tying up a client connection.
Each agent has its own unique webhook URL, which can be generated from the agent's edit page in the DocsGPT UI. This URL includes a secure token for authentication.
### API Endpoints
- **Webhook URL:** `http://localhost:7091/api/webhooks/agents/{AGENT_WEBHOOK_TOKEN}`
- **Task Status URL:** `http://localhost:7091/api/task_status`
<Callout type="info">
For DocsGPT Cloud, use `https://gptcloud.arc53.com/` as the base URL.
</Callout>
For more technical details, you can explore the API swagger documentation available for the cloud version or your local instance.
---
## The Webhook Workflow
The process involves two main steps: triggering the task and polling for the result.
### Step 1: Trigger the Webhook
Send an HTTP `POST` request to the agent's unique webhook URL with the required payload. The structure of this payload should match what the agent's prompt and tools are designed to handle.
- **Method:** `POST`
- **Response:** A JSON object with a `task_id`. `{"task_id": "a1b2c3d4-e5f6-..."}`
<Tabs items={['cURL', 'Python', 'JavaScript']}>
<Tabs.Tab>
```bash
curl -X POST \
http://localhost:7091/api/webhooks/agents/your_webhook_token \
-H "Content-Type: application/json" \
-d '{"question": "Your message to agent"}'
```
</Tabs.Tab>
<Tabs.Tab>
```python
import requests
WEBHOOK_URL = "http://localhost:7091/api/webhooks/agents/your_webhook_token"
payload = {"question": "Your message to agent"}
try:
response = requests.post(WEBHOOK_URL, json=payload)
response.raise_for_status()
task_id = response.json().get("task_id")
print(f"Task successfully created with ID: {task_id}")
except requests.exceptions.RequestException as e:
print(f"Error triggering webhook: {e}")
```
</Tabs.Tab>
<Tabs.Tab>
```javascript
const webhookUrl = 'http://localhost:7091/api/webhooks/agents/your_webhook_token';
const payload = { question: 'Your message to agent' };
async function triggerWebhook() {
try {
const response = await fetch(webhookUrl, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify(payload)
});
if (!response.ok) throw new Error(`HTTP error! ${response.status}`);
const data = await response.json();
console.log(`Task successfully created with ID: ${data.task_id}`);
return data.task_id;
} catch (error) {
console.error('Error triggering webhook:', error);
}
}
triggerWebhook();
```
</Tabs.Tab>
</Tabs>
### Step 2: Poll for the Result
Once you have the `task_id`, periodically send a `GET` request to the `/api/task_status` endpoint until the task `status` is `SUCCESS` or `FAILURE`.
- **`status`**: The current state of the task (`PENDING`, `STARTED`, `SUCCESS`, `FAILURE`).
- **`result`**: The final output from the agent, available when the status is `SUCCESS` or `FAILURE`.
<Tabs items={['cURL', 'Python', 'JavaScript']}>
<Tabs.Tab>
```bash
# Replace the task_id with the one you received
curl http://localhost:7091/api/task_status?task_id=YOUR_TASK_ID
```
</Tabs.Tab>
<Tabs.Tab>
```python
import requests
import time
STATUS_URL = "http://localhost:7091/api/task_status"
task_id = "YOUR_TASK_ID"
while True:
response = requests.get(STATUS_URL, params={"task_id": task_id})
data = response.json()
status = data.get("status")
print(f"Current task status: {status}")
if status in ["SUCCESS", "FAILURE"]:
print("Final Result:")
print(data.get("result"))
break
time.sleep(2)
```
</Tabs.Tab>
<Tabs.Tab>
```javascript
const statusUrl = 'http://localhost:7091/api/task_status';
const taskId = 'YOUR_TASK_ID';
const sleep = (ms) => new Promise(resolve => setTimeout(resolve, ms));
async function pollForResult() {
while (true) {
const response = await fetch(`${statusUrl}?task_id=${taskId}`);
const data = await response.json();
const status = data.status;
console.log(`Current task status: ${status}`);
if (status === 'SUCCESS' || status === 'FAILURE') {
console.log('Final Result:', data.result);
break;
}
await sleep(2000);
}
}
pollForResult();
```
</Tabs.Tab>
</Tabs>

View File

@@ -37,33 +37,33 @@ While modifying `settings.py` offers more flexibility, it's generally recommende
Here are some of the most fundamental settings you'll likely want to configure:
- **`LLM_PROVIDER`**: This setting determines which Large Language Model (LLM) provider DocsGPT will use. It tells DocsGPT which API to interact with.
- **`LLM_PROVIDER`**: This setting determines which Large Language Model (LLM) provider DocsGPT will use. It tells DocsGPT which API to interact with.
- **Common values:**
- `docsgpt`: Use the DocsGPT Public API Endpoint (simple and free, as offered in `setup.sh` option 1).
- `openai`: Use OpenAI's API (requires an API key).
- `google`: Use Google's Vertex AI or Gemini models.
- `anthropic`: Use Anthropic's Claude models.
- `groq`: Use Groq's models.
- `huggingface`: Use HuggingFace Inference API.
- `azure_openai`: Use Azure OpenAI Service.
- `openai` (when using local inference engines like Ollama, Llama.cpp, TGI, etc.): This signals DocsGPT to use an OpenAI-compatible API format, even if the actual LLM is running locally.
- **Common values:**
- `docsgpt`: Use the DocsGPT Public API Endpoint (simple and free, as offered in `setup.sh` option 1).
- `openai`: Use OpenAI's API (requires an API key).
- `google`: Use Google's Vertex AI or Gemini models.
- `anthropic`: Use Anthropic's Claude models.
- `groq`: Use Groq's models.
- `huggingface`: Use HuggingFace Inference API.
- `azure_openai`: Use Azure OpenAI Service.
- `openai` (when using local inference engines like Ollama, Llama.cpp, TGI, etc.): This signals DocsGPT to use an OpenAI-compatible API format, even if the actual LLM is running locally.
- **`LLM_NAME`**: Specifies the specific model to use from the chosen LLM provider. The available models depend on the `LLM_PROVIDER` you've selected.
- **`LLM_NAME`**: Specifies the specific model to use from the chosen LLM provider. The available models depend on the `LLM_PROVIDER` you've selected.
- **Examples:**
- For `LLM_PROVIDER=openai`: `gpt-4o`
- For `LLM_PROVIDER=google`: `gemini-2.0-flash`
- For local models (e.g., Ollama): `llama3.2:1b` (or any model name available in your setup).
- **Examples:**
- For `LLM_PROVIDER=openai`: `gpt-4o`
- For `LLM_PROVIDER=google`: `gemini-2.0-flash`
- For local models (e.g., Ollama): `llama3.2:1b` (or any model name available in your setup).
- **`EMBEDDINGS_NAME`**: This setting defines which embedding model DocsGPT will use to generate vector embeddings for your documents. Embeddings are numerical representations of text that allow DocsGPT to understand the semantic meaning of your documents for efficient search and retrieval.
- **`EMBEDDINGS_NAME`**: This setting defines which embedding model DocsGPT will use to generate vector embeddings for your documents. Embeddings are numerical representations of text that allow DocsGPT to understand the semantic meaning of your documents for efficient search and retrieval.
- **Default value:** `huggingface_sentence-transformers/all-mpnet-base-v2` (a good general-purpose embedding model).
- **Other options:** You can explore other embedding models from Hugging Face Sentence Transformers or other providers if needed.
- **Default value:** `huggingface_sentence-transformers/all-mpnet-base-v2` (a good general-purpose embedding model).
- **Other options:** You can explore other embedding models from Hugging Face Sentence Transformers or other providers if needed.
- **`API_KEY`**: Required for most cloud-based LLM providers. This is your authentication key to access the LLM provider's API. You'll need to obtain this key from your chosen provider's platform.
- **`API_KEY`**: Required for most cloud-based LLM providers. This is your authentication key to access the LLM provider's API. You'll need to obtain this key from your chosen provider's platform.
- **`OPENAI_BASE_URL`**: Specifically used when `LLM_PROVIDER` is set to `openai` but you are connecting to a local inference engine (like Ollama, Llama.cpp, etc.) that exposes an OpenAI-compatible API. This setting tells DocsGPT where to find your local LLM server.
- **`OPENAI_BASE_URL`**: Specifically used when `LLM_PROVIDER` is set to `openai` but you are connecting to a local inference engine (like Ollama, Llama.cpp, etc.) that exposes an OpenAI-compatible API. This setting tells DocsGPT where to find your local LLM server.
## Configuration Examples
@@ -93,82 +93,51 @@ OPENAI_BASE_URL=http://host.docker.internal:11434/v1 # Default Ollama API URL wi
EMBEDDINGS_NAME=huggingface_sentence-transformers/all-mpnet-base-v2 # You can also run embeddings locally if needed
```
In this case, even though you are using Ollama locally, `LLM_PROVIDER` is set to `openai` because Ollama (and many other local inference engines) are designed to be API-compatible with OpenAI. `OPENAI_BASE_URL` points DocsGPT to the local Ollama server.
In this case, even though you are using Ollama locally, `LLM_PROVIDER` is set to `openai` because Ollama (and many other local inference engines) are designed to be API-compatible with OpenAI. `OPENAI_BASE_URL` points DocsGPT to the local Ollama server.
## Authentication Settings
DocsGPT includes a JWT (JSON Web Token) based authentication feature for managing sessions or securing local deployments while allowing access.
### `AUTH_TYPE` Overview
- **`AUTH_TYPE`**: This setting in your `.env` file or `settings.py` determines the authentication method.
- **Possible values:**
- `None` (or not set): No authentication is used.
- `simple_jwt`: A single, long-lived JWT token is generated and used for all authenticated requests. This is useful for securing a local deployment with a shared secret.
- `session_jwt`: Unique JWT tokens are generated for sessions, typically for individual users or temporary access.
- If `AUTH_TYPE` is set to `simple_jwt` or `session_jwt`, then a `JWT_SECRET_KEY` is required.
- **`JWT_SECRET_KEY`**: This is a crucial secret key used to sign and verify JWTs.
- It can be set directly in your `.env` file or `settings.py`.
- **Automatic Key Generation**: If `AUTH_TYPE` is `simple_jwt` or `session_jwt` and `JWT_SECRET_KEY` is _not_ set in your environment variables or `settings.py`, DocsGPT will attempt to:
1. Read the key from a file named `.jwt_secret_key` in the project's root directory.
2. If the file doesn't exist, it will generate a new 32-byte random key, save it to `.jwt_secret_key`, and use it for the session. This ensures that the key persists across application restarts.
- **Security Note**: It's vital to keep this key secure. If you set it manually, choose a strong, random string.
The `AUTH_TYPE` setting in your `.env` file or `settings.py` determines the authentication method used by DocsGPT. This allows you to control how users authenticate with your DocsGPT instance.
**How it works:**
| Value | Description |
| ------------- | ------------------------------------------------------------------------------------------- |
| `None` | No authentication is used. Anyone can access the app. |
| `simple_jwt` | A single, long-lived JWT token is generated at startup. All requests use this shared token. |
| `session_jwt` | Unique JWT tokens are generated for each session/user. |
- When `AUTH_TYPE` is set to `simple_jwt`, a token is generated at startup (if not already present or configured) and printed to the console. This token should be included in the `Authorization` header of your API requests as a Bearer token (e.g., `Authorization: Bearer YOUR_SIMPLE_JWT_TOKEN`).
- When `AUTH_TYPE` is set to `session_jwt`:
- Clients can request a new token from the `/api/generate_token` endpoint.
- This token should then be included in the `Authorization` header for subsequent requests.
- The backend verifies the JWT token provided in the `Authorization` header for protected routes.
- The `/api/config` endpoint can be used to check the current `auth_type` and whether authentication is required.
#### How to Configure
**Frontend Token Input for `simple_jwt`:**
Add the following to your `.env` file (or set in `settings.py`):
```env
# No authentication (default)
AUTH_TYPE=None
# OR: Simple JWT (shared token)
AUTH_TYPE=simple_jwt
JWT_SECRET_KEY=your_secret_key_here
# OR: Session JWT (per-user/session tokens)
AUTH_TYPE=session_jwt
JWT_SECRET_KEY=your_secret_key_here
```
- If `AUTH_TYPE` is set to `simple_jwt` or `session_jwt`, a `JWT_SECRET_KEY` is required.
- If `JWT_SECRET_KEY` is not set, DocsGPT will generate one and store it in `.jwt_secret_key` in the project root.
#### How Each Method Works
- **None**: No authentication. All API and UI access is open.
- **simple_jwt**:
- A single JWT token is generated at startup and printed to the console.
- Use this token in the `Authorization` header for all API requests:
```http
Authorization: Bearer <SIMPLE_JWT_TOKEN>
```
- The frontend will prompt for this token if not already set.
- **session_jwt**:
- Clients can request a new token from `/api/generate_token`.
- Use the received token in the `Authorization` header for subsequent requests.
- Each user/session gets a unique token.
#### Security Notes
- Always keep your `JWT_SECRET_KEY` secure and private.
- If you set it manually, use a strong, random string.
- If not set, DocsGPT will generate a secure key and persist it in `.jwt_secret_key`.
#### Checking Current Auth Type
- Use the `/api/config` endpoint to check the current `auth_type` and whether authentication is required.
#### Frontend Token Input for `simple_jwt`
If you have configured `AUTH_TYPE=simple_jwt`, the DocsGPT frontend will prompt you to enter the JWT token if it's not already set or is invalid. Paste the `SIMPLE_JWT_TOKEN` (printed to your console when the backend starts) into this field to access the application.
<img
src="/jwt-input.png"
alt="Frontend prompt for JWT Token"
style={{
width: "500px",
maxWidth: "100%",
display: "block",
margin: "1em auto",
}}
<img
src="/jwt-input.png"
alt="Frontend prompt for JWT Token"
style={{
width: '500px',
maxWidth: '100%',
display: 'block',
margin: '1em auto'
}}
/>
If you have configured `AUTH_TYPE=simple_jwt`, the DocsGPT frontend will prompt you to enter the JWT token if it's not already set or is invalid. You'll need to paste the `SIMPLE_JWT_TOKEN` (which is printed to your console when the backend starts) into this field to access the application.
## Exploring More Settings
These are just the basic settings to get you started. The `settings.py` file contains many more advanced options that you can explore to further customize DocsGPT, such as:
@@ -178,4 +147,4 @@ These are just the basic settings to get you started. The `settings.py` file con
- Cache settings (`CACHE_REDIS_URL`)
- And many more!
For a complete list of available settings and their descriptions, refer to the `settings.py` file in `application/core`. Remember to restart your Docker containers after making changes to your `.env` file or `settings.py` for the changes to take effect.
For a complete list of available settings and their descriptions, refer to the `settings.py` file in `application/core`. Remember to restart your Docker containers after making changes to your `.env` file or `settings.py` for the changes to take effect.

View File

@@ -1,6 +0,0 @@
{
"google-drive-connector": {
"title": "🔗 Google Drive",
"href": "/Guides/Integrations/google-drive-connector"
}
}

View File

@@ -1,212 +0,0 @@
---
title: Google Drive Connector
description: Connect your Google Drive as an external knowledge base to upload and process files directly from your Google Drive account.
---
import { Callout } from 'nextra/components'
import { Steps } from 'nextra/components'
# Google Drive Connector
The Google Drive Connector allows you to seamlessly connect your Google Drive account as an external knowledge base. This integration enables you to upload and process files directly from your Google Drive without manually downloading and uploading them to DocsGPT.
## Features
- **Direct File Access**: Browse and select files directly from your Google Drive
- **Comprehensive File Support**: Supports all major document formats including:
- Google Workspace files (Docs, Sheets, Slides)
- Microsoft Office files (.docx, .xlsx, .pptx, .doc, .ppt, .xls)
- PDF documents
- Text files (.txt, .md, .rst, .html, .rtf)
- Data files (.csv, .json)
- Image files (.png, .jpg, .jpeg)
- E-books (.epub)
- **Secure Authentication**: Uses OAuth 2.0 for secure access to your Google Drive
- **Real-time Sync**: Process files directly from Google Drive without local downloads
<Callout type="info" emoji="">
The Google Drive Connector requires proper configuration of Google API credentials. Follow the setup instructions below to enable this feature.
</Callout>
## Prerequisites
Before setting up the Google Drive Connector, you'll need:
1. A Google Cloud Platform (GCP) project
2. Google Drive API enabled
3. OAuth 2.0 credentials configured
4. DocsGPT instance with proper environment variables
## Setup Instructions
<Steps>
### Step 1: Create a Google Cloud Project
1. Go to the [Google Cloud Console](https://console.cloud.google.com/)
2. Create a new project or select an existing one
3. Note down your Project ID for later use
### Step 2: Enable Google Drive API
1. In the Google Cloud Console, navigate to **APIs & Services** > **Library**
2. Search for "Google Drive API"
3. Click on "Google Drive API" and click **Enable**
### Step 3: Create OAuth 2.0 Credentials
1. Go to **APIs & Services** > **Credentials**
2. Click **Create Credentials** > **OAuth client ID**
3. If prompted, configure the OAuth consent screen:
- Choose **External** user type (unless you're using Google Workspace)
- Fill in the required fields (App name, User support email, Developer contact)
- Add your domain to **Authorized domains** if deploying publicly
4. For Application type, select **Web application**
5. Add your DocsGPT frontend URL to **Authorized JavaScript origins**:
- For local development: `http://localhost:3000`
- For production: `https://yourdomain.com`
6. Add your DocsGPT callback URL to **Authorized redirect URIs**:
- For local development: `http://localhost:7091/api/connectors/callback?provider=google_drive`
- For production: `https://yourdomain.com/api/connectors/callback?provider=google_drive`
7. Click **Create** and note down the **Client ID** and **Client Secret**
### Step 4: Configure Backend Environment Variables
Add the following environment variables to your backend configuration:
**For Docker deployment**, add to your `.env` file in the root directory:
```env
# Google Drive Connector Configuration
GOOGLE_CLIENT_ID=your_google_client_id_here
GOOGLE_CLIENT_SECRET=your_google_client_secret_here
```
**For manual deployment**, set these environment variables in your system or application configuration.
### Step 5: Configure Frontend Environment Variables
Add the following environment variables to your frontend `.env` file:
```env
# Google Drive Frontend Configuration
VITE_GOOGLE_CLIENT_ID=your_google_client_id_here
```
<Callout type="warning" emoji="⚠️">
Make sure to use the same Google Client ID in both backend and frontend configurations.
</Callout>
### Step 6: Restart Your Application
After configuring the environment variables:
1. **For Docker**: Restart your Docker containers
```bash
docker-compose down
docker-compose up -d
```
2. **For manual deployment**: Restart both backend and frontend services
</Steps>
## Using the Google Drive Connector
Once configured, you can use the Google Drive Connector to upload files:
<Steps>
### Step 1: Access the Upload Interface
1. Navigate to the DocsGPT interface
2. Go to the upload/training section
3. You should now see "Google Drive" as an available upload option
### Step 2: Connect Your Google Account
1. Select "Google Drive" as your upload method
2. Click "Connect to Google Drive"
3. You'll be redirected to Google's OAuth consent screen
4. Grant the necessary permissions to DocsGPT
5. You'll be redirected back to DocsGPT with a successful connection
### Step 3: Select Files
1. Once connected, click "Select Files"
2. The Google Drive picker will open
3. Browse your Google Drive and select the files you want to process
4. Click "Select" to confirm your choices
### Step 4: Process Files
1. Review your selected files
2. Click "Train" or "Upload" to process the files
3. DocsGPT will download and process the files from your Google Drive
4. Once processing is complete, the files will be available in your knowledge base
</Steps>
## Supported File Types
The Google Drive Connector supports the following file types:
| File Type | Extensions | Description |
|-----------|------------|-------------|
| **Google Workspace** | - | Google Docs, Sheets, Slides (automatically converted) |
| **Microsoft Office** | .docx, .xlsx, .pptx | Modern Office formats |
| **Legacy Office** | .doc, .ppt, .xls | Older Office formats |
| **PDF Documents** | .pdf | Portable Document Format |
| **Text Files** | .txt, .md, .rst, .html, .rtf | Various text formats |
| **Data Files** | .csv, .json | Structured data formats |
| **Images** | .png, .jpg, .jpeg | Image files (with OCR if enabled) |
| **E-books** | .epub | Electronic publication format |
## Troubleshooting
### Common Issues
**"Google Drive option not appearing"**
- Verify that `VITE_GOOGLE_CLIENT_ID` is set in frontend environment
- Check that `VITE_GOOGLE_CLIENT_ID` environment variable is present in your frontend configuration
- Check browser console for any JavaScript errors
- Ensure the frontend has been restarted after adding environment variables
**"Authentication failed"**
- Verify that your OAuth 2.0 credentials are correctly configured
- Check that the redirect URI `http://<your-domain>/api/connectors/callback?provider=google_drive` is correctly added in GCP console
- Ensure the Google Drive API is enabled in your GCP project
**"Permission denied" errors**
- Verify that the OAuth consent screen is properly configured
- Check that your Google account has access to the files you're trying to select
- Ensure the required scopes are granted during authentication
**"Files not processing"**
- Check that the backend environment variables are correctly set
- Verify that the OAuth credentials have the necessary permissions
- Check the backend logs for any error messages
### Environment Variable Checklist
**Backend (.env in root directory):**
- ✅ `GOOGLE_CLIENT_ID`
- ✅ `GOOGLE_CLIENT_SECRET`
**Frontend (.env in frontend directory):**
- ✅ `VITE_GOOGLE_CLIENT_ID`
### Security Considerations
- Keep your Google Client Secret secure and never expose it in frontend code
- Regularly rotate your OAuth credentials
- Use HTTPS in production to protect authentication tokens
- Ensure proper OAuth consent screen configuration for production use
<Callout type="tip" emoji="💡">
For production deployments, make sure to add your actual domain to the OAuth consent screen and authorized origins/redirect URIs.
</Callout>

View File

@@ -20,8 +20,5 @@
"Architecture": {
"title": "🏗️ Architecture",
"href": "/Guides/Architecture"
},
"Integrations": {
"title": "🔗 Integrations"
}
}

View File

@@ -60,7 +60,7 @@ const config = {
GitHub
</a>
{' | '}
<a href="https://blog.docsgpt.cloud/" target="_blank">
<a href="https://www.blog.docsgpt.cloud/" target="_blank">
Blog
</a>
</div>

View File

@@ -5388,9 +5388,9 @@
}
},
"node_modules/dompurify": {
"version": "3.2.4",
"resolved": "https://registry.npmjs.org/dompurify/-/dompurify-3.2.4.tgz",
"integrity": "sha512-ysFSFEDVduQpyhzAob/kkuJjf5zWkZD8/A9ywSp1byueyuCfHamrCBa14/Oc2iiB0e51B+NpxSl5gmzn+Ms/mg==",
"version": "3.2.6",
"resolved": "https://registry.npmjs.org/dompurify/-/dompurify-3.2.6.tgz",
"integrity": "sha512-/2GogDQlohXPZe6D6NOgQvXLPSYBqIWMnZ8zzOhn09REE4eyAzb+Hed3jhoM9OkuaJ8P6ZGTTVWQKAi8ieIzfQ==",
"license": "(MPL-2.0 OR Apache-2.0)",
"optionalDependencies": {
"@types/trusted-types": "^2.0.7"

View File

@@ -5,8 +5,6 @@
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0,viewport-fit=cover" />
<meta name="apple-mobile-web-app-capable" content="yes">
<meta name="theme-color" content="#fbfbfb" media="(prefers-color-scheme: light)" />
<meta name="theme-color" content="#161616" media="(prefers-color-scheme: dark)" />
<title>DocsGPT</title>
<link rel="shortcut icon" type="image/x-icon" href="/favicon.ico" />
</head>

File diff suppressed because it is too large Load Diff

View File

@@ -19,20 +19,19 @@
]
},
"dependencies": {
"@reduxjs/toolkit": "^2.8.2",
"@reduxjs/toolkit": "^2.5.1",
"chart.js": "^4.4.4",
"clsx": "^2.1.1",
"copy-to-clipboard": "^3.3.3",
"i18next": "^24.2.0",
"i18next-browser-languagedetector": "^8.0.2",
"lodash": "^4.17.21",
"mermaid": "^11.6.0",
"prop-types": "^15.8.1",
"react": "^19.1.0",
"react": "^18.2.0",
"react-chartjs-2": "^5.3.0",
"react-dom": "^19.0.0",
"react-dropzone": "^14.3.8",
"react-google-drive-picker": "^1.2.2",
"react-copy-to-clipboard": "^5.1.0",
"react-dom": "^18.3.1",
"react-dropzone": "^14.3.5",
"react-helmet": "^6.1.0",
"react-i18next": "^15.4.0",
"react-markdown": "^9.0.1",
"react-redux": "^9.2.0",
@@ -40,19 +39,18 @@
"react-syntax-highlighter": "^15.6.1",
"rehype-katex": "^7.0.1",
"remark-gfm": "^4.0.0",
"remark-math": "^6.0.0",
"tailwind-merge": "^3.3.1"
"remark-math": "^6.0.0"
},
"devDependencies": {
"@tailwindcss/postcss": "^4.1.10",
"@types/lodash": "^4.17.20",
"@types/mermaid": "^9.1.0",
"@types/react": "^19.1.8",
"@types/react-dom": "^19.0.0",
"@types/react": "^18.0.27",
"@types/react-dom": "^18.3.0",
"@types/react-helmet": "^6.1.11",
"@types/react-syntax-highlighter": "^15.5.13",
"@typescript-eslint/eslint-plugin": "^5.51.0",
"@typescript-eslint/parser": "^5.62.0",
"@vitejs/plugin-react": "^4.3.4",
"autoprefixer": "^10.4.13",
"eslint": "^8.57.1",
"eslint-config-prettier": "^10.1.5",
"eslint-config-standard-with-typescript": "^34.0.0",
@@ -66,8 +64,8 @@
"lint-staged": "^15.3.0",
"postcss": "^8.4.49",
"prettier": "^3.5.3",
"prettier-plugin-tailwindcss": "^0.6.13",
"tailwindcss": "^4.1.11",
"prettier-plugin-tailwindcss": "^0.6.11",
"tailwindcss": "^3.4.17",
"typescript": "^5.8.3",
"vite": "^6.3.5",
"vite-plugin-svgr": "^4.3.0"

View File

@@ -1,5 +1,6 @@
module.exports = {
plugins: {
'@tailwindcss/postcss': {},
tailwindcss: {},
autoprefixer: {},
},
}

View File

@@ -1 +0,0 @@
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 122.88 122.88"><defs><style>.a{fill:#d53;}.b{fill:#fff;}.c{fill:#ddd;}.d{fill:#fc0;}.e{fill:#6b5;}.f{fill:#4a4;}.g{fill:#148;}</style></defs><title>duckduckgo</title><path class="a" d="M122.88,61.44a61.44,61.44,0,1,0-61.44,61.44,61.44,61.44,0,0,0,61.44-61.44Z"/><path class="b" d="M114.37,61.44a52.92,52.92,0,1,0-15.5,37.43,52.76,52.76,0,0,0,15.5-37.43Zm-13.12-39.8A56.29,56.29,0,1,1,61.44,5.15a56.12,56.12,0,0,1,39.81,16.49Z"/><path class="c" d="M43.24,30.15C26.17,34.13,32.43,58,32.43,58l10.81,52.9,4,1.71-4-82.49Zm-4-10.24H34.7L41,22.19s-6.26,0-6.26,4C48.36,25.6,54.61,29,54.61,29l-15.36-9.1Zm0,0Z"/><path class="b" d="M75.66,115.48S62,93.87,62,79.64c0-26.73,17.63-4,17.63-25S62,28.44,62,28.44c-8.53-10.8-25-8.53-25-8.53l4,2.28s-4,1.13-5.12,2.27,10.81-1.7,15.93,2.85C30.72,29,34.13,46.08,34.13,46.08l11.95,68.27,29.58,1.13Zm0,0Z"/><path class="d" d="M75.66,60.87l21.62-5.69C116.62,58,80.78,68.84,78.51,68.27c-17.07-2.85-12,11.37,8.53,6.82s5.12,11.38-13.65,5.12c-26.74-7.39-12.52-20.48,2.27-19.34Z"/><path class="e" d="M70,105.81l1.14-1.7c12.52,4.55,13.09,6.25,12.52-5.12s0-11.38-13.09-1.71c0-2.84-7.39-1.71-8.53,0-11.95-5.12-13.09-6.83-12.52,1.14,1.14,16.5.57,13.65,11.95,8l8.53-.57Zm0,0Z"/><path class="f" d="M60.87,99.56v6.82c.57,1.14,9.67,1.14,9.67-1.14s-4.55,1.71-7.39.57S62,98.42,62,98.42l-1.14,1.14Zm0,0Z"/><path class="g" d="M48.36,43.24c-2.85-3.42-10.24-.57-8.54,4,.57-2.28,4.55-5.69,8.54-4Zm18.2,0c.57-3.42,6.26-4,8-.57a8,8,0,0,0-8,.57Zm-18.77,9.1a1.14,1.14,0,1,1,0,.57v-.57Zm-4.55,2.27a4,4,0,1,0,0-.57v.57Zm29.58-4a1.14,1.14,0,1,1,0,.57v-.57ZM69.4,52.91a3.42,3.42,0,1,0,0-.57v.57Zm0,0Z"/></svg>

Before

Width:  |  Height:  |  Size: 1.6 KiB

View File

@@ -1,4 +0,0 @@
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" width="64" height="64" color="#000000" fill="none">
<path d="M3.49994 11.7501L11.6717 3.57855C12.7762 2.47398 14.5672 2.47398 15.6717 3.57855C16.7762 4.68312 16.7762 6.47398 15.6717 7.57855M15.6717 7.57855L9.49994 13.7501M15.6717 7.57855C16.7762 6.47398 18.5672 6.47398 19.6717 7.57855C20.7762 8.68312 20.7762 10.474 19.6717 11.5785L12.7072 18.543C12.3167 18.9335 12.3167 19.5667 12.7072 19.9572L13.9999 21.2499" stroke="currentColor" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"></path>
<path d="M17.4999 9.74921L11.3282 15.921C10.2237 17.0255 8.43272 17.0255 7.32823 15.921C6.22373 14.8164 6.22373 13.0255 7.32823 11.921L13.4999 5.74939" stroke="currentColor" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"></path>
</svg>

Before

Width:  |  Height:  |  Size: 831 B

Binary file not shown.

View File

@@ -19,9 +19,9 @@ export default function Hero({
}>;
return (
<div className="text-black-1000 dark:text-bright-gray flex h-full w-full flex-col items-center justify-between">
<div className="flex h-full w-full flex-col items-center justify-between text-black-1000 dark:text-bright-gray">
{/* Header Section */}
<div className="flex grow flex-col items-center justify-center pt-8 md:pt-0">
<div className="flex flex-grow flex-col items-center justify-center pt-8 md:pt-0">
<div className="mb-4 flex items-center">
<span className="text-4xl font-semibold">DocsGPT</span>
<img className="mb-1 inline w-14" src={DocsGPT3} alt="docsgpt" />
@@ -29,7 +29,7 @@ export default function Hero({
</div>
{/* Demo Buttons Section */}
<div className="mb-3 w-full max-w-full md:mb-3">
<div className="mb-8 w-full max-w-full md:mb-16">
<div className="grid grid-cols-1 gap-3 text-xs md:grid-cols-1 md:gap-4 lg:grid-cols-2">
{demos?.map(
(demo: { header: string; query: string }, key: number) =>
@@ -38,9 +38,9 @@ export default function Hero({
<button
key={key}
onClick={() => handleQuestion({ question: demo.query })}
className={`border-dark-gray text-just-black hover:bg-cultured dark:border-dim-gray dark:text-chinese-white dark:hover:bg-charleston-green w-full rounded-[66px] border bg-transparent px-6 py-[14px] text-left transition-colors ${key >= 2 ? 'hidden md:block' : ''} // Show only 2 buttons on mobile`}
className={`w-full rounded-[66px] border border-dark-gray bg-transparent px-6 py-[14px] text-left text-just-black transition-colors hover:bg-cultured dark:border-dim-gray dark:text-chinese-white dark:hover:bg-charleston-green ${key >= 2 ? 'hidden md:block' : ''} // Show only 2 buttons on mobile`}
>
<p className="text-black-1000 dark:text-bright-gray mb-2 font-semibold">
<p className="mb-2 font-semibold text-black-1000 dark:text-bright-gray">
{demo.header}
</p>
<span className="line-clamp-2 text-gray-700 opacity-60 dark:text-gray-300">

View File

@@ -10,7 +10,7 @@ import Add from './assets/add.svg';
import DocsGPT3 from './assets/cute_docsgpt3.svg';
import Discord from './assets/discord.svg';
import Expand from './assets/expand.svg';
import Github from './assets/git_nav.svg';
import Github from './assets/github.svg';
import Hamburger from './assets/hamburger.svg';
import openNewChat from './assets/openNewChat.svg';
import Pin from './assets/pin.svg';
@@ -81,27 +81,8 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
useState<ActiveState>('INACTIVE');
const [recentAgents, setRecentAgents] = useState<Agent[]>([]);
const navRef = useRef<HTMLDivElement>(null);
useEffect(() => {
function handleClickOutside(event: MouseEvent) {
if (
navRef.current &&
!navRef.current.contains(event.target as Node) &&
(isMobile || isTablet) &&
navOpen
) {
setNavOpen(false);
}
}
const navRef = useRef(null);
//event listener only for mobile/tablet when nav is open
if ((isMobile || isTablet) && navOpen) {
document.addEventListener('mousedown', handleClickOutside);
return () => {
document.removeEventListener('mousedown', handleClickOutside);
};
}
}, [navOpen, isMobile, isTablet, setNavOpen]);
async function fetchRecentAgents() {
try {
const response = await userService.getPinnedAgents(token);
@@ -293,7 +274,7 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
return (
<>
{!navOpen && (
<div className="absolute top-3 left-3 z-20 hidden transition-all duration-25 lg:block">
<div className="duration-25 absolute left-3 top-3 z-20 hidden transition-all lg:block">
<div className="flex items-center gap-3">
<button
onClick={() => {
@@ -321,7 +302,7 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
/>
</button>
)}
<div className="text-gray-4000 text-[20px] font-medium">
<div className="text-[20px] font-medium text-[#949494]">
DocsGPT
</div>
</div>
@@ -330,8 +311,8 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
<div
ref={navRef}
className={`${
!navOpen && '-ml-96 md:-ml-72'
} bg-lotion dark:border-r-purple-taupe dark:bg-chinese-black fixed top-0 z-20 flex h-full w-72 flex-col border-r border-b-0 transition-all duration-20 dark:text-white`}
!navOpen && '-ml-96 md:-ml-[18rem]'
} duration-20 fixed top-0 z-20 flex h-full w-72 flex-col border-b-0 border-r-[1px] bg-lotion transition-all dark:border-r-purple-taupe dark:bg-chinese-black dark:text-white`}
>
<div
className={'visible mt-2 flex h-[6vh] w-full justify-between md:h-12'}
@@ -375,7 +356,7 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
className={({ isActive }) =>
`${
isActive ? 'bg-transparent' : ''
} group border-silver hover:border-rainy-gray dark:border-purple-taupe sticky mx-4 mt-4 flex cursor-pointer gap-2.5 rounded-3xl border p-3 hover:bg-transparent dark:text-white`
} group sticky mx-4 mt-4 flex cursor-pointer gap-2.5 rounded-3xl border border-silver p-3 hover:border-rainy-gray hover:bg-transparent dark:border-purple-taupe dark:text-white`
}
>
<img
@@ -383,16 +364,16 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
alt="Create new chat"
className="opacity-80 group-hover:opacity-100"
/>
<p className="text-dove-gray dark:text-chinese-silver dark:group-hover:text-bright-gray text-sm group-hover:text-neutral-600">
<p className="text-sm text-dove-gray group-hover:text-neutral-600 dark:text-chinese-silver dark:group-hover:text-bright-gray">
{t('newChat')}
</p>
</NavLink>
<div
id="conversationsMainDiv"
className="mb-auto h-[78vh] overflow-x-hidden overflow-y-auto dark:text-white"
className="mb-auto h-[78vh] overflow-y-auto overflow-x-hidden dark:text-white"
>
{conversations?.loading && !isDeletingConversation && (
<div className="absolute top-1/2 left-1/2 -translate-x-1/2 -translate-y-1/2 transform">
<div className="absolute left-1/2 top-1/2 -translate-x-1/2 -translate-y-1/2 transform">
<img
src={isDarkTheme ? SpinnerDark : Spinner}
className="animate-spin cursor-pointer bg-transparent"
@@ -403,14 +384,14 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
{recentAgents?.length > 0 ? (
<div>
<div className="mx-4 my-auto mt-2 flex h-6 items-center">
<p className="mt-1 ml-4 text-sm font-semibold">Agents</p>
<p className="ml-4 mt-1 text-sm font-semibold">Agents</p>
</div>
<div className="agents-container">
<div>
{recentAgents.map((agent, idx) => (
<div
key={idx}
className={`group hover:bg-bright-gray dark:hover:bg-dark-charcoal mx-4 my-auto mt-4 flex h-9 cursor-pointer items-center justify-between rounded-3xl pl-4 ${
className={`group mx-4 my-auto mt-4 flex h-9 cursor-pointer items-center justify-between rounded-3xl pl-4 hover:bg-bright-gray dark:hover:bg-dark-charcoal ${
agent.id === selectedAgent?.id && !conversationId
? 'bg-bright-gray dark:bg-dark-charcoal'
: ''
@@ -420,16 +401,12 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
<div className="flex items-center gap-2">
<div className="flex w-6 justify-center">
<img
src={
agent.image && agent.image.trim() !== ''
? agent.image
: Robot
}
src={agent.image ?? Robot}
alt="agent-logo"
className="h-6 w-6 rounded-full object-contain"
className="h-6 w-6"
/>
</div>
<p className="text-eerie-black dark:text-bright-gray overflow-hidden text-sm leading-6 text-ellipsis whitespace-nowrap">
<p className="overflow-hidden overflow-ellipsis whitespace-nowrap text-sm leading-6 text-eerie-black dark:text-bright-gray">
{agent.name}
</p>
</div>
@@ -453,7 +430,7 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
))}
</div>
<div
className="hover:bg-bright-gray dark:hover:bg-dark-charcoal mx-4 my-auto mt-2 flex h-9 cursor-pointer items-center gap-2 rounded-3xl pl-4"
className="mx-4 my-auto mt-2 flex h-9 cursor-pointer items-center gap-2 rounded-3xl pl-4 hover:bg-bright-gray dark:hover:bg-dark-charcoal"
onClick={() => {
dispatch(setSelectedAgent(null));
if (isMobile || isTablet) {
@@ -469,7 +446,7 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
className="h-[18px] w-[18px]"
/>
</div>
<p className="text-eerie-black dark:text-bright-gray overflow-hidden text-sm leading-6 text-ellipsis whitespace-nowrap">
<p className="overflow-hidden overflow-ellipsis whitespace-nowrap text-sm leading-6 text-eerie-black dark:text-bright-gray">
{t('manageAgents')}
</p>
</div>
@@ -477,7 +454,7 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
</div>
) : (
<div
className="hover:bg-bright-gray dark:hover:bg-dark-charcoal mx-4 my-auto mt-2 flex h-9 cursor-pointer items-center gap-2 rounded-3xl pl-4"
className="mx-4 my-auto mt-2 flex h-9 cursor-pointer items-center gap-2 rounded-3xl pl-4 hover:bg-bright-gray dark:hover:bg-dark-charcoal"
onClick={() => {
if (isMobile || isTablet) {
setNavOpen(false);
@@ -493,7 +470,7 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
className="h-[18px] w-[18px]"
/>
</div>
<p className="text-eerie-black dark:text-bright-gray overflow-hidden text-sm leading-6 text-ellipsis whitespace-nowrap">
<p className="overflow-hidden overflow-ellipsis whitespace-nowrap text-sm leading-6 text-eerie-black dark:text-bright-gray">
{t('manageAgents')}
</p>
</div>
@@ -501,7 +478,7 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
{conversations?.data && conversations.data.length > 0 ? (
<div className="mt-7">
<div className="mx-4 my-auto mt-2 flex h-6 items-center justify-between gap-4 rounded-3xl">
<p className="mt-1 ml-4 text-sm font-semibold">{t('chats')}</p>
<p className="ml-4 mt-1 text-sm font-semibold">{t('chats')}</p>
</div>
<div className="conversations-container">
{conversations.data?.map((conversation) => (
@@ -526,8 +503,8 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
<></>
)}
</div>
<div className="text-eerie-black flex h-auto flex-col justify-end dark:text-white">
<div className="dark:border-b-purple-taupe flex flex-col gap-2 border-b py-2">
<div className="flex h-auto flex-col justify-end text-eerie-black dark:text-white">
<div className="flex flex-col gap-2 border-b-[1px] py-2 dark:border-b-purple-taupe">
<NavLink
onClick={() => {
if (isMobile || isTablet) {
@@ -537,7 +514,7 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
}}
to="/settings"
className={({ isActive }) =>
`mx-4 my-auto flex h-9 cursor-pointer items-center gap-4 rounded-3xl hover:bg-gray-100 dark:hover:bg-[#28292E] ${
`mx-4 my-auto flex h-9 cursor-pointer gap-4 rounded-3xl hover:bg-gray-100 dark:hover:bg-[#28292E] ${
isActive ? 'bg-gray-3000 dark:bg-transparent' : ''
}`
}
@@ -545,16 +522,14 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
<img
src={SettingGear}
alt="Settings"
width={21}
height={21}
className="my-auto ml-2 filter dark:invert"
className="w- ml-2 filter dark:invert"
/>
<p className="text-eerie-black text-sm dark:text-white">
<p className="my-auto text-sm text-eerie-black dark:text-white">
{t('settings.label')}
</p>
</NavLink>
</div>
<div className="text-eerie-black flex flex-col justify-end dark:text-white">
<div className="flex flex-col justify-end text-eerie-black dark:text-white">
<div className="flex items-center justify-between py-1">
<Help />
@@ -568,8 +543,6 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
>
<img
src={Discord}
width={24}
height={24}
alt="Join Discord community"
className="m-2 w-6 self-center filter dark:invert"
/>
@@ -583,10 +556,8 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
>
<img
src={Twitter}
width={20}
height={20}
alt="Follow us on Twitter"
className="m-2 self-center filter dark:invert"
className="m-2 w-5 self-center filter dark:invert"
/>
</NavLink>
<NavLink
@@ -599,9 +570,7 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
<img
src={Github}
alt="View on GitHub"
width={28}
height={28}
className="m-2 self-center filter dark:invert"
className="m-2 w-6 self-center filter dark:invert"
/>
</NavLink>
</div>
@@ -609,7 +578,7 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
</div>
</div>
</div>
<div className="dark:border-b-purple-taupe dark:bg-chinese-black sticky z-10 h-16 w-full border-b-2 bg-gray-50 lg:hidden">
<div className="sticky z-10 h-16 w-full border-b-2 bg-gray-50 dark:border-b-purple-taupe dark:bg-chinese-black lg:hidden">
<div className="ml-6 flex h-full items-center gap-6">
<button
className="h-6 w-6 lg:hidden"
@@ -621,7 +590,7 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {
className="w-7 filter dark:invert"
/>
</button>
<div className="text-gray-4000 text-[20px] font-medium">DocsGPT</div>
<div className="text-[20px] font-medium text-[#949494]">DocsGPT</div>
</div>
</div>
<DeleteConvModal

View File

@@ -54,7 +54,7 @@ export default function AgentCard({
return (
<div
className={`relative flex h-44 w-48 flex-col justify-between rounded-[1.2rem] bg-[#F6F6F6] px-6 py-5 hover:bg-[#ECECEC] dark:bg-[#383838] dark:hover:bg-[#383838]/80 ${
className={`relative flex h-44 w-48 flex-col justify-between rounded-[1.2rem] bg-[#F6F6F6] px-6 py-5 hover:bg-[#ECECEC] dark:bg-[#383838] hover:dark:bg-[#383838]/80 ${
agent.status === 'published' ? 'cursor-pointer' : ''
}`}
onClick={handleCardClick}
@@ -65,7 +65,7 @@ export default function AgentCard({
e.stopPropagation();
setIsMenuOpen(true);
}}
className="absolute top-4 right-4 z-10 cursor-pointer"
className="absolute right-4 top-4 z-10 cursor-pointer"
>
<img src={ThreeDots} alt="options" className="h-[19px] w-[19px]" />
{menuOptions && (
@@ -83,9 +83,9 @@ export default function AgentCard({
<div className="w-full">
<div className="flex w-full items-center gap-1 px-1">
<img
src={agent.image && agent.image.trim() !== '' ? agent.image : Robot}
src={agent.image ?? Robot}
alt={`${agent.name}`}
className="h-7 w-7 rounded-full object-contain"
className="h-7 w-7 rounded-full"
/>
{agent.status === 'draft' && (
<p className="text-xs text-black opacity-50 dark:text-[#E0E0E0]">
@@ -96,11 +96,11 @@ export default function AgentCard({
<div className="mt-2">
<p
title={agent.name}
className="truncate px-1 text-[13px] leading-relaxed font-semibold text-[#020617] capitalize dark:text-[#E0E0E0]"
className="truncate px-1 text-[13px] font-semibold capitalize leading-relaxed text-[#020617] dark:text-[#E0E0E0]"
>
{agent.name}
</p>
<p className="dark:text-sonic-silver-light mt-1 h-20 overflow-auto px-1 text-[12px] leading-relaxed text-[#64748B]">
<p className="mt-1 h-20 overflow-auto px-1 text-[12px] leading-relaxed text-[#64748B] dark:text-sonic-silver-light">
{agent.description}
</p>
</div>

View File

@@ -44,12 +44,12 @@ export default function AgentLogs() {
>
<img src={ArrowLeft} alt="left-arrow" className="h-3 w-3" />
</button>
<p className="text-eerie-black dark:text-bright-gray mt-px text-sm font-semibold">
<p className="mt-px text-sm font-semibold text-eerie-black dark:text-bright-gray">
Back to all agents
</p>
</div>
<div className="mt-5 flex w-full flex-wrap items-center justify-between gap-2 px-4">
<h1 className="text-eerie-black m-0 text-[40px] font-bold dark:text-white">
<h1 className="m-0 text-[40px] font-bold text-[#212121] dark:text-white">
Agent Logs
</h1>
</div>

View File

@@ -6,23 +6,24 @@ import ConversationMessages from '../conversation/ConversationMessages';
import { Query } from '../conversation/conversationModels';
import {
addQuery,
fetchPreviewAnswer,
handlePreviewAbort,
fetchAnswer,
handleAbort,
resendQuery,
resetPreview,
selectPreviewQueries,
selectPreviewStatus,
} from './agentPreviewSlice';
resetConversation,
selectQueries,
selectStatus,
} from '../conversation/conversationSlice';
import { selectSelectedAgent } from '../preferences/preferenceSlice';
import { AppDispatch } from '../store';
export default function AgentPreview() {
const dispatch = useDispatch<AppDispatch>();
const queries = useSelector(selectPreviewQueries);
const status = useSelector(selectPreviewStatus);
const queries = useSelector(selectQueries);
const status = useSelector(selectStatus);
const selectedAgent = useSelector(selectSelectedAgent);
const [input, setInput] = useState('');
const [lastQueryReturnedErr, setLastQueryReturnedErr] = useState(false);
const fetchStream = useRef<any>(null);
@@ -30,7 +31,7 @@ export default function AgentPreview() {
const handleFetchAnswer = useCallback(
({ question, index }: { question: string; index?: number }) => {
fetchStream.current = dispatch(
fetchPreviewAnswer({ question, indx: index }),
fetchAnswer({ question, indx: index, isPreview: true }),
);
},
[dispatch],
@@ -94,11 +95,11 @@ export default function AgentPreview() {
};
useEffect(() => {
dispatch(resetPreview());
dispatch(resetConversation());
return () => {
if (fetchStream.current) fetchStream.current.abort();
handlePreviewAbort();
dispatch(resetPreview());
handleAbort();
dispatch(resetConversation());
};
}, [dispatch]);
@@ -110,7 +111,7 @@ export default function AgentPreview() {
}, [queries]);
return (
<div>
<div className="dark:bg-raisin-black flex h-full flex-col items-center justify-between gap-2 overflow-y-hidden">
<div className="flex h-full flex-col items-center justify-between gap-2 overflow-y-hidden dark:bg-raisin-black">
<div className="h-[512px] w-full overflow-y-auto">
<ConversationMessages
handleQuestion={handleQuestion}
@@ -128,7 +129,7 @@ export default function AgentPreview() {
showToolButton={selectedAgent ? false : true}
autoFocus={false}
/>
<p className="text-gray-4000 dark:text-sonic-silver w-full self-center bg-transparent pt-2 text-center text-xs md:inline">
<p className="w-full self-center bg-transparent pt-2 text-center text-xs text-gray-4000 dark:text-sonic-silver md:inline">
This is a preview of the agent. You can publish it to start using it
in conversations.
</p>

View File

@@ -1,5 +1,4 @@
import isEqual from 'lodash/isEqual';
import React, { useCallback, useEffect, useRef, useState } from 'react';
import React, { useEffect, useRef, useState } from 'react';
import { useDispatch, useSelector } from 'react-redux';
import { useNavigate, useParams } from 'react-router-dom';
@@ -7,9 +6,7 @@ import userService from '../api/services/userService';
import ArrowLeft from '../assets/arrow-left.svg';
import SourceIcon from '../assets/source.svg';
import Dropdown from '../components/Dropdown';
import { FileUpload } from '../components/FileUpload';
import MultiSelectPopup, { OptionType } from '../components/MultiSelectPopup';
import Spinner from '../components/Spinner';
import AgentDetailsModal from '../modals/AgentDetailsModal';
import ConfirmationModal from '../modals/ConfirmationModal';
import { ActiveState, Doc, Prompt } from '../models/misc';
@@ -20,7 +17,6 @@ import {
setSelectedAgent,
} from '../preferences/preferenceSlice';
import PromptsModal from '../preferences/PromptsModal';
import Prompts from '../settings/Prompts';
import { UserToolType } from '../settings/types';
import AgentPreview from './AgentPreview';
import { Agent } from './types';
@@ -45,16 +41,13 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
description: '',
image: '',
source: '',
sources: [],
chunks: '',
retriever: '',
prompt_id: 'default',
prompt_id: '',
tools: [],
agent_type: '',
status: '',
json_schema: undefined,
});
const [imageFile, setImageFile] = useState<File | null>(null);
const [prompts, setPrompts] = useState<
{ name: string; id: string; type: string }[]
>([]);
@@ -71,44 +64,34 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
useState<ActiveState>('INACTIVE');
const [agentDetails, setAgentDetails] = useState<ActiveState>('INACTIVE');
const [addPromptModal, setAddPromptModal] = useState<ActiveState>('INACTIVE');
const [hasChanges, setHasChanges] = useState(false);
const [draftLoading, setDraftLoading] = useState(false);
const [publishLoading, setPublishLoading] = useState(false);
const [jsonSchemaText, setJsonSchemaText] = useState('');
const [jsonSchemaValid, setJsonSchemaValid] = useState(true);
const [isJsonSchemaExpanded, setIsJsonSchemaExpanded] = useState(false);
const initialAgentRef = useRef<Agent | null>(null);
const sourceAnchorButtonRef = useRef<HTMLButtonElement>(null);
const toolAnchorButtonRef = useRef<HTMLButtonElement>(null);
const modeConfig = {
new: {
heading: 'New Agent',
buttonText: 'Publish',
buttonText: 'Create Agent',
showDelete: false,
showSaveDraft: true,
showLogs: false,
showAccessDetails: false,
trackChanges: false,
},
edit: {
heading: 'Edit Agent',
buttonText: 'Save',
buttonText: 'Save Changes',
showDelete: true,
showSaveDraft: false,
showLogs: true,
showAccessDetails: true,
trackChanges: true,
},
draft: {
heading: 'New Agent (Draft)',
buttonText: 'Publish',
buttonText: 'Publish Draft',
showDelete: true,
showSaveDraft: true,
showLogs: false,
showAccessDetails: false,
trackChanges: false,
},
};
const chunks = ['0', '2', '4', '6', '8', '10'];
@@ -118,24 +101,11 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
];
const isPublishable = () => {
const hasRequiredFields =
agent.name && agent.description && agent.prompt_id && agent.agent_type;
const isJsonSchemaValidOrEmpty =
jsonSchemaText.trim() === '' || jsonSchemaValid;
return hasRequiredFields && isJsonSchemaValidOrEmpty;
return (
agent.name && agent.description && agent.prompt_id && agent.agent_type
);
};
const isJsonSchemaInvalid = () => {
return jsonSchemaText.trim() !== '' && !jsonSchemaValid;
};
const handleUpload = useCallback((files: File[]) => {
if (files && files.length > 0) {
const file = files[0];
setImageFile(file);
}
}, []);
const handleCancel = () => {
if (selectedAgent) dispatch(setSelectedAgent(null));
navigate('/agents');
@@ -148,184 +118,42 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
};
const handleSaveDraft = async () => {
const formData = new FormData();
formData.append('name', agent.name);
formData.append('description', agent.description);
if (selectedSourceIds.size > 1) {
const sourcesArray = Array.from(selectedSourceIds)
.map((id) => {
const sourceDoc = sourceDocs?.find(
(source) =>
source.id === id || source.retriever === id || source.name === id,
const response =
effectiveMode === 'new'
? await userService.createAgent({ ...agent, status: 'draft' }, token)
: await userService.updateAgent(
agent.id || '',
{ ...agent, status: 'draft' },
token,
);
if (sourceDoc?.name === 'Default' && !sourceDoc?.id) {
return 'default';
}
return sourceDoc?.id || id;
})
.filter(Boolean);
formData.append('sources', JSON.stringify(sourcesArray));
formData.append('source', '');
} else if (selectedSourceIds.size === 1) {
const singleSourceId = Array.from(selectedSourceIds)[0];
const sourceDoc = sourceDocs?.find(
(source) =>
source.id === singleSourceId ||
source.retriever === singleSourceId ||
source.name === singleSourceId,
);
let finalSourceId;
if (sourceDoc?.name === 'Default' && !sourceDoc?.id)
finalSourceId = 'default';
else finalSourceId = sourceDoc?.id || singleSourceId;
formData.append('source', String(finalSourceId));
formData.append('sources', JSON.stringify([]));
} else {
formData.append('source', '');
formData.append('sources', JSON.stringify([]));
}
formData.append('chunks', agent.chunks);
formData.append('retriever', agent.retriever);
formData.append('prompt_id', agent.prompt_id);
formData.append('agent_type', agent.agent_type);
formData.append('status', 'draft');
if (imageFile) formData.append('image', imageFile);
if (agent.tools && agent.tools.length > 0)
formData.append('tools', JSON.stringify(agent.tools));
else formData.append('tools', '[]');
if (agent.json_schema) {
formData.append('json_schema', JSON.stringify(agent.json_schema));
}
try {
setDraftLoading(true);
const response =
effectiveMode === 'new'
? await userService.createAgent(formData, token)
: await userService.updateAgent(agent.id || '', formData, token);
if (!response.ok) throw new Error('Failed to create agent draft');
const data = await response.json();
const updatedAgent = {
...agent,
id: data.id || agent.id,
image: data.image || agent.image,
};
setAgent(updatedAgent);
if (effectiveMode === 'new') setEffectiveMode('draft');
} catch (error) {
console.error('Error saving draft:', error);
throw new Error('Failed to save draft');
} finally {
setDraftLoading(false);
if (!response.ok) throw new Error('Failed to create agent draft');
const data = await response.json();
if (effectiveMode === 'new') {
setEffectiveMode('draft');
setAgent((prev) => ({ ...prev, id: data.id }));
}
};
const handlePublish = async () => {
const formData = new FormData();
formData.append('name', agent.name);
formData.append('description', agent.description);
if (selectedSourceIds.size > 1) {
const sourcesArray = Array.from(selectedSourceIds)
.map((id) => {
const sourceDoc = sourceDocs?.find(
(source) =>
source.id === id || source.retriever === id || source.name === id,
const response =
effectiveMode === 'new'
? await userService.createAgent(
{ ...agent, status: 'published' },
token,
)
: await userService.updateAgent(
agent.id || '',
{ ...agent, status: 'published' },
token,
);
if (sourceDoc?.name === 'Default' && !sourceDoc?.id) {
return 'default';
}
return sourceDoc?.id || id;
})
.filter(Boolean);
formData.append('sources', JSON.stringify(sourcesArray));
formData.append('source', '');
} else if (selectedSourceIds.size === 1) {
const singleSourceId = Array.from(selectedSourceIds)[0];
const sourceDoc = sourceDocs?.find(
(source) =>
source.id === singleSourceId ||
source.retriever === singleSourceId ||
source.name === singleSourceId,
);
let finalSourceId;
if (sourceDoc?.name === 'Default' && !sourceDoc?.id)
finalSourceId = 'default';
else finalSourceId = sourceDoc?.id || singleSourceId;
formData.append('source', String(finalSourceId));
formData.append('sources', JSON.stringify([]));
} else {
formData.append('source', '');
formData.append('sources', JSON.stringify([]));
}
formData.append('chunks', agent.chunks);
formData.append('retriever', agent.retriever);
formData.append('prompt_id', agent.prompt_id);
formData.append('agent_type', agent.agent_type);
formData.append('status', 'published');
if (imageFile) formData.append('image', imageFile);
if (agent.tools && agent.tools.length > 0)
formData.append('tools', JSON.stringify(agent.tools));
else formData.append('tools', '[]');
if (agent.json_schema) {
formData.append('json_schema', JSON.stringify(agent.json_schema));
}
try {
setPublishLoading(true);
const response =
effectiveMode === 'new'
? await userService.createAgent(formData, token)
: await userService.updateAgent(agent.id || '', formData, token);
if (!response.ok) throw new Error('Failed to publish agent');
const data = await response.json();
const updatedAgent = {
...agent,
id: data.id || agent.id,
key: data.key || agent.key,
status: 'published',
image: data.image || agent.image,
};
setAgent(updatedAgent);
initialAgentRef.current = updatedAgent;
if (effectiveMode === 'new' || effectiveMode === 'draft') {
setEffectiveMode('edit');
setAgentDetails('ACTIVE');
}
setImageFile(null);
} catch (error) {
console.error('Error publishing agent:', error);
throw new Error('Failed to publish agent');
} finally {
setPublishLoading(false);
}
};
const validateAndSetJsonSchema = (text: string) => {
setJsonSchemaText(text);
if (text.trim() === '') {
setAgent({ ...agent, json_schema: undefined });
setJsonSchemaValid(true);
return;
}
try {
const parsed = JSON.parse(text);
setAgent({ ...agent, json_schema: parsed });
setJsonSchemaValid(true);
} catch (error) {
setJsonSchemaValid(false);
if (!response.ok) throw new Error('Failed to publish agent');
const data = await response.json();
if (data.id) setAgent((prev) => ({ ...prev, id: data.id }));
if (data.key) setAgent((prev) => ({ ...prev, key: data.key }));
if (effectiveMode === 'new' || effectiveMode === 'draft') {
setEffectiveMode('edit');
setAgent((prev) => ({ ...prev, status: 'published' }));
setAgentDetails('ACTIVE');
}
};
@@ -362,99 +190,37 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
throw new Error('Failed to fetch agent');
}
const data = await response.json();
if (data.sources && data.sources.length > 0) {
const mappedSources = data.sources.map((sourceId: string) => {
if (sourceId === 'default') {
const defaultSource = sourceDocs?.find(
(source) => source.name === 'Default',
);
return defaultSource?.retriever || 'classic';
}
return sourceId;
});
setSelectedSourceIds(new Set(mappedSources));
} else if (data.source) {
if (data.source === 'default') {
const defaultSource = sourceDocs?.find(
(source) => source.name === 'Default',
);
setSelectedSourceIds(
new Set([defaultSource?.retriever || 'classic']),
);
} else {
setSelectedSourceIds(new Set([data.source]));
}
} else if (data.retriever) {
if (data.source) setSelectedSourceIds(new Set([data.source]));
else if (data.retriever)
setSelectedSourceIds(new Set([data.retriever]));
}
if (data.tools) setSelectedToolIds(new Set(data.tools));
if (data.status === 'draft') setEffectiveMode('draft');
if (data.json_schema) {
const jsonText = JSON.stringify(data.json_schema, null, 2);
setJsonSchemaText(jsonText);
setJsonSchemaValid(true);
}
setAgent(data);
initialAgentRef.current = data;
};
getAgent();
}
}, [agentId, mode, token]);
useEffect(() => {
const selectedSources = Array.from(selectedSourceIds)
.map((id) =>
sourceDocs?.find(
(source) =>
source.id === id || source.retriever === id || source.name === id,
),
)
.filter(Boolean);
if (selectedSources.length > 0) {
// Handle multiple sources
if (selectedSources.length > 1) {
// Multiple sources selected - store in sources array
const sourceIds = selectedSources
.map((source) => source?.id)
.filter((id): id is string => Boolean(id));
const selectedSource = Array.from(selectedSourceIds).map((id) =>
sourceDocs?.find(
(source) =>
source.id === id || source.retriever === id || source.name === id,
),
);
if (selectedSource[0]?.model === embeddingsName) {
if (selectedSource[0] && 'id' in selectedSource[0]) {
setAgent((prev) => ({
...prev,
sources: sourceIds,
source: '', // Clear single source for multiple sources
source: selectedSource[0]?.id || 'default',
retriever: '',
}));
} else {
// Single source selected - maintain backward compatibility
const selectedSource = selectedSources[0];
if (selectedSource?.model === embeddingsName) {
if (selectedSource && 'id' in selectedSource) {
setAgent((prev) => ({
...prev,
source: selectedSource?.id || 'default',
sources: [], // Clear sources array for single source
retriever: '',
}));
} else {
setAgent((prev) => ({
...prev,
source: '',
sources: [], // Clear sources array
retriever: selectedSource?.retriever || 'classic',
}));
}
}
}
} else {
// No sources selected
setAgent((prev) => ({
...prev,
source: '',
sources: [],
retriever: '',
}));
} else
setAgent((prev) => ({
...prev,
source: '',
retriever: selectedSource[0]?.retriever || 'classic',
}));
}
}, [selectedSourceIds]);
@@ -472,26 +238,7 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
useEffect(() => {
if (isPublishable()) dispatch(setSelectedAgent(agent));
if (!modeConfig[effectiveMode].trackChanges) {
setHasChanges(true);
return;
}
if (!initialAgentRef.current) {
setHasChanges(false);
return;
}
const initialJsonSchemaText = initialAgentRef.current.json_schema
? JSON.stringify(initialAgentRef.current.json_schema, null, 2)
: '';
const isChanged =
!isEqual(agent, initialAgentRef.current) ||
imageFile !== null ||
jsonSchemaText !== initialJsonSchemaText;
setHasChanges(isChanged);
}, [agent, dispatch, effectiveMode, imageFile, jsonSchemaText]);
}, [agent, dispatch]);
return (
<div className="p-4 md:p-12">
<div className="flex items-center gap-3 px-4">
@@ -501,24 +248,24 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
>
<img src={ArrowLeft} alt="left-arrow" className="h-3 w-3" />
</button>
<p className="text-eerie-black dark:text-bright-gray mt-px text-sm font-semibold">
<p className="mt-px text-sm font-semibold text-eerie-black dark:text-bright-gray">
Back to all agents
</p>
</div>
<div className="mt-5 flex w-full flex-wrap items-center justify-between gap-2 px-4">
<h1 className="text-eerie-black m-0 text-[40px] font-bold dark:text-white">
<h1 className="m-0 text-[40px] font-bold text-[#212121] dark:text-white">
{modeConfig[effectiveMode].heading}
</h1>
<div className="flex flex-wrap items-center gap-1">
<button
className="text-purple-30 dark:text-light-gray mr-4 rounded-3xl py-2 text-sm font-medium dark:bg-transparent"
className="mr-4 rounded-3xl py-2 text-sm font-medium text-purple-30 dark:bg-transparent dark:text-light-gray"
onClick={handleCancel}
>
Cancel
</button>
{modeConfig[effectiveMode].showDelete && agent.id && (
<button
className="group border-red-2000 text-red-2000 hover:bg-red-2000 flex items-center gap-2 rounded-3xl border border-solid px-5 py-2 text-sm font-medium transition-colors hover:text-white"
className="group flex items-center gap-2 rounded-3xl border border-solid border-red-2000 px-5 py-2 text-sm font-medium text-red-2000 transition-colors hover:bg-red-2000 hover:text-white"
onClick={() => setDeleteConfirmation('ACTIVE')}
>
<span className="block h-4 w-4 bg-[url('/src/assets/red-trash.svg')] bg-contain bg-center bg-no-repeat transition-all group-hover:bg-[url('/src/assets/white-trash.svg')]" />
@@ -527,24 +274,15 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
)}
{modeConfig[effectiveMode].showSaveDraft && (
<button
disabled={isJsonSchemaInvalid()}
className={`border-violets-are-blue text-violets-are-blue hover:bg-violets-are-blue w-28 rounded-3xl border border-solid py-2 text-sm font-medium transition-colors hover:text-white ${
isJsonSchemaInvalid() ? 'cursor-not-allowed opacity-30' : ''
}`}
className="hover:bg-vi</button>olets-are-blue rounded-3xl border border-solid border-violets-are-blue px-5 py-2 text-sm font-medium text-violets-are-blue transition-colors hover:bg-violets-are-blue hover:text-white"
onClick={handleSaveDraft}
>
<span className="flex items-center justify-center transition-all duration-200">
{draftLoading ? (
<Spinner size="small" color="#976af3" />
) : (
'Save Draft'
)}
</span>
Save Draft
</button>
)}
{modeConfig[effectiveMode].showAccessDetails && (
<button
className="group border-violets-are-blue text-violets-are-blue hover:bg-violets-are-blue flex items-center gap-2 rounded-3xl border border-solid px-5 py-2 text-sm font-medium transition-colors hover:text-white"
className="group flex items-center gap-2 rounded-3xl border border-solid border-violets-are-blue px-5 py-2 text-sm font-medium text-violets-are-blue transition-colors hover:bg-violets-are-blue hover:text-white"
onClick={() => navigate(`/agents/logs/${agent.id}`)}
>
<span className="block h-5 w-5 bg-[url('/src/assets/monitoring-purple.svg')] bg-contain bg-center bg-no-repeat transition-all group-hover:bg-[url('/src/assets/monitoring-white.svg')]" />
@@ -553,24 +291,18 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
)}
{modeConfig[effectiveMode].showAccessDetails && (
<button
className="hover:bg-vi</button>olets-are-blue border-violets-are-blue text-violets-are-blue hover:bg-violets-are-blue rounded-3xl border border-solid px-5 py-2 text-sm font-medium transition-colors hover:text-white"
className="hover:bg-vi</button>olets-are-blue rounded-3xl border border-solid border-violets-are-blue px-5 py-2 text-sm font-medium text-violets-are-blue transition-colors hover:bg-violets-are-blue hover:text-white"
onClick={() => setAgentDetails('ACTIVE')}
>
Access Details
</button>
)}
<button
disabled={!isPublishable() || !hasChanges}
className={`${!isPublishable() || !hasChanges ? 'cursor-not-allowed opacity-30' : ''} bg-purple-30 hover:bg-violets-are-blue flex w-28 items-center justify-center rounded-3xl py-2 text-sm font-medium text-white`}
disabled={!isPublishable()}
className={`${!isPublishable() && 'cursor-not-allowed opacity-30'} rounded-3xl bg-purple-30 px-5 py-2 text-sm font-medium text-white hover:bg-violets-are-blue`}
onClick={handlePublish}
>
<span className="flex items-center justify-center transition-all duration-200">
{publishLoading ? (
<Spinner size="small" color="white" />
) : (
modeConfig[effectiveMode].buttonText
)}
</span>
Publish
</button>
</div>
</div>
@@ -579,35 +311,20 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
<div className="rounded-[30px] bg-[#F6F6F6] px-6 py-3 dark:bg-[#383838] dark:text-[#E0E0E0]">
<h2 className="text-lg font-semibold">Meta</h2>
<input
className="border-silver text-jet dark:bg-raisin-black dark:text-bright-gray dark:placeholder:text-silver mt-3 w-full rounded-3xl border bg-white px-5 py-3 text-sm outline-hidden placeholder:text-gray-400 dark:border-[#7E7E7E]"
className="mt-3 w-full rounded-3xl border border-silver bg-white px-5 py-3 text-sm text-jet outline-none placeholder:text-gray-400 dark:border-[#7E7E7E] dark:bg-[#222327] dark:text-bright-gray placeholder:dark:text-silver"
type="text"
value={agent.name}
placeholder="Agent name"
onChange={(e) => setAgent({ ...agent, name: e.target.value })}
/>
<textarea
className="border-silver text-jet dark:bg-raisin-black dark:text-bright-gray dark:placeholder:text-silver mt-3 h-32 w-full rounded-xl border bg-white px-5 py-4 text-sm outline-hidden placeholder:text-gray-400 dark:border-[#7E7E7E]"
className="mt-3 h-32 w-full rounded-3xl border border-silver bg-white px-5 py-4 text-sm text-jet outline-none placeholder:text-gray-400 dark:border-[#7E7E7E] dark:bg-[#222327] dark:text-bright-gray placeholder:dark:text-silver"
placeholder="Describe your agent"
value={agent.description}
onChange={(e) =>
setAgent({ ...agent, description: e.target.value })
}
/>
<div className="mt-3">
<FileUpload
showPreview
className="dark:bg-raisin-black"
onUpload={handleUpload}
onRemove={() => setImageFile(null)}
uploadText={[
{ text: 'Click to upload', colorClass: 'text-[#7D54D1]' },
{
text: ' or drag and drop',
colorClass: 'text-[#525252]',
},
]}
/>
</div>
</div>
<div className="rounded-[30px] bg-[#F6F6F6] px-6 py-3 dark:bg-[#383838] dark:text-[#E0E0E0]">
<h2 className="text-lg font-semibold">Source</h2>
@@ -616,11 +333,7 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
<button
ref={sourceAnchorButtonRef}
onClick={() => setIsSourcePopupOpen(!isSourcePopupOpen)}
className={`border-silver dark:bg-raisin-black w-full truncate rounded-3xl border bg-white px-5 py-3 text-left text-sm dark:border-[#7E7E7E] ${
selectedSourceIds.size > 0
? 'text-jet dark:text-bright-gray'
: 'dark:text-silver text-gray-400'
}`}
className="w-full truncate rounded-3xl border border-silver bg-white px-5 py-3 text-left text-sm text-gray-400 dark:border-[#7E7E7E] dark:bg-[#222327] dark:text-silver"
>
{selectedSourceIds.size > 0
? Array.from(selectedSourceIds)
@@ -635,7 +348,7 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
)
.filter(Boolean)
.join(', ')
: 'Select sources'}
: 'Select source'}
</button>
<MultiSelectPopup
isOpen={isSourcePopupOpen}
@@ -651,10 +364,12 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
selectedIds={selectedSourceIds}
onSelectionChange={(newSelectedIds: Set<string | number>) => {
setSelectedSourceIds(newSelectedIds);
setIsSourcePopupOpen(false);
}}
title="Select Sources"
title="Select Source"
searchPlaceholder="Search sources..."
noOptionsMessage="No sources available"
noOptionsMessage="No source available"
singleSelect={true}
/>
</div>
<div className="mt-3">
@@ -666,47 +381,49 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
}
size="w-full"
rounded="3xl"
buttonDarkBackgroundColor="[#222327]"
border="border"
buttonClassName="bg-white dark:bg-[#222327] border-silver dark:border-[#7E7E7E]"
optionsClassName="bg-white dark:bg-[#383838] border-silver dark:border-[#7E7E7E]"
darkBorderColor="[#7E7E7E]"
placeholder="Chunks per query"
placeholderClassName="text-gray-400 dark:text-silver"
placeholderTextColor="gray-400"
darkPlaceholderTextColor="silver"
contentSize="text-sm"
/>
</div>
</div>
</div>
<div className="rounded-[30px] bg-[#F6F6F6] px-6 py-3 dark:bg-[#383838] dark:text-[#E0E0E0]">
<div className="flex flex-wrap items-end gap-1">
<div className="min-w-20 grow basis-full sm:basis-0">
<Prompts
prompts={prompts}
selectedPrompt={
prompts.find((prompt) => prompt.id === agent.prompt_id) ||
prompts[0]
<h2 className="text-lg font-semibold">Prompt</h2>
<div className="mt-3 flex flex-wrap items-center gap-1">
<div className="min-w-20 flex-grow basis-full sm:basis-0">
<Dropdown
options={prompts.map((prompt) => ({
label: prompt.name,
value: prompt.id,
}))}
selectedValue={
agent.prompt_id
? prompts.filter(
(prompt) => prompt.id === agent.prompt_id,
)[0]?.name || null
: null
}
onSelectPrompt={(name, id, type) =>
setAgent({ ...agent, prompt_id: id })
onSelect={(option: { label: string; value: string }) =>
setAgent({ ...agent, prompt_id: option.value })
}
setPrompts={setPrompts}
title="Prompt"
titleClassName="text-lg font-semibold dark:text-[#E0E0E0]"
showAddButton={false}
dropdownProps={{
size: 'w-full',
rounded: '3xl',
border: 'border',
buttonClassName:
'bg-white dark:bg-[#222327] border-silver dark:border-[#7E7E7E]',
optionsClassName:
'bg-white dark:bg-[#383838] border-silver dark:border-[#7E7E7E]',
placeholderClassName: 'text-gray-400 dark:text-silver',
contentSize: 'text-sm',
}}
size="w-full"
rounded="3xl"
buttonDarkBackgroundColor="[#222327]"
border="border"
darkBorderColor="[#7E7E7E]"
placeholder="Select a prompt"
placeholderTextColor="gray-400"
darkPlaceholderTextColor="silver"
contentSize="text-sm"
/>
</div>
<button
className="border-violets-are-blue text-violets-are-blue hover:bg-violets-are-blue w-20 shrink-0 basis-full rounded-3xl border-2 border-solid px-5 py-[11px] text-sm transition-colors hover:text-white sm:basis-auto"
className="w-20 flex-shrink-0 basis-full rounded-3xl border-2 border-solid border-violets-are-blue px-5 py-[11px] text-sm text-violets-are-blue transition-colors hover:bg-violets-are-blue hover:text-white sm:basis-auto"
onClick={() => setAddPromptModal('ACTIVE')}
>
Add
@@ -719,11 +436,7 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
<button
ref={toolAnchorButtonRef}
onClick={() => setIsToolsPopupOpen(!isToolsPopupOpen)}
className={`border-silver dark:bg-raisin-black w-full truncate rounded-3xl border bg-white px-5 py-3 text-left text-sm dark:border-[#7E7E7E] ${
selectedToolIds.size > 0
? 'text-jet dark:text-bright-gray'
: 'dark:text-silver text-gray-400'
}`}
className="w-full truncate rounded-3xl border border-silver bg-white px-5 py-3 text-left text-sm text-gray-400 dark:border-[#7E7E7E] dark:bg-[#222327] dark:text-silver"
>
{selectedToolIds.size > 0
? Array.from(selectedToolIds)
@@ -765,87 +478,16 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
}
size="w-full"
rounded="3xl"
buttonDarkBackgroundColor="[#222327]"
border="border"
buttonClassName="bg-white dark:bg-[#222327] border-silver dark:border-[#7E7E7E]"
optionsClassName="bg-white dark:bg-[#383838] border-silver dark:border-[#7E7E7E]"
darkBorderColor="[#7E7E7E]"
placeholder="Select type"
placeholderClassName="text-gray-400 dark:text-silver"
placeholderTextColor="gray-400"
darkPlaceholderTextColor="silver"
contentSize="text-sm"
/>
</div>
</div>
<div className="rounded-[30px] bg-[#F6F6F6] px-6 py-3 dark:bg-[#383838] dark:text-[#E0E0E0]">
<button
onClick={() => setIsJsonSchemaExpanded(!isJsonSchemaExpanded)}
className="flex w-full items-center justify-between text-left focus:outline-none"
>
<div>
<h2 className="text-lg font-semibold">Advanced</h2>
</div>
<div className="ml-4 flex items-center">
<svg
className={`h-5 w-5 transform transition-transform duration-200 ${
isJsonSchemaExpanded ? 'rotate-180' : ''
}`}
fill="none"
stroke="currentColor"
viewBox="0 0 24 24"
>
<path
strokeLinecap="round"
strokeLinejoin="round"
strokeWidth={2}
d="M19 9l-7 7-7-7"
/>
</svg>
</div>
</button>
{isJsonSchemaExpanded && (
<div className="mt-3">
<div>
<h2 className="text-sm font-medium">JSON response schema</h2>
<p className="mt-1 text-xs text-gray-600 dark:text-gray-400">
Define a JSON schema to enforce structured output format
</p>
</div>
<textarea
value={jsonSchemaText}
onChange={(e) => validateAndSetJsonSchema(e.target.value)}
placeholder={`{
"type": "object",
"properties": {
"name": {"type": "string"},
"email": {"type": "string"}
},
"required": ["name", "email"],
"additionalProperties": false
}`}
rows={9}
className={`border-silver text-jet dark:bg-raisin-black dark:text-bright-gray mt-2 w-full rounded-2xl border bg-white px-4 py-3 font-mono text-sm outline-hidden dark:border-[#7E7E7E]`}
/>
{jsonSchemaText.trim() !== '' && (
<div
className={`mt-2 flex items-center gap-2 text-sm ${
jsonSchemaValid
? 'text-green-600 dark:text-green-400'
: 'text-red-600 dark:text-red-400'
}`}
>
<span
className={`h-4 w-4 bg-contain bg-center bg-no-repeat ${
jsonSchemaValid
? "bg-[url('/src/assets/circle-check.svg')]"
: "bg-[url('/src/assets/circle-x.svg')]"
}`}
/>
{jsonSchemaValid
? 'Valid JSON'
: 'Invalid JSON - fix to enable saving'}
</div>
)}
</div>
)}
</div>
</div>
<div className="col-span-3 flex flex-col gap-3 rounded-[30px] bg-[#F6F6F6] px-6 py-3 dark:bg-[#383838] dark:text-[#E0E0E0]">
<h2 className="text-lg font-semibold">Preview</h2>
@@ -886,7 +528,7 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
function AgentPreviewArea() {
const selectedAgent = useSelector(selectSelectedAgent);
return (
<div className="dark:bg-raisin-black h-full w-full rounded-[30px] border border-[#F6F6F6] bg-white max-[1180px]:h-192 dark:border-[#7E7E7E]">
<div className="h-full w-full rounded-[30px] border border-[#F6F6F6] bg-white dark:border-[#7E7E7E] dark:bg-[#222327] max-[1180px]:h-[48rem]">
{selectedAgent?.status === 'published' ? (
<div className="flex h-full w-full flex-col justify-end overflow-auto rounded-[30px]">
<AgentPreview />
@@ -894,7 +536,7 @@ function AgentPreviewArea() {
) : (
<div className="flex h-full w-full flex-col items-center justify-center gap-2">
<span className="block h-12 w-12 bg-[url('/src/assets/science-spark.svg')] bg-contain bg-center bg-no-repeat transition-all dark:bg-[url('/src/assets/science-spark-dark.svg')]" />{' '}
<p className="dark:text-gray-4000 text-xs text-[#18181B]">
<p className="text-xs text-[#18181B] dark:text-[#949494]">
Published agents can be previewed here
</p>
</div>
@@ -945,7 +587,7 @@ function AddPromptModal({
setNewPromptContent('');
onSelect?.(newPromptName, newPrompt.id, newPromptContent);
} catch (error) {
console.error('Error adding prompt:', error);
console.error(error);
}
};
return (

View File

@@ -57,7 +57,9 @@ export default function SharedAgent() {
const handleFetchAnswer = useCallback(
({ question, index }: { question: string; index?: number }) => {
fetchStream.current = dispatch(fetchAnswer({ question, indx: index }));
fetchStream.current = dispatch(
fetchAnswer({ question, indx: index, isPreview: false }),
);
},
[dispatch],
);
@@ -143,7 +145,7 @@ export default function SharedAgent() {
alt="No agent found"
className="mx-auto mb-6 h-32 w-32"
/>
<p className="dark:text-gray-4000 text-center text-lg text-[#71717A]">
<p className="text-center text-lg text-[#71717A] dark:text-[#949494]">
No agent found. Please ensure the agent is shared.
</p>
</div>
@@ -151,17 +153,13 @@ export default function SharedAgent() {
);
return (
<div className="relative h-full w-full">
<div className="absolute top-5 left-4 hidden items-center gap-3 sm:flex">
<div className="absolute left-4 top-5 hidden items-center gap-3 sm:flex">
<img
src={
sharedAgent.image && sharedAgent.image.trim() !== ''
? sharedAgent.image
: Robot
}
src={sharedAgent.image ?? Robot}
alt="agent-logo"
className="h-6 w-6 rounded-full object-contain"
className="h-6 w-6"
/>
<h2 className="text-eerie-black text-lg font-semibold dark:text-[#E0E0E0]">
<h2 className="text-lg font-semibold text-[#212121] dark:text-[#E0E0E0]">
{sharedAgent.name}
</h2>
</div>
@@ -188,7 +186,7 @@ export default function SharedAgent() {
showToolButton={sharedAgent ? false : true}
autoFocus={false}
/>
<p className="text-gray-4000 dark:text-sonic-silver hidden w-screen self-center bg-transparent py-2 text-center text-xs md:inline md:w-full">
<p className="hidden w-[100vw] self-center bg-transparent py-2 text-center text-xs text-gray-4000 dark:text-sonic-silver md:inline md:w-full">
{t('tagline')}
</p>
</div>

View File

@@ -3,19 +3,16 @@ import { Agent } from './types';
export default function SharedAgentCard({ agent }: { agent: Agent }) {
return (
<div className="border-dark-gray dark:border-grey flex w-full max-w-[720px] flex-col rounded-3xl border p-6 shadow-xs sm:w-fit sm:min-w-[480px]">
<div className="flex w-full max-w-[720px] flex-col rounded-3xl border border-dark-gray p-6 shadow-sm dark:border-grey sm:w-fit sm:min-w-[480px]">
<div className="flex items-center gap-3">
<div className="flex h-12 w-12 items-center justify-center overflow-hidden rounded-full p-1">
<img
src={agent.image && agent.image.trim() !== '' ? agent.image : Robot}
className="h-full w-full rounded-full object-contain"
/>
<img src={Robot} className="h-full w-full object-contain" />
</div>
<div className="flex max-h-[92px] w-[80%] flex-col gap-px">
<h2 className="text-eerie-black text-base font-semibold sm:text-lg dark:text-[#E0E0E0]">
<h2 className="text-base font-semibold text-[#212121] dark:text-[#E0E0E0] sm:text-lg">
{agent.name}
</h2>
<p className="dark:text-gray-4000 overflow-y-auto text-xs text-wrap break-all text-[#71717A] sm:text-sm">
<p className="overflow-y-auto text-wrap break-all text-xs text-[#71717A] dark:text-[#949494] sm:text-sm">
{agent.description}
</p>
</div>
@@ -23,12 +20,12 @@ export default function SharedAgentCard({ agent }: { agent: Agent }) {
{agent.shared_metadata && (
<div className="mt-4 flex items-center gap-8">
{agent.shared_metadata?.shared_by && (
<p className="text-eerie-black text-xs font-light sm:text-sm dark:text-[#E0E0E0]">
<p className="text-xs font-light text-[#212121] dark:text-[#E0E0E0] sm:text-sm">
by {agent.shared_metadata.shared_by}
</p>
)}
{agent.shared_metadata?.shared_at && (
<p className="dark:text-gray-4000 text-xs font-light text-[#71717A] sm:text-sm">
<p className="text-xs font-light text-[#71717A] dark:text-[#949494] sm:text-sm">
Shared on{' '}
{new Date(agent.shared_metadata.shared_at).toLocaleString(
'en-US',
@@ -47,14 +44,14 @@ export default function SharedAgentCard({ agent }: { agent: Agent }) {
)}
{agent.tool_details && agent.tool_details.length > 0 && (
<div className="mt-8">
<p className="text-eerie-black text-sm font-semibold sm:text-base dark:text-[#E0E0E0]">
<p className="text-sm font-semibold text-[#212121] dark:text-[#E0E0E0] sm:text-base">
Connected Tools
</p>
<div className="mt-2 flex flex-wrap gap-2">
{agent.tool_details.map((tool, index) => (
<span
key={index}
className="bg-bright-gray text-eerie-black dark:bg-dark-charcoal flex items-center gap-1 rounded-full px-3 py-1 text-xs font-light dark:text-[#E0E0E0]"
className="flex items-center gap-1 rounded-full bg-bright-gray px-3 py-1 text-xs font-light text-[#212121] dark:bg-dark-charcoal dark:text-[#E0E0E0]"
>
<img
src={`/toolIcons/tool_${tool.name}.svg`}

View File

@@ -1,336 +0,0 @@
import { createAsyncThunk, createSlice, PayloadAction } from '@reduxjs/toolkit';
import {
Answer,
ConversationState,
Query,
Status,
} from '../conversation/conversationModels';
import {
handleFetchAnswer,
handleFetchAnswerSteaming,
} from '../conversation/conversationHandlers';
import {
selectCompletedAttachments,
clearAttachments,
} from '../upload/uploadSlice';
import store from '../store';
const initialState: ConversationState = {
queries: [],
status: 'idle',
conversationId: null,
};
const API_STREAMING = import.meta.env.VITE_API_STREAMING === 'true';
let abortController: AbortController | null = null;
export function handlePreviewAbort() {
if (abortController) {
abortController.abort();
abortController = null;
}
}
export const fetchPreviewAnswer = createAsyncThunk<
Answer,
{ question: string; indx?: number }
>(
'agentPreview/fetchAnswer',
async ({ question, indx }, { dispatch, getState }) => {
if (abortController) abortController.abort();
abortController = new AbortController();
const { signal } = abortController;
const state = getState() as RootState;
const attachmentIds = selectCompletedAttachments(state)
.filter((a) => a.id)
.map((a) => a.id) as string[];
if (attachmentIds.length > 0) {
dispatch(clearAttachments());
}
if (state.preference) {
if (API_STREAMING) {
await handleFetchAnswerSteaming(
question,
signal,
state.preference.token,
state.preference.selectedDocs!,
null, // No conversation ID for previews
state.preference.prompt.id,
state.preference.chunks,
state.preference.token_limit,
(event: MessageEvent) => {
const data = JSON.parse(event.data);
const targetIndex = indx ?? state.agentPreview.queries.length - 1;
if (data.type === 'end') {
dispatch(agentPreviewSlice.actions.setStatus('idle'));
} else if (data.type === 'thought') {
dispatch(
updateThought({
index: targetIndex,
query: { thought: data.thought },
}),
);
} else if (data.type === 'source') {
dispatch(
updateStreamingSource({
index: targetIndex,
query: { sources: data.source ?? [] },
}),
);
} else if (data.type === 'tool_call') {
dispatch(
updateToolCall({
index: targetIndex,
tool_call: data.data,
}),
);
} else if (data.type === 'error') {
dispatch(agentPreviewSlice.actions.setStatus('failed'));
dispatch(
agentPreviewSlice.actions.raiseError({
index: targetIndex,
message: data.error,
}),
);
} else if (data.type === 'structured_answer') {
dispatch(
updateStreamingQuery({
index: targetIndex,
query: {
response: data.answer,
structured: data.structured,
schema: data.schema,
},
}),
);
} else {
dispatch(
updateStreamingQuery({
index: targetIndex,
query: { response: data.answer },
}),
);
}
},
indx,
state.preference.selectedAgent?.id,
attachmentIds,
false, // Don't save preview conversations
);
} else {
// Non-streaming implementation
const answer = await handleFetchAnswer(
question,
signal,
state.preference.token,
state.preference.selectedDocs!,
null, // No conversation ID for previews
state.preference.prompt.id,
state.preference.chunks,
state.preference.token_limit,
state.preference.selectedAgent?.id,
attachmentIds,
false, // Don't save preview conversations
);
if (answer) {
const sourcesPrepped = answer.sources.map(
(source: { title: string }) => {
if (source && source.title) {
const titleParts = source.title.split('/');
return {
...source,
title: titleParts[titleParts.length - 1],
};
}
return source;
},
);
const targetIndex = indx ?? state.agentPreview.queries.length - 1;
dispatch(
updateQuery({
index: targetIndex,
query: {
response: answer.answer,
thought: answer.thought,
sources: sourcesPrepped,
tool_calls: answer.toolCalls,
},
}),
);
dispatch(agentPreviewSlice.actions.setStatus('idle'));
}
}
}
return {
conversationId: null,
title: null,
answer: '',
query: question,
result: '',
thought: '',
sources: [],
tool_calls: [],
};
},
);
export const agentPreviewSlice = createSlice({
name: 'agentPreview',
initialState,
reducers: {
addQuery(state, action: PayloadAction<Query>) {
state.queries.push(action.payload);
},
resendQuery(
state,
action: PayloadAction<{ index: number; prompt: string; query?: Query }>,
) {
state.queries = [
...state.queries.splice(0, action.payload.index),
action.payload,
];
},
updateStreamingQuery(
state,
action: PayloadAction<{
index: number;
query: Partial<Query>;
}>,
) {
const { index, query } = action.payload;
if (state.status === 'idle') return;
if (query.response != undefined) {
state.queries[index].response =
(state.queries[index].response || '') + query.response;
}
if (query.structured !== undefined) {
state.queries[index].structured = query.structured;
}
if (query.schema !== undefined) {
state.queries[index].schema = query.schema;
}
},
updateThought(
state,
action: PayloadAction<{
index: number;
query: Partial<Query>;
}>,
) {
const { index, query } = action.payload;
if (query.thought != undefined) {
state.queries[index].thought =
(state.queries[index].thought || '') + query.thought;
}
},
updateStreamingSource(
state,
action: PayloadAction<{
index: number;
query: Partial<Query>;
}>,
) {
const { index, query } = action.payload;
if (!state.queries[index].sources) {
state.queries[index].sources = query?.sources;
} else if (query.sources) {
state.queries[index].sources!.push(...query.sources);
}
},
updateToolCall(state, action) {
const { index, tool_call } = action.payload;
if (!state.queries[index].tool_calls) {
state.queries[index].tool_calls = [];
}
const existingIndex = state.queries[index].tool_calls.findIndex(
(call) => call.call_id === tool_call.call_id,
);
if (existingIndex !== -1) {
const existingCall = state.queries[index].tool_calls[existingIndex];
state.queries[index].tool_calls[existingIndex] = {
...existingCall,
...tool_call,
};
} else state.queries[index].tool_calls.push(tool_call);
},
updateQuery(
state,
action: PayloadAction<{ index: number; query: Partial<Query> }>,
) {
const { index, query } = action.payload;
state.queries[index] = {
...state.queries[index],
...query,
};
},
setStatus(state, action: PayloadAction<Status>) {
state.status = action.payload;
},
raiseError(
state,
action: PayloadAction<{
index: number;
message: string;
}>,
) {
const { index, message } = action.payload;
state.queries[index].error = message;
},
resetPreview: (state) => {
state.queries = initialState.queries;
state.status = initialState.status;
state.conversationId = initialState.conversationId;
handlePreviewAbort();
},
},
extraReducers(builder) {
builder
.addCase(fetchPreviewAnswer.pending, (state) => {
state.status = 'loading';
})
.addCase(fetchPreviewAnswer.rejected, (state, action) => {
if (action.meta.aborted) {
state.status = 'idle';
return state;
}
state.status = 'failed';
state.queries[state.queries.length - 1].error = 'Something went wrong';
});
},
});
type RootState = ReturnType<typeof store.getState>;
export const selectPreviewQueries = (state: RootState) =>
state.agentPreview.queries;
export const selectPreviewStatus = (state: RootState) =>
state.agentPreview.status;
export const {
addQuery,
updateQuery,
resendQuery,
updateStreamingQuery,
updateThought,
updateStreamingSource,
updateToolCall,
setStatus,
raiseError,
resetPreview,
} = agentPreviewSlice.actions;
export default agentPreviewSlice.reducer;

View File

@@ -111,10 +111,10 @@ function AgentsList() {
}, [token]);
return (
<div className="p-4 md:p-12">
<h1 className="text-eerie-black mb-0 text-[40px] font-bold dark:text-[#E0E0E0]">
<h1 className="mb-0 text-[40px] font-bold text-[#212121] dark:text-[#E0E0E0]">
Agents
</h1>
<p className="dark:text-gray-4000 mt-5 text-[15px] text-[#71717A]">
<p className="mt-5 text-[15px] text-[#71717A] dark:text-[#949494]">
Discover and create custom versions of DocsGPT that combine
instructions, extra knowledge, and any combination of skills
</p>
@@ -206,7 +206,7 @@ function AgentSection({
</div>
{sectionConfig[section].showNewAgentButton && (
<button
className="bg-purple-30 hover:bg-violets-are-blue rounded-full px-4 py-2 text-sm text-white"
className="rounded-full bg-purple-30 px-4 py-2 text-sm text-white hover:bg-violets-are-blue"
onClick={() => navigate('/agents/new')}
>
New Agent
@@ -235,7 +235,7 @@ function AgentSection({
<p>{sectionConfig[section].emptyStateDescription}</p>
{sectionConfig[section].showNewAgentButton && (
<button
className="bg-purple-30 hover:bg-violets-are-blue ml-2 rounded-full px-4 py-2 text-sm text-white"
className="ml-2 rounded-full bg-purple-30 px-4 py-2 text-sm text-white hover:bg-violets-are-blue"
onClick={() => navigate('/agents/new')}
>
New Agent
@@ -324,21 +324,17 @@ function AgentCard({
iconWidth: 14,
iconHeight: 14,
},
...(agent.status === 'published'
? [
{
icon: agent.pinned ? UnPin : Pin,
label: agent.pinned ? 'Unpin' : 'Pin agent',
onClick: (e: SyntheticEvent) => {
e.stopPropagation();
togglePin();
},
variant: 'primary' as const,
iconWidth: 18,
iconHeight: 18,
},
]
: []),
{
icon: agent.pinned ? UnPin : Pin,
label: agent.pinned ? 'Unpin' : 'Pin agent',
onClick: (e: SyntheticEvent) => {
e.stopPropagation();
togglePin();
},
variant: 'primary',
iconWidth: 18,
iconHeight: 18,
},
{
icon: Trash,
label: 'Delete',
@@ -410,7 +406,7 @@ function AgentCard({
};
return (
<div
className={`relative flex h-44 w-full flex-col justify-between rounded-[1.2rem] bg-[#F6F6F6] px-6 py-5 hover:bg-[#ECECEC] md:w-48 dark:bg-[#383838] dark:hover:bg-[#383838]/80 ${agent.status === 'published' && 'cursor-pointer'}`}
className={`relative flex h-44 w-full flex-col justify-between rounded-[1.2rem] bg-[#F6F6F6] px-6 py-5 hover:bg-[#ECECEC] dark:bg-[#383838] hover:dark:bg-[#383838]/80 md:w-48 ${agent.status === 'published' && 'cursor-pointer'}`}
onClick={(e) => {
e.stopPropagation();
handleClick();
@@ -422,7 +418,7 @@ function AgentCard({
e.stopPropagation();
setIsMenuOpen(true);
}}
className="absolute top-4 right-4 z-10 cursor-pointer"
className="absolute right-4 top-4 z-10 cursor-pointer"
>
<img src={ThreeDots} alt={'use-agent'} className="h-[19px] w-[19px]" />
<ContextMenu
@@ -430,16 +426,16 @@ function AgentCard({
setIsOpen={setIsMenuOpen}
options={menuOptions}
anchorRef={menuRef}
position="bottom-right"
position="top-right"
offset={{ x: 0, y: 0 }}
/>
</div>
<div className="w-full">
<div className="flex w-full items-center gap-1 px-1">
<img
src={agent.image && agent.image.trim() !== '' ? agent.image : Robot}
src={agent.image ?? Robot}
alt={`${agent.name}`}
className="h-7 w-7 rounded-full object-contain"
className="h-7 w-7 rounded-full"
/>
{agent.status === 'draft' && (
<p className="text-xs text-black opacity-50 dark:text-[#E0E0E0]">{`(Draft)`}</p>
@@ -448,11 +444,11 @@ function AgentCard({
<div className="mt-2">
<p
title={agent.name}
className="truncate px-1 text-[13px] leading-relaxed font-semibold text-[#020617] capitalize dark:text-[#E0E0E0]"
className="truncate px-1 text-[13px] font-semibold capitalize leading-relaxed text-[#020617] dark:text-[#E0E0E0]"
>
{agent.name}
</p>
<p className="dark:text-sonic-silver-light mt-1 h-20 overflow-auto px-1 text-[12px] leading-relaxed text-[#64748B]">
<p className="mt-1 h-20 overflow-auto px-1 text-[12px] leading-relaxed text-[#64748B] dark:text-sonic-silver-light">
{agent.description}
</p>
</div>

View File

@@ -10,7 +10,6 @@ export type Agent = {
description: string;
image: string;
source: string;
sources?: string[];
chunks: string;
retriever: string;
prompt_id: string;
@@ -27,5 +26,4 @@ export type Agent = {
created_at?: string;
updated_at?: string;
last_used_at?: string;
json_schema?: object;
};

View File

@@ -1,21 +1,16 @@
export const baseURL =
import.meta.env.VITE_API_HOST || 'https://docsapi.arc53.com';
const getHeaders = (
token: string | null,
customHeaders = {},
isFormData = false,
): HeadersInit => {
const headers: HeadersInit = {
const defaultHeaders = {
'Content-Type': 'application/json',
};
const getHeaders = (token: string | null, customHeaders = {}): HeadersInit => {
return {
...defaultHeaders,
...(token ? { Authorization: `Bearer ${token}` } : {}),
...customHeaders,
};
if (!isFormData) {
headers['Content-Type'] = 'application/json';
}
return headers;
};
const apiClient = {
@@ -49,21 +44,6 @@ const apiClient = {
return response;
}),
postFormData: (
url: string,
formData: FormData,
token: string | null,
headers = {},
signal?: AbortSignal,
): Promise<Response> => {
return fetch(`${baseURL}${url}`, {
method: 'POST',
headers: getHeaders(token, headers, true),
body: formData,
signal,
});
},
put: (
url: string,
data: any,
@@ -80,21 +60,6 @@ const apiClient = {
return response;
}),
putFormData: (
url: string,
formData: FormData,
token: string | null,
headers = {},
signal?: AbortSignal,
): Promise<Response> => {
return fetch(`${baseURL}${url}`, {
method: 'PUT',
headers: getHeaders(token, headers, true),
body: formData,
signal,
});
},
delete: (
url: string,
token: string | null,

View File

@@ -38,29 +38,13 @@ const endpoints = {
UPDATE_TOOL_STATUS: '/api/update_tool_status',
UPDATE_TOOL: '/api/update_tool',
DELETE_TOOL: '/api/delete_tool',
SYNC_CONNECTOR: '/api/connectors/sync',
GET_CHUNKS: (
docId: string,
page: number,
per_page: number,
path?: string,
search?: string,
) =>
`/api/get_chunks?id=${docId}&page=${page}&per_page=${per_page}${
path ? `&path=${encodeURIComponent(path)}` : ''
}${search ? `&search=${encodeURIComponent(search)}` : ''}`,
GET_CHUNKS: (docId: string, page: number, per_page: number) =>
`/api/get_chunks?id=${docId}&page=${page}&per_page=${per_page}`,
ADD_CHUNK: '/api/add_chunk',
DELETE_CHUNK: (docId: string, chunkId: string) =>
`/api/delete_chunk?id=${docId}&chunk_id=${chunkId}`,
UPDATE_CHUNK: '/api/update_chunk',
STORE_ATTACHMENT: '/api/store_attachment',
DIRECTORY_STRUCTURE: (docId: string) =>
`/api/directory_structure?id=${docId}`,
MANAGE_SOURCE_FILES: '/api/manage_source_files',
MCP_TEST_CONNECTION: '/api/mcp_server/test',
MCP_SAVE_SERVER: '/api/mcp_server/save',
MCP_OAUTH_STATUS: (task_id: string) =>
`/api/mcp_server/oauth_status/${task_id}`,
},
CONVERSATION: {
ANSWER: '/api/answer',

View File

@@ -1,4 +1,3 @@
import { getSessionToken } from '../../utils/providerUtils';
import apiClient from '../client';
import endpoints from '../endpoints';
@@ -23,13 +22,13 @@ const userService = {
getAgents: (token: string | null): Promise<any> =>
apiClient.get(endpoints.USER.AGENTS, token),
createAgent: (data: any, token: string | null): Promise<any> =>
apiClient.postFormData(endpoints.USER.CREATE_AGENT, data, token),
apiClient.post(endpoints.USER.CREATE_AGENT, data, token),
updateAgent: (
agent_id: string,
data: any,
token: string | null,
): Promise<any> =>
apiClient.putFormData(endpoints.USER.UPDATE_AGENT(agent_id), data, token),
apiClient.put(endpoints.USER.UPDATE_AGENT(agent_id), data, token),
deleteAgent: (id: string, token: string | null): Promise<any> =>
apiClient.delete(endpoints.USER.DELETE_AGENT(id), token),
getPinnedAgents: (token: string | null): Promise<any> =>
@@ -87,13 +86,8 @@ const userService = {
page: number,
perPage: number,
token: string | null,
path?: string,
search?: string,
): Promise<any> =>
apiClient.get(
endpoints.USER.GET_CHUNKS(docId, page, perPage, path, search),
token,
),
apiClient.get(endpoints.USER.GET_CHUNKS(docId, page, perPage), token),
addChunk: (data: any, token: string | null): Promise<any> =>
apiClient.post(endpoints.USER.ADD_CHUNK, data, token),
deleteChunk: (
@@ -104,32 +98,6 @@ const userService = {
apiClient.delete(endpoints.USER.DELETE_CHUNK(docId, chunkId), token),
updateChunk: (data: any, token: string | null): Promise<any> =>
apiClient.put(endpoints.USER.UPDATE_CHUNK, data, token),
getDirectoryStructure: (docId: string, token: string | null): Promise<any> =>
apiClient.get(endpoints.USER.DIRECTORY_STRUCTURE(docId), token),
manageSourceFiles: (data: FormData, token: string | null): Promise<any> =>
apiClient.postFormData(endpoints.USER.MANAGE_SOURCE_FILES, data, token),
testMCPConnection: (data: any, token: string | null): Promise<any> =>
apiClient.post(endpoints.USER.MCP_TEST_CONNECTION, data, token),
saveMCPServer: (data: any, token: string | null): Promise<any> =>
apiClient.post(endpoints.USER.MCP_SAVE_SERVER, data, token),
getMCPOAuthStatus: (task_id: string, token: string | null): Promise<any> =>
apiClient.get(endpoints.USER.MCP_OAUTH_STATUS(task_id), token),
syncConnector: (
docId: string,
provider: string,
token: string | null,
): Promise<any> => {
const sessionToken = getSessionToken(provider);
return apiClient.post(
endpoints.USER.SYNC_CONNECTOR,
{
source_id: docId,
session_token: sessionToken,
provider: provider,
},
token,
);
},
};
export default userService;

View File

@@ -1,4 +0,0 @@
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M6 7.5C6 7.36739 5.94732 7.24021 5.85355 7.14645C5.75979 7.05268 5.63261 7 5.5 7H4.5C4.36739 7 4.24021 7.05268 4.14645 7.14645C4.05268 7.24021 4 7.36739 4 7.5V8.5C4 8.63261 4.05268 8.75979 4.14645 8.85355C4.24021 8.94732 4.36739 9 4.5 9H5.5C5.63261 9 5.75979 8.94732 5.85355 8.85355C5.94732 8.75979 6 8.63261 6 8.5V7.5ZM6 10.5C6 10.3674 5.94732 10.2402 5.85355 10.1464C5.75979 10.0527 5.63261 10 5.5 10H4.5C4.36739 10 4.24021 10.0527 4.14645 10.1464C4.05268 10.2402 4 10.3674 4 10.5V11.5C4 11.6326 4.05268 11.7598 4.14645 11.8536C4.24021 11.9473 4.36739 12 4.5 12H5.5C5.63261 12 5.75979 11.9473 5.85355 11.8536C5.94732 11.7598 6 11.6326 6 11.5V10.5ZM7.5 7H8.5C8.63261 7 8.75979 7.05268 8.85355 7.14645C8.94732 7.24021 9 7.36739 9 7.5V8.5C9 8.63261 8.94732 8.75979 8.85355 8.85355C8.75979 8.94732 8.63261 9 8.5 9H7.5C7.36739 9 7.24021 8.94732 7.14645 8.85355C7.05268 8.75979 7 8.63261 7 8.5V7.5C7 7.36739 7.05268 7.24021 7.14645 7.14645C7.24021 7.05268 7.36739 7 7.5 7ZM8.5 10H7.5C7.36739 10 7.24021 10.0527 7.14645 10.1464C7.05268 10.2402 7 10.3674 7 10.5V11.5C7 11.6326 7.05268 11.7598 7.14645 11.8536C7.24021 11.9473 7.36739 12 7.5 12H8.5C8.63261 12 8.75979 11.9473 8.85355 11.8536C8.94732 11.7598 9 11.6326 9 11.5V10.5C9 10.3674 8.94732 10.2402 8.85355 10.1464C8.75979 10.0527 8.63261 10 8.5 10ZM10 7.5C10 7.36739 10.0527 7.24021 10.1464 7.14645C10.2402 7.05268 10.3674 7 10.5 7H11.5C11.6326 7 11.7598 7.05268 11.8536 7.14645C11.9473 7.24021 12 7.36739 12 7.5V8.5C12 8.63261 11.9473 8.75979 11.8536 8.85355C11.7598 8.94732 11.6326 9 11.5 9H10.5C10.3674 9 10.2402 8.94732 10.1464 8.85355C10.0527 8.75979 10 8.63261 10 8.5V7.5Z" fill="#848484"/>
<path fill-rule="evenodd" clip-rule="evenodd" d="M4.5 0C4.63261 0 4.75979 0.0526784 4.85355 0.146447C4.94732 0.240215 5 0.367392 5 0.5V1H11V0.5C11 0.367392 11.0527 0.240215 11.1464 0.146447C11.2402 0.0526784 11.3674 0 11.5 0C11.6326 0 11.7598 0.0526784 11.8536 0.146447C11.9473 0.240215 12 0.367392 12 0.5V1C13.66 1 15 2.34 15 4V12C15 13.66 13.66 15 12 15H4C2.34 15 1 13.66 1 12V4C1 2.34 2.34 1 4 1V0.5C4 0.367392 4.05268 0.240215 4.14645 0.146447C4.24021 0.0526784 4.36739 0 4.5 0ZM14 4V5H2V4C2 2.9 2.895 2 4 2V2.5C4 2.63261 4.05268 2.75979 4.14645 2.85355C4.24021 2.94732 4.36739 3 4.5 3C4.63261 3 4.75979 2.94732 4.85355 2.85355C4.94732 2.75979 5 2.63261 5 2.5V2H11V2.5C11 2.63261 11.0527 2.75979 11.1464 2.85355C11.2402 2.94732 11.3674 3 11.5 3C11.6326 3 11.7598 2.94732 11.8536 2.85355C11.9473 2.75979 12 2.63261 12 2.5V2C13.1 2 14 2.895 14 4ZM2 12V6H14V12C14 13.1 13.105 14 12 14H4C2.9 14 2 13.105 2 12Z" fill="#848484"/>
</svg>

Before

Width:  |  Height:  |  Size: 2.6 KiB

View File

@@ -1 +1 @@
<svg width="16px" height="16px" viewBox="0 0 1024 1024" class="icon" version="1.1" xmlns="http://www.w3.org/2000/svg" fill="#11ee1c" stroke="#11ee1c" stroke-width="83.96799999999999"><g id="SVGRepo_bgCarrier" stroke-width="0"></g><g id="SVGRepo_tracerCarrier" stroke-linecap="round" stroke-linejoin="round"></g><g id="SVGRepo_iconCarrier"><path d="M866.133333 258.133333L362.666667 761.6l-204.8-204.8L98.133333 618.666667 362.666667 881.066667l563.2-563.2z" fill="#0C9D35"></path></g></svg>
<svg width="16px" height="16px" viewBox="0 0 1024 1024" class="icon" version="1.1" xmlns="http://www.w3.org/2000/svg" fill="#11ee1c" stroke="#11ee1c" stroke-width="83.96799999999999"><g id="SVGRepo_bgCarrier" stroke-width="0"></g><g id="SVGRepo_tracerCarrier" stroke-linecap="round" stroke-linejoin="round"></g><g id="SVGRepo_iconCarrier"><path d="M866.133333 258.133333L362.666667 761.6l-204.8-204.8L98.133333 618.666667 362.666667 881.066667l563.2-563.2z" fill="#11ee1c"></path></g></svg>

Before

Width:  |  Height:  |  Size: 490 B

After

Width:  |  Height:  |  Size: 490 B

View File

@@ -1,3 +0,0 @@
<svg width="22" height="22" viewBox="0 0 22 22" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M20.2891 15.81L21.7091 14.39L18.4991 11.21L15.4991 10.36L17.4091 10.1L21.5991 6.89999L20.3991 5.29998L16.5891 8.14999L13.9091 8.59999L17.1091 5.40999L15.9991 0.859985L13.9991 1.33999L14.8591 4.78999L13.7591 5.92999C13.5285 5.38882 13.144 4.92736 12.6533 4.60302C12.1625 4.27867 11.5873 4.10574 10.9991 4.10574C10.4108 4.10574 9.83559 4.27867 9.34487 4.60302C8.85414 4.92736 8.4696 5.38882 8.23906 5.92999L7.10906 4.78999L7.99906 1.33999L5.99906 0.859985L4.88906 5.40999L8.08906 8.59999L5.39906 8.14999L1.59906 5.29998L0.399063 6.89999L4.59906 10.1L6.45906 10.41L3.45906 11.26L0.289062 14.39L1.70906 15.81L4.49906 12.99L6.86906 12.32L2.99906 15.64V21.1H4.99906V16.56L6.55906 15.22C6.73264 16.2723 7.27432 17.2287 8.08751 17.9188C8.90071 18.6088 9.93255 18.9876 10.9991 18.9876C12.0656 18.9876 13.0974 18.6088 13.9106 17.9188C14.7238 17.2287 15.2655 16.2723 15.4391 15.22L16.9991 16.56V21.1H18.9991V15.64L15.1291 12.32L17.4991 12.99L20.2891 15.81Z" fill="black"/>
</svg>

Before

Width:  |  Height:  |  Size: 1.0 KiB

Some files were not shown because too many files have changed in this diff Show More