Compare commits

...

31 Commits

Author SHA1 Message Date
Alex
fcdb4fb5e8 feat: faster ebook parsing 2026-04-09 18:31:06 +01:00
Alex
e787c896eb upd Security.md 2026-04-08 12:49:20 +01:00
Alex
23aeaff5db Merge pull request #2362 from arc53/v1-mini-improvements
feat: history overwrite
2026-04-06 15:02:32 +01:00
Alex
689dd79597 fix: lang 2026-04-06 14:57:51 +01:00
Alex
0c15af90b1 feat: history overwrite 2026-04-06 14:42:01 +01:00
Alex
cdd6ff6557 chore: bump deps 2026-04-04 12:45:34 +01:00
Alex
72b3d94453 fix: tests 2026-04-03 18:30:46 +01:00
Alex
7e88d09e5d Merge branch 'main' of https://github.com/arc53/DocsGPT 2026-04-03 18:26:37 +01:00
Alex
74a4a237dc fix: bump deps 2026-04-03 18:26:29 +01:00
Alex
c3f01c6619 Merge pull request #2347 from ManishMadan2882/main
Minor frontend updates
2026-04-03 18:17:27 +01:00
Alex
6b408823d4 fix: mini theme color edits 2026-04-03 18:16:07 +01:00
Alex
3fc81ac5d8 fix: clean error 2026-04-03 18:08:38 +01:00
Alex
2652f8a5b0 fix: chatwoot 2026-04-03 18:04:49 +01:00
Alex
d711eefe96 patch: agent usage limits 2026-04-03 18:03:31 +01:00
Alex
79206f3919 fix: harden faiss 2026-04-03 17:57:49 +01:00
Alex
de971d9452 fix: validate mcp url 2026-04-03 17:52:48 +01:00
Alex
1b4d5ca0dd patch: mcp identity 2026-04-03 17:40:22 +01:00
Alex
81989e8258 fix: patch /v1/models 2026-04-03 17:37:09 +01:00
Alex
dc262d1698 patch: error 2026-04-03 17:30:23 +01:00
Alex
69f9c93869 patch: s3 2026-04-03 17:28:09 +01:00
Alex
74bf80b25c patch: sharing convos 2026-04-03 17:20:06 +01:00
Alex
d9a92a7208 feat: improve setup scripts 2026-04-03 17:15:21 +01:00
Alex
02e93d993d patch: available tools 2026-04-03 17:12:36 +01:00
Alex
6b6495f48c patch: key 2026-04-03 17:06:35 +01:00
Alex
249dd9ce37 patch: paths 2026-04-03 16:45:03 +01:00
Alex
9134ab0478 Merge branch 'main' of https://github.com/arc53/DocsGPT 2026-04-03 16:40:50 +01:00
Alex
10ef68c9d0 Revise vulnerability reporting process
Updated vulnerability reporting instructions to use GitHub's private reporting flow.
2026-04-03 16:36:10 +01:00
Alex
7d65cf1c2b chore: bump deps 2026-04-03 16:35:10 +01:00
Alex
13c6cc59c1 Merge pull request #2349 from arc53/messages-format
Messages format
2026-04-03 16:26:57 +01:00
ManishMadan2882
648b3f1d20 (fix) lint/fe 2026-04-01 03:30:44 +05:30
ManishMadan2882
a75a9e23f9 (feat:fe) minor good things 2026-04-01 03:19:03 +05:30
54 changed files with 1256 additions and 304 deletions

View File

@@ -2,13 +2,21 @@
## Supported Versions
Supported Versions:
Currently, we support security patches by committing changes and bumping the version published on Github.
Security patches target the latest release and the `main` branch. We recommend always running the most recent version.
## Reporting a Vulnerability
Found a vulnerability? Please email us:
Preferred method: use GitHub's private vulnerability reporting flow:
https://github.com/arc53/DocsGPT/security
security@arc53.com
Then click **Report a vulnerability**.
Alternatively, email us at: security@arc53.com
We aim to acknowledge reports within 48 hours.
## Incident Handling
Arc53 maintains internal incident response procedures. If you believe an active exploit is occurring, include **URGENT** in your report subject line.

View File

@@ -24,6 +24,7 @@ from application.api.user.tasks import mcp_oauth_status_task, mcp_oauth_task
from application.cache import get_redis_instance
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.core.url_validation import SSRFError, validate_url
from application.security.encryption import decrypt_credentials
logger = logging.getLogger(__name__)
@@ -61,7 +62,8 @@ class MCPTool(Tool):
"""
self.config = config
self.user_id = user_id
self.server_url = config.get("server_url", "")
raw_url = config.get("server_url", "")
self.server_url = self._validate_server_url(raw_url) if raw_url else ""
self.transport_type = config.get("transport_type", "auto")
self.auth_type = config.get("auth_type", "none")
self.timeout = config.get("timeout", 30)
@@ -87,6 +89,18 @@ class MCPTool(Tool):
if self.server_url and self.auth_type != "oauth":
self._setup_client()
@staticmethod
def _validate_server_url(server_url: str) -> str:
"""Validate server_url to prevent SSRF to internal networks.
Raises:
ValueError: If the URL points to a private/internal address.
"""
try:
return validate_url(server_url)
except SSRFError as exc:
raise ValueError(f"Invalid MCP server URL: {exc}") from exc
def _resolve_redirect_uri(self, configured_redirect_uri: Optional[str]) -> str:
if configured_redirect_uri:
return configured_redirect_uri.rstrip("/")
@@ -108,8 +122,9 @@ class MCPTool(Tool):
auth_key = ""
if self.auth_type == "oauth":
scopes_str = ",".join(self.oauth_scopes) if self.oauth_scopes else "none"
oauth_identity = self.user_id or self.oauth_task_id or "anonymous"
auth_key = (
f"oauth:{self.oauth_client_name}:{scopes_str}:{self.redirect_uri}"
f"oauth:{oauth_identity}:{self.oauth_client_name}:{scopes_str}:{self.redirect_uri}"
)
elif self.auth_type in ["bearer"]:
token = self.auth_credentials.get(

View File

@@ -85,6 +85,10 @@ class AnswerResource(Resource, BaseAnswerResource):
) = processor.resume_from_tool_actions(
data["tool_actions"], data["conversation_id"]
)
if not processor.decoded_token:
return make_response({"error": "Unauthorized"}, 401)
if error := self.check_usage(processor.agent_config):
return error
stream = self.complete_stream(
question="",
agent=agent,

View File

@@ -92,6 +92,14 @@ class StreamResource(Resource, BaseAnswerResource):
) = processor.resume_from_tool_actions(
data["tool_actions"], data["conversation_id"]
)
if not processor.decoded_token:
return Response(
self.error_stream_generate("Unauthorized"),
status=401,
mimetype="text/event-stream",
)
if error := self.check_usage(processor.agent_config):
return error
return Response(
self.complete_stream(
question="",

View File

@@ -112,6 +112,7 @@ class StreamProcessor:
self._required_tool_actions: Optional[Dict[str, Set[Optional[str]]]] = None
self.compressed_summary: Optional[str] = None
self.compressed_summary_tokens: int = 0
self._agent_data: Optional[Dict[str, Any]] = None
def initialize(self):
"""Initialize all required components for processing"""
@@ -359,22 +360,29 @@ class StreamProcessor:
return data
def _configure_source(self):
"""Configure the source based on agent data"""
api_key = self.data.get("api_key") or self.agent_key
"""Configure the source based on agent data.
if api_key:
agent_data = self._get_data_from_api_key(api_key)
The literal string ``"default"`` is a placeholder meaning "no
ingested source" and is normalized to an empty source so that no
retrieval is attempted.
"""
if self._agent_data:
agent_data = self._agent_data
if agent_data.get("sources") and len(agent_data["sources"]) > 0:
source_ids = [
source["id"] for source in agent_data["sources"] if source.get("id")
source["id"]
for source in agent_data["sources"]
if source.get("id") and source["id"] != "default"
]
if source_ids:
self.source = {"active_docs": source_ids}
else:
self.source = {}
self.all_sources = agent_data["sources"]
elif agent_data.get("source"):
self.all_sources = [
s for s in agent_data["sources"] if s.get("id") != "default"
]
elif agent_data.get("source") and agent_data["source"] != "default":
self.source = {"active_docs": agent_data["source"]}
self.all_sources = [
{
@@ -387,11 +395,24 @@ class StreamProcessor:
self.all_sources = []
return
if "active_docs" in self.data:
self.source = {"active_docs": self.data["active_docs"]}
active_docs = self.data["active_docs"]
if active_docs and active_docs != "default":
self.source = {"active_docs": active_docs}
else:
self.source = {}
return
self.source = {}
self.all_sources = []
def _has_active_docs(self) -> bool:
"""Return True if a real document source is configured for retrieval."""
active_docs = self.source.get("active_docs") if self.source else None
if not active_docs:
return False
if active_docs == "default":
return False
return True
def _resolve_agent_id(self) -> Optional[str]:
"""Resolve agent_id from request, then fall back to conversation context."""
request_agent_id = self.data.get("agent_id")
@@ -433,48 +454,39 @@ class StreamProcessor:
effective_key = self.data.get("api_key") or self.agent_key
if effective_key:
data_key = self._get_data_from_api_key(effective_key)
if data_key.get("_id"):
self.agent_id = str(data_key.get("_id"))
self._agent_data = self._get_data_from_api_key(effective_key)
if self._agent_data.get("_id"):
self.agent_id = str(self._agent_data.get("_id"))
self.agent_config.update(
{
"prompt_id": data_key.get("prompt_id", "default"),
"agent_type": data_key.get("agent_type", settings.AGENT_NAME),
"prompt_id": self._agent_data.get("prompt_id", "default"),
"agent_type": self._agent_data.get("agent_type", settings.AGENT_NAME),
"user_api_key": effective_key,
"json_schema": data_key.get("json_schema"),
"default_model_id": data_key.get("default_model_id", ""),
"models": data_key.get("models", []),
"json_schema": self._agent_data.get("json_schema"),
"default_model_id": self._agent_data.get("default_model_id", ""),
"models": self._agent_data.get("models", []),
"allow_system_prompt_override": self._agent_data.get(
"allow_system_prompt_override", False
),
}
)
# Set identity context
if self.data.get("api_key"):
# External API key: use the key owner's identity
self.initial_user_id = data_key.get("user")
self.decoded_token = {"sub": data_key.get("user")}
self.initial_user_id = self._agent_data.get("user")
self.decoded_token = {"sub": self._agent_data.get("user")}
elif self.is_shared_usage:
# Shared agent: keep the caller's identity
pass
else:
# Owner using their own agent
self.decoded_token = {"sub": data_key.get("user")}
self.decoded_token = {"sub": self._agent_data.get("user")}
if data_key.get("source"):
self.source = {"active_docs": data_key["source"]}
if data_key.get("workflow"):
self.agent_config["workflow"] = data_key["workflow"]
self.agent_config["workflow_owner"] = data_key.get("user")
if data_key.get("retriever"):
self.retriever_config["retriever_name"] = data_key["retriever"]
if data_key.get("chunks") is not None:
try:
self.retriever_config["chunks"] = int(data_key["chunks"])
except (ValueError, TypeError):
logger.warning(
f"Invalid chunks value: {data_key['chunks']}, using default value 2"
)
self.retriever_config["chunks"] = 2
if self._agent_data.get("workflow"):
self.agent_config["workflow"] = self._agent_data["workflow"]
self.agent_config["workflow_owner"] = self._agent_data.get("user")
else:
# No API key — default/workflow configuration
agent_type = settings.AGENT_NAME
@@ -497,14 +509,45 @@ class StreamProcessor:
)
def _configure_retriever(self):
"""Assemble retriever config with precedence: request > agent > default."""
doc_token_limit = calculate_doc_token_budget(model_id=self.model_id)
# Start with defaults
retriever_name = "classic"
chunks = 2
# Layer agent-level config (if present)
if self._agent_data:
if self._agent_data.get("retriever"):
retriever_name = self._agent_data["retriever"]
if self._agent_data.get("chunks") is not None:
try:
chunks = int(self._agent_data["chunks"])
except (ValueError, TypeError):
logger.warning(
f"Invalid agent chunks value: {self._agent_data['chunks']}, "
"using default value 2"
)
# Explicit request values win over agent config
if "retriever" in self.data:
retriever_name = self.data["retriever"]
if "chunks" in self.data:
try:
chunks = int(self.data["chunks"])
except (ValueError, TypeError):
logger.warning(
f"Invalid request chunks value: {self.data['chunks']}, "
"using default value 2"
)
self.retriever_config = {
"retriever_name": self.data.get("retriever", "classic"),
"chunks": int(self.data.get("chunks", 2)),
"retriever_name": retriever_name,
"chunks": chunks,
"doc_token_limit": doc_token_limit,
}
# isNoneDoc without an API key forces no retrieval
api_key = self.data.get("api_key") or self.agent_key
if not api_key and "isNoneDoc" in self.data and self.data["isNoneDoc"]:
self.retriever_config["chunks"] = 0
@@ -528,6 +571,9 @@ class StreamProcessor:
if self.data.get("isNoneDoc", False) and not self.agent_id:
logger.info("Pre-fetch skipped: isNoneDoc=True")
return None, None
if not self._has_active_docs():
logger.info("Pre-fetch skipped: no active docs configured")
return None, None
try:
retriever = self.create_retriever()
logger.info(
@@ -910,15 +956,23 @@ class StreamProcessor:
raw_prompt = get_prompt(prompt_id, self.prompts_collection)
self._prompt_content = raw_prompt
rendered_prompt = self.prompt_renderer.render_prompt(
prompt_content=raw_prompt,
user_id=self.initial_user_id,
request_id=self.data.get("request_id"),
passthrough_data=self.data.get("passthrough"),
docs=docs,
docs_together=docs_together,
tools_data=tools_data,
)
# Allow API callers to override the system prompt when the agent
# has opted in via allow_system_prompt_override.
if (
self.agent_config.get("allow_system_prompt_override", False)
and self.data.get("system_prompt_override")
):
rendered_prompt = self.data["system_prompt_override"]
else:
rendered_prompt = self.prompt_renderer.render_prompt(
prompt_content=raw_prompt,
user_id=self.initial_user_id,
request_id=self.data.get("request_id"),
passthrough_data=self.data.get("passthrough"),
docs=docs,
docs_together=docs_together,
tools_data=tools_data,
)
provider = (
get_provider_from_model_id(self.model_id)

View File

@@ -26,12 +26,20 @@ internal = Blueprint("internal", __name__)
@internal.before_request
def verify_internal_key():
"""Verify INTERNAL_KEY for all internal endpoint requests."""
if settings.INTERNAL_KEY:
internal_key = request.headers.get("X-Internal-Key")
if not internal_key or internal_key != settings.INTERNAL_KEY:
logger.warning(f"Unauthorized internal API access attempt from {request.remote_addr}")
return jsonify({"error": "Unauthorized", "message": "Invalid or missing internal key"}), 401
"""Verify INTERNAL_KEY for all internal endpoint requests.
Deny by default: if INTERNAL_KEY is not configured, reject all requests.
"""
if not settings.INTERNAL_KEY:
logger.warning(
f"Internal API request rejected from {request.remote_addr}: "
"INTERNAL_KEY is not configured"
)
return jsonify({"error": "Unauthorized", "message": "Internal API is not configured"}), 401
internal_key = request.headers.get("X-Internal-Key")
if not internal_key or internal_key != settings.INTERNAL_KEY:
logger.warning(f"Unauthorized internal API access attempt from {request.remote_addr}")
return jsonify({"error": "Unauthorized", "message": "Invalid or missing internal key"}), 401
@internal.route("/api/download", methods=["get"])

View File

@@ -73,6 +73,7 @@ AGENT_TYPE_SCHEMAS = {
"token_limit",
"limited_request_mode",
"request_limit",
"allow_system_prompt_override",
"createdAt",
"updatedAt",
"lastUsedAt",
@@ -96,6 +97,7 @@ AGENT_TYPE_SCHEMAS = {
"token_limit",
"limited_request_mode",
"request_limit",
"allow_system_prompt_override",
"createdAt",
"updatedAt",
"lastUsedAt",
@@ -220,6 +222,12 @@ def build_agent_document(
base_doc["request_limit"] = int(
data.get("request_limit", settings.DEFAULT_AGENT_LIMITS["request_limit"])
)
if "allow_system_prompt_override" in allowed_fields:
base_doc["allow_system_prompt_override"] = (
data.get("allow_system_prompt_override") == "True"
if isinstance(data.get("allow_system_prompt_override"), str)
else bool(data.get("allow_system_prompt_override", False))
)
return {k: v for k, v in base_doc.items() if k in allowed_fields}
@@ -292,6 +300,9 @@ class GetAgent(Resource):
"default_model_id": agent.get("default_model_id", ""),
"folder_id": agent.get("folder_id"),
"workflow": agent.get("workflow"),
"allow_system_prompt_override": agent.get(
"allow_system_prompt_override", False
),
}
return make_response(jsonify(data), 200)
except Exception as e:
@@ -373,6 +384,9 @@ class GetAgents(Resource):
"default_model_id": agent.get("default_model_id", ""),
"folder_id": agent.get("folder_id"),
"workflow": agent.get("workflow"),
"allow_system_prompt_override": agent.get(
"allow_system_prompt_override", False
),
}
for agent in agents
if "source" in agent
@@ -450,6 +464,10 @@ class CreateAgent(Resource):
"folder_id": fields.String(
required=False, description="Folder ID to organize the agent"
),
"allow_system_prompt_override": fields.Boolean(
required=False,
description="Allow API callers to override the system prompt via the v1 endpoint",
),
},
)
@@ -491,9 +509,9 @@ class CreateAgent(Resource):
data["json_schema"] = normalize_json_schema_payload(
data.get("json_schema")
)
except JsonSchemaValidationError as exc:
except JsonSchemaValidationError:
return make_response(
jsonify({"success": False, "message": f"JSON schema {exc}"}),
jsonify({"success": False, "message": "Invalid JSON schema"}),
400,
)
if data.get("status") not in ["draft", "published"]:
@@ -674,6 +692,10 @@ class UpdateAgent(Resource):
"folder_id": fields.String(
required=False, description="Folder ID to organize the agent"
),
"allow_system_prompt_override": fields.Boolean(
required=False,
description="Allow API callers to override the system prompt via the v1 endpoint",
),
},
)
@@ -765,6 +787,7 @@ class UpdateAgent(Resource):
"default_model_id",
"folder_id",
"workflow",
"allow_system_prompt_override",
]
for field in allowed_fields:
@@ -872,9 +895,9 @@ class UpdateAgent(Resource):
update_fields[field] = normalize_json_schema_payload(
json_schema
)
except JsonSchemaValidationError as exc:
except JsonSchemaValidationError:
return make_response(
jsonify({"success": False, "message": f"JSON schema {exc}"}),
jsonify({"success": False, "message": "Invalid JSON schema"}),
400,
)
else:
@@ -983,6 +1006,13 @@ class UpdateAgent(Resource):
if workflow_error:
return workflow_error
update_fields[field] = workflow_id
elif field == "allow_system_prompt_override":
raw_value = data.get("allow_system_prompt_override", False)
update_fields[field] = (
raw_value == "True"
if isinstance(raw_value, str)
else bool(raw_value)
)
else:
value = data[field]
if field in ["name", "description", "prompt_id", "agent_type"]:

View File

@@ -612,6 +612,10 @@ class LiveSpeechToTextFinish(Resource):
class ServeImage(Resource):
@api.doc(description="Serve an image from storage")
def get(self, image_path):
if ".." in image_path or image_path.startswith("/") or "\x00" in image_path:
return make_response(
jsonify({"success": False, "message": "Invalid image path"}), 400
)
try:
from application.api.user.base import storage
@@ -629,6 +633,10 @@ class ServeImage(Resource):
return make_response(
jsonify({"success": False, "message": "Image not found"}), 404
)
except ValueError:
return make_response(
jsonify({"success": False, "message": "Invalid image path"}), 400
)
except Exception as e:
current_app.logger.error(f"Error serving image: {e}")
return make_response(

View File

@@ -57,7 +57,7 @@ class ShareConversation(Resource):
try:
conversation = conversations_collection.find_one(
{"_id": ObjectId(conversation_id)}
{"_id": ObjectId(conversation_id), "user": user}
)
if conversation is None:
return make_response(

View File

@@ -463,6 +463,16 @@ class ManageSourceFiles(Resource):
removed_files = []
map_updated = False
for file_path in file_paths:
if ".." in str(file_path) or str(file_path).startswith("/"):
return make_response(
jsonify(
{
"success": False,
"message": "Invalid file path",
}
),
400,
)
full_path = f"{source_file_path}/{file_path}"
# Remove from storage

View File

@@ -14,6 +14,7 @@ from application.api.user.tools.routes import transform_actions
from application.cache import get_redis_instance
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.core.url_validation import SSRFError, validate_url
from application.security.encryption import decrypt_credentials, encrypt_credentials
from application.utils import check_required_fields
@@ -63,6 +64,21 @@ def _extract_auth_credentials(config):
return auth_credentials
def _validate_mcp_server_url(config: dict) -> None:
"""Validate the server_url in an MCP config to prevent SSRF.
Raises:
ValueError: If the URL is missing or points to a blocked address.
"""
server_url = (config.get("server_url") or "").strip()
if not server_url:
raise ValueError("server_url is required")
try:
validate_url(server_url)
except SSRFError as exc:
raise ValueError(f"Invalid server URL: {exc}") from exc
@tools_mcp_ns.route("/mcp_server/test")
class TestMCPServerConfig(Resource):
@api.expect(
@@ -97,6 +113,8 @@ class TestMCPServerConfig(Resource):
400,
)
_validate_mcp_server_url(config)
auth_credentials = _extract_auth_credentials(config)
test_config = config.copy()
test_config["auth_credentials"] = auth_credentials
@@ -105,15 +123,41 @@ class TestMCPServerConfig(Resource):
result = mcp_tool.test_connection()
if result.get("requires_oauth"):
return make_response(jsonify(result), 200)
safe_result = {
k: v
for k, v in result.items()
if k in ("success", "requires_oauth", "auth_url")
}
return make_response(jsonify(safe_result), 200)
if not result.get("success") and "message" in result:
if not result.get("success"):
current_app.logger.error(
f"MCP connection test failed: {result.get('message')}"
)
result["message"] = "Connection test failed"
return make_response(
jsonify(
{
"success": False,
"message": "Connection test failed",
"tools_count": 0,
}
),
200,
)
return make_response(jsonify(result), 200)
safe_result = {
"success": True,
"message": result.get("message", "Connection successful"),
"tools_count": result.get("tools_count", 0),
"tools": result.get("tools", []),
}
return make_response(jsonify(safe_result), 200)
except ValueError as e:
current_app.logger.warning(f"Invalid MCP server test request: {e}")
return make_response(
jsonify({"success": False, "error": "Invalid MCP server configuration"}),
400,
)
except Exception as e:
current_app.logger.error(f"Error testing MCP server: {e}", exc_info=True)
return make_response(
@@ -165,6 +209,8 @@ class MCPServerSave(Resource):
400,
)
_validate_mcp_server_url(config)
auth_credentials = _extract_auth_credentials(config)
auth_type = config.get("auth_type", "none")
mcp_config = config.copy()
@@ -279,6 +325,12 @@ class MCPServerSave(Resource):
"tools_count": len(transformed_actions),
}
return make_response(jsonify(response_data), 200)
except ValueError as e:
current_app.logger.warning(f"Invalid MCP server save request: {e}")
return make_response(
jsonify({"success": False, "error": "Invalid MCP server configuration"}),
400,
)
except Exception as e:
current_app.logger.error(f"Error saving MCP server: {e}", exc_info=True)
return make_response(

View File

@@ -8,6 +8,7 @@ from application.agents.tools.spec_parser import parse_spec
from application.agents.tools.tool_manager import ToolManager
from application.api import api
from application.api.user.base import user_tools_collection
from application.core.url_validation import SSRFError, validate_url
from application.security.encryption import decrypt_credentials, encrypt_credentials
from application.utils import check_required_fields, validate_function_name
@@ -130,6 +131,8 @@ tools_ns = Namespace("tools", description="Tool management operations", path="/a
class AvailableTools(Resource):
@api.doc(description="Get available tools for a user")
def get(self):
if not request.decoded_token:
return make_response(jsonify({"success": False}), 401)
try:
tools_metadata = []
for tool_name, tool_instance in tool_manager.tools.items():
@@ -236,6 +239,16 @@ class CreateTool(Resource):
if missing_fields:
return missing_fields
try:
if data["name"] == "mcp_tool":
server_url = (data.get("config", {}).get("server_url") or "").strip()
if server_url:
try:
validate_url(server_url)
except SSRFError:
return make_response(
jsonify({"success": False, "message": "Invalid server URL"}),
400,
)
tool_instance = tool_manager.tools.get(data["name"])
if not tool_instance:
return make_response(
@@ -421,6 +434,16 @@ class UpdateToolConfig(Resource):
return make_response(jsonify({"success": False}), 404)
tool_name = tool_doc.get("name")
if tool_name == "mcp_tool":
server_url = (data["config"].get("server_url") or "").strip()
if server_url:
try:
validate_url(server_url)
except SSRFError:
return make_response(
jsonify({"success": False, "message": "Invalid server URL"}),
400,
)
tool_instance = tool_manager.tools.get(tool_name)
config_requirements = (
tool_instance.get_config_requirements() if tool_instance else {}

View File

@@ -138,10 +138,18 @@ def chat_completions():
if usage_error:
return usage_error
should_save_conversation = bool(internal_data.get("save_conversation", False))
if is_stream:
return Response(
_stream_response(
helper, question, agent, processor, model_name, continuation
helper,
question,
agent,
processor,
model_name,
continuation,
should_save_conversation,
),
mimetype="text/event-stream",
headers={
@@ -151,7 +159,13 @@ def chat_completions():
)
else:
return _non_stream_response(
helper, question, agent, processor, model_name, continuation
helper,
question,
agent,
processor,
model_name,
continuation,
should_save_conversation,
)
except ValueError as e:
@@ -181,6 +195,7 @@ def _stream_response(
processor: StreamProcessor,
model_name: str,
continuation: Optional[Dict],
should_save_conversation: bool,
) -> Generator[str, None, None]:
"""Generate translated SSE chunks for streaming response."""
completion_id = f"chatcmpl-{int(time.time())}"
@@ -193,6 +208,7 @@ def _stream_response(
decoded_token=processor.decoded_token,
agent_id=processor.agent_id,
model_id=processor.model_id,
should_save_conversation=should_save_conversation,
_continuation=continuation,
)
@@ -225,6 +241,7 @@ def _non_stream_response(
processor: StreamProcessor,
model_name: str,
continuation: Optional[Dict],
should_save_conversation: bool,
) -> Response:
"""Collect full response and return as single JSON."""
stream = helper.complete_stream(
@@ -235,6 +252,7 @@ def _non_stream_response(
decoded_token=processor.decoded_token,
agent_id=processor.agent_id,
model_id=processor.model_id,
should_save_conversation=should_save_conversation,
_continuation=continuation,
)
@@ -293,8 +311,9 @@ def list_models():
for ag in user_agents:
created = ag.get("createdAt")
created_ts = int(created.timestamp()) if created else int(time.time())
model_id = str(ag.get("_id") or ag.get("id") or "")
models.append({
"id": str(ag.get("key", "")),
"id": model_id,
"object": "model",
"created": created_ts,
"owned_by": "docsgpt",

View File

@@ -80,6 +80,17 @@ def extract_conversation_id(messages: List[Dict]) -> Optional[str]:
return None
def extract_system_prompt(messages: List[Dict]) -> Optional[str]:
"""Extract the first system message content from the messages array.
Returns None if no system message is present.
"""
for msg in messages:
if msg.get("role") == "system":
return msg.get("content", "")
return None
def convert_history(messages: List[Dict]) -> List[Dict]:
"""Convert chat completions messages array to DocsGPT history format.
@@ -148,20 +159,27 @@ def translate_request(
break
history = convert_history(messages)
system_prompt_override = extract_system_prompt(messages)
docsgpt = data.get("docsgpt", {})
result = {
"question": question,
"api_key": api_key,
"history": json.dumps(history),
"save_conversation": True,
# Conversations are NOT persisted by default on the v1 endpoint.
# Callers opt in via ``docsgpt.save_conversation: true``.
"save_conversation": bool(docsgpt.get("save_conversation", False)),
}
if system_prompt_override is not None:
result["system_prompt_override"] = system_prompt_override
# Client tools
if data.get("tools"):
result["client_tools"] = data["tools"]
# DocsGPT extensions
docsgpt = data.get("docsgpt", {})
if docsgpt.get("attachments"):
result["attachments"] = docsgpt["attachments"]

View File

@@ -19,25 +19,10 @@ class EpubParser(BaseParser):
def parse_file(self, file: Path, errors: str = "ignore") -> str:
"""Parse file."""
try:
import ebooklib
from ebooklib import epub
from fast_ebook import epub
except ImportError:
raise ValueError("`EbookLib` is required to read Epub files.")
try:
import html2text
except ImportError:
raise ValueError("`html2text` is required to parse Epub files.")
raise ValueError("`fast-ebook` is required to read Epub files.")
text_list = []
book = epub.read_epub(file, options={"ignore_ncx": True})
# Iterate through all chapters.
for item in book.get_items():
# Chapters are typically located in epub documents items.
if item.get_type() == ebooklib.ITEM_DOCUMENT:
text_list.append(
html2text.html2text(item.get_content().decode("utf-8"))
)
text = "\n".join(text_list)
book = epub.read_epub(file)
text = book.to_markdown()
return text

View File

@@ -1,5 +1,5 @@
anthropic==0.86.0
boto3==1.42.24
anthropic==0.88.0
boto3==1.42.83
beautifulsoup4==4.14.3
cel-python==0.5.0
celery==5.6.3
@@ -11,11 +11,11 @@ rapidocr>=1.4.0
onnxruntime>=1.19.0
docx2txt==0.9
ddgs>=8.0.0
ebooklib==0.20
elevenlabs==2.40.0
fast-ebook
elevenlabs==2.41.0
Flask==3.1.3
faiss-cpu==1.13.2
fastmcp==2.14.6
fastmcp==3.2.0
flask-restx==1.3.2
google-genai==1.69.0
google-api-python-client==2.193.0
@@ -23,10 +23,9 @@ google-auth-httplib2==0.3.1
google-auth-oauthlib==1.3.1
gTTS==2.5.4
gunicorn==25.3.0
html2text==2025.4.15
jinja2==3.1.6
jiter==0.13.0
jmespath==1.0.1
jmespath==1.1.0
joblib==1.5.3
jsonpatch==1.33
jsonpointer==3.0.0
@@ -34,7 +33,7 @@ kombu==5.6.2
langchain==1.2.3
langchain-community==0.4.1
langchain-core==1.2.23
langchain-openai==1.1.7
langchain-openai==1.1.12
langchain-text-splitters==1.1.1
langsmith==0.7.23
lazy-object-proxy==1.12.0
@@ -53,10 +52,10 @@ orjson==3.11.7
packaging==26.0
pandas==3.0.2
openpyxl==3.1.5
pathable==0.4.4
pathable==0.5.0
pdf2image>=1.17.0
pillow
portalocker>=2.7.0,<3.0.0
portalocker>=2.7.0,<4.0.0
prompt-toolkit==3.0.52
protobuf==7.34.1
psycopg2-binary==2.9.11
@@ -65,14 +64,14 @@ pydantic
pydantic-core
pydantic-settings
pymongo==4.16.0
pypdf==6.6.0
pypdf==6.9.2
python-dateutil==2.9.0.post0
python-dotenv
python-jose==3.5.0
python-pptx==1.0.2
redis==7.4.0
referencing>=0.28.0,<0.38.0
regex==2026.3.32
regex==2026.4.4
requests==2.33.1
retry==0.9.2
sentence-transformers==5.3.0

View File

@@ -21,10 +21,19 @@ class LocalStorage(BaseStorage):
)
def _get_full_path(self, path: str) -> str:
"""Get absolute path by combining base_dir and path."""
"""Get absolute path by combining base_dir and path.
Raises:
ValueError: If the resolved path escapes base_dir (path traversal).
"""
if os.path.isabs(path):
return path
return os.path.join(self.base_dir, path)
resolved = os.path.realpath(path)
else:
resolved = os.path.realpath(os.path.join(self.base_dir, path))
base = os.path.realpath(self.base_dir)
if not resolved.startswith(base + os.sep) and resolved != base:
raise ValueError(f"Path traversal detected: {path}")
return resolved
def save_file(self, file_data: BinaryIO, path: str, **kwargs) -> dict:
"""Save a file to local storage."""

View File

@@ -2,6 +2,7 @@
import io
import os
import posixpath
from typing import BinaryIO, Callable, List
import boto3
@@ -14,6 +15,20 @@ from botocore.exceptions import ClientError
class S3Storage(BaseStorage):
"""AWS S3 storage implementation."""
@staticmethod
def _validate_path(path: str) -> str:
"""Validate and normalize an S3 key to prevent path traversal.
Raises:
ValueError: If the path contains traversal sequences or is absolute.
"""
if "\x00" in path:
raise ValueError(f"Null byte in path: {path}")
normalized = posixpath.normpath(path)
if normalized.startswith("/") or normalized.startswith(".."):
raise ValueError(f"Path traversal detected: {path}")
return normalized
def __init__(self, bucket_name=None):
"""
Initialize S3 storage.
@@ -46,6 +61,7 @@ class S3Storage(BaseStorage):
**kwargs,
) -> dict:
"""Save a file to S3 storage."""
path = self._validate_path(path)
self.s3.upload_fileobj(
file_data, self.bucket_name, path, ExtraArgs={"StorageClass": storage_class}
)
@@ -61,6 +77,7 @@ class S3Storage(BaseStorage):
def get_file(self, path: str) -> BinaryIO:
"""Get a file from S3 storage."""
path = self._validate_path(path)
if not self.file_exists(path):
raise FileNotFoundError(f"File not found: {path}")
file_obj = io.BytesIO()
@@ -70,6 +87,7 @@ class S3Storage(BaseStorage):
def delete_file(self, path: str) -> bool:
"""Delete a file from S3 storage."""
path = self._validate_path(path)
try:
self.s3.delete_object(Bucket=self.bucket_name, Key=path)
return True
@@ -78,6 +96,7 @@ class S3Storage(BaseStorage):
def file_exists(self, path: str) -> bool:
"""Check if a file exists in S3 storage."""
path = self._validate_path(path)
try:
self.s3.head_object(Bucket=self.bucket_name, Key=path)
return True
@@ -115,6 +134,7 @@ class S3Storage(BaseStorage):
import logging
import tempfile
path = self._validate_path(path)
if not self.file_exists(path):
raise FileNotFoundError(f"File not found in S3: {path}")
with tempfile.NamedTemporaryFile(

View File

@@ -11,11 +11,33 @@ from application.storage.storage_creator import StorageCreator
def get_vectorstore(path: str) -> str:
if path:
vectorstore = f"indexes/{path}"
else:
vectorstore = "indexes"
return vectorstore
"""Build a safe local path for a FAISS index.
Args:
path: Source identifier provided by the caller.
Returns:
The validated vectorstore path rooted under ``indexes``.
Raises:
ValueError: If ``path`` escapes the ``indexes`` directory.
"""
base_dir = "indexes"
if not path:
return base_dir
normalized = str(path).strip()
if "\\" in normalized:
raise ValueError("Invalid source_id path")
candidate = os.path.normpath(os.path.join(base_dir, normalized))
base_abs = os.path.abspath(base_dir)
candidate_abs = os.path.abspath(candidate)
if not candidate_abs.startswith(base_abs + os.sep) and candidate_abs != base_abs:
raise ValueError("Invalid source_id path")
return candidate
class FaissStore(BaseVectorStore):

View File

@@ -7,6 +7,10 @@ export default {
"title": "🔌 Agent API",
"href": "/Agents/api"
},
"openai-compatible": {
"title": "🔄 OpenAI-Compatible API",
"href": "/Agents/openai-compatible"
},
"webhooks": {
"title": "🪝 Agent Webhooks",
"href": "/Agents/webhooks"

View File

@@ -15,6 +15,10 @@ DocsGPT Agents can be accessed programmatically through API endpoints. This page
When you use an agent `api_key`, DocsGPT loads that agent's configuration automatically (prompt, tools, sources, default model). You usually only need to send `question` and `api_key`.
<Callout type="info">
Looking to connect an existing OpenAI-compatible client (opencode, aider, the OpenAI SDKs, etc.) to a DocsGPT Agent? Use the [OpenAI-Compatible Chat Completions API](/Agents/openai-compatible) — it speaks the standard chat completions protocol so no adapter code is required.
</Callout>
## Base URL
<Callout type="info">

View File

@@ -111,6 +111,7 @@ Once an agent is created, you can:
* Modify any of its configuration settings (name, description, source, prompt, tools, type).
* **Generate a Public Link:** From the edit screen, you can create a shareable public link that allows others to import and use your agent.
* **Get a Webhook URL:** You can also obtain a Webhook URL for the agent. This allows external applications or services to trigger the agent and receive responses programmatically, enabling powerful integrations and automations.
* **Use it via API:** Every agent exposes an API key that can be used with the native [Agent API](/Agents/api) or the [OpenAI-Compatible API](/Agents/openai-compatible) so you can drop DocsGPT Agents into any tool that already speaks the chat completions protocol.
## Seeding Premade Agents from YAML

View File

@@ -0,0 +1,93 @@
---
title: OpenAI-Compatible API
description: Connect any OpenAI-compatible client to DocsGPT Agents via /v1/chat/completions.
---
import { Callout, Tabs } from 'nextra/components';
# OpenAI-Compatible API
DocsGPT exposes `/v1/chat/completions` following the standard chat completions protocol. Point any compatible client — **opencode**, **Aider**, **LibreChat** or the OpenAI SDKs — at your DocsGPT Agent by changing only the base URL and API key.
## Quick Start
<Tabs items={['Python', 'cURL']}>
<Tabs.Tab>
```python
from openai import OpenAI
client = OpenAI(
base_url="http://localhost:7091/v1", # or https://gptcloud.arc53.com/v1
api_key="your_agent_api_key",
)
response = client.chat.completions.create(
model="docsgpt-agent",
messages=[{"role": "user", "content": "Summarize our refund policy"}],
)
print(response.choices[0].message.content)
```
</Tabs.Tab>
<Tabs.Tab>
```bash
curl -X POST http://localhost:7091/v1/chat/completions \
-H "Authorization: Bearer your_agent_api_key" \
-H "Content-Type: application/json" \
-d '{"model":"docsgpt-agent","messages":[{"role":"user","content":"Summarize our refund policy"}]}'
```
</Tabs.Tab>
</Tabs>
The `model` field is accepted but ignored — the agent bound to your API key determines the model. The agent's prompt, sources, tools, and default model are loaded automatically.
## Base URL & Auth
| Environment | Base URL |
| --- | --- |
| Local | `http://localhost:7091/v1` |
| Cloud | `https://gptcloud.arc53.com/v1` |
Authenticate with `Authorization: Bearer <agent_api_key>`.
## Endpoints
| Method | Path | Description |
| --- | --- | --- |
| `POST` | `/v1/chat/completions` | Chat request (streaming or non-streaming) |
| `GET` | `/v1/models` | List agents available to your key |
## Streaming
Set `"stream": true`. You'll receive SSE chunks with `choices[0].delta.content`. DocsGPT-specific events (sources, tool calls) arrive as extra frames with a `docsgpt` key — standard clients ignore them.
```python
stream = client.chat.completions.create(
model="docsgpt-agent",
stream=True,
messages=[{"role": "user", "content": "Explain vector search"}],
)
for chunk in stream:
print(chunk.choices[0].delta.content or "", end="", flush=True)
```
## System Prompt Override
System messages are **dropped by default** — the agent's configured prompt is used. To allow callers to override it, enable **Allow prompt override** in the agent's Advanced settings.
<Callout type="warning">
When an override is active, the agent's prompt template is replaced wholesale — template variables like `{summaries}` are not substituted.
</Callout>
## Conversation Persistence
Conversations are **not persisted by default** (stateless, like most OpenAI clients expect). Opt in per request:
```json
{ "docsgpt": { "save_conversation": true } }
```
The response will include `docsgpt.conversation_id`.
## When to Use Native Endpoints Instead
Use [`/api/answer` or `/stream`](/Agents/api) if you need server-side attachments, `passthrough` template variables, explicit `conversation_id` reuse, or persistence by default.

View File

@@ -1,3 +1,5 @@
import hashlib
import hmac
import os
import pprint
@@ -10,6 +12,7 @@ docsgpt_url = os.getenv("docsgpt_url")
chatwoot_url = os.getenv("chatwoot_url")
docsgpt_key = os.getenv("docsgpt_key")
chatwoot_token = os.getenv("chatwoot_token")
chatwoot_webhook_secret = os.getenv("chatwoot_webhook_secret", "")
# account_id = os.getenv("account_id")
# assignee_id = os.getenv("assignee_id")
label_stop = "human-requested"
@@ -45,12 +48,35 @@ def send_to_chatwoot(account, conversation, message):
return r.json()
def is_valid_chatwoot_signature(raw_body: bytes, signature_header: str | None) -> bool:
"""Validate Chatwoot webhook signature using shared secret."""
if not chatwoot_webhook_secret or not signature_header:
return False
expected = hmac.new(
chatwoot_webhook_secret.encode("utf-8"), raw_body, hashlib.sha256
).hexdigest()
provided = signature_header.strip()
if provided.startswith("sha256="):
provided = provided.split("=", maxsplit=1)[1]
return hmac.compare_digest(provided, expected)
app = Flask(__name__)
@app.route('/docsgpt', methods=['POST'])
def docsgpt():
data = request.get_json()
raw_body = request.get_data()
signature = request.headers.get("X-Chatwoot-Signature")
if not is_valid_chatwoot_signature(raw_body, signature):
return "Unauthorized", 401
data = request.get_json(silent=True)
if not isinstance(data, dict):
return "Invalid payload", 400
pp = pprint.PrettyPrinter(indent=4)
pp.pprint(data)
try:

View File

@@ -73,6 +73,7 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
token_limit: undefined,
limited_request_mode: false,
request_limit: undefined,
allow_system_prompt_override: false,
models: [],
default_model_id: '',
});
@@ -241,6 +242,11 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
formData.append('request_limit', '0');
}
formData.append(
'allow_system_prompt_override',
agent.allow_system_prompt_override ? 'True' : 'False',
);
if (imageFile) formData.append('image', imageFile);
if (agent.tools && agent.tools.length > 0)
@@ -361,6 +367,11 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
formData.append('request_limit', '0');
}
formData.append(
'allow_system_prompt_override',
agent.allow_system_prompt_override ? 'True' : 'False',
);
if (agent.models && agent.models.length > 0) {
formData.append('models', JSON.stringify(agent.models));
}
@@ -1266,6 +1277,43 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
}`}
/>
</div>
<div className="mt-6">
<div className="flex items-center justify-between gap-4">
<div className="min-w-0 flex-1">
<h2 className="text-sm font-medium">
{t('agents.form.advanced.systemPromptOverride')}
</h2>
<p className="mt-1 text-xs text-gray-600 dark:text-gray-400">
{t(
'agents.form.advanced.systemPromptOverrideDescription',
)}
</p>
</div>
<button
onClick={() =>
setAgent({
...agent,
allow_system_prompt_override:
!agent.allow_system_prompt_override,
})
}
className={`relative h-6 w-11 shrink-0 rounded-full transition-colors ${
agent.allow_system_prompt_override
? 'bg-primary'
: 'bg-gray-300 dark:bg-gray-600'
}`}
>
<span
className={`absolute top-0.5 h-5 w-5 transform rounded-full bg-white transition-transform ${
agent.allow_system_prompt_override
? ''
: '-translate-x-5'
}`}
/>
</button>
</div>
</div>
</div>
)}
</div>

View File

@@ -36,6 +36,7 @@ export type Agent = {
default_model_id?: string;
folder_id?: string;
workflow?: string;
allow_system_prompt_override?: boolean;
};
export type AgentFolder = {

View File

@@ -18,6 +18,7 @@ import {
X,
} from 'lucide-react';
import { useCallback, useEffect, useMemo, useRef, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { useSelector } from 'react-redux';
import { useNavigate, useParams, useSearchParams } from 'react-router-dom';
import ReactFlow, {
@@ -301,6 +302,7 @@ function createWorkflowPayload(
}
function WorkflowBuilderInner() {
const { t } = useTranslation();
const navigate = useNavigate();
const token = useSelector(selectToken);
const sourceDocs = useSelector(selectSourceDocs);
@@ -1142,6 +1144,10 @@ function WorkflowBuilderInner() {
workflowDescription || `Workflow agent: ${workflowName}`,
);
agentFormData.append('status', 'published');
agentFormData.append(
'allow_system_prompt_override',
currentAgent.allow_system_prompt_override ? 'True' : 'False',
);
if (imageFile) {
agentFormData.append('image', imageFile);
}
@@ -1203,6 +1209,10 @@ function WorkflowBuilderInner() {
agentFormData.append('agent_type', 'workflow');
agentFormData.append('status', 'published');
agentFormData.append('workflow', savedWorkflowId || '');
agentFormData.append(
'allow_system_prompt_override',
currentAgent.allow_system_prompt_override ? 'True' : 'False',
);
if (imageFile) {
agentFormData.append('image', imageFile);
}
@@ -1454,6 +1464,40 @@ function WorkflowBuilderInner() {
Image updates are included the next time you save.
</p>
</div>
<div className="mb-3">
<div className="flex items-center justify-between">
<div>
<label className="block text-sm font-medium text-gray-700 dark:text-gray-300">
{t('agents.form.advanced.systemPromptOverride')}
</label>
<p className="mt-0.5 text-[11px] text-gray-500 dark:text-gray-400">
{t('agents.form.advanced.systemPromptOverrideDescription')}
</p>
</div>
<button
onClick={() =>
setCurrentAgent((prev) => ({
...prev,
allow_system_prompt_override:
!prev.allow_system_prompt_override,
}))
}
className={`relative h-6 w-11 shrink-0 rounded-full transition-colors ${
currentAgent.allow_system_prompt_override
? 'bg-primary'
: 'bg-gray-300 dark:bg-gray-600'
}`}
>
<span
className={`absolute top-0.5 h-5 w-5 transform rounded-full bg-white transition-transform ${
currentAgent.allow_system_prompt_override
? ''
: '-translate-x-5'
}`}
/>
</button>
</div>
</div>
<button
onClick={handleWorkflowSettingsDone}
disabled={isPublishing}

View File

@@ -78,6 +78,7 @@ function Dropdown<T extends DropdownOption>({
const searchRef = useRef<HTMLInputElement>(null);
const [open, setOpen] = useState(false);
const [query, setQuery] = useState('');
const [dropUp, setDropUp] = useState(false);
const radius = rounded === '3xl' ? 'rounded-3xl' : 'rounded-xl';
const radiusTop = rounded === '3xl' ? 'rounded-t-3xl' : 'rounded-t-xl';
@@ -90,14 +91,23 @@ function Dropdown<T extends DropdownOption>({
setQuery('');
}
};
document.addEventListener('mousedown', handler);
return () => document.removeEventListener('mousedown', handler);
document.addEventListener('mousedown', handler, true);
return () => document.removeEventListener('mousedown', handler, true);
}, []);
useEffect(() => {
if (open && searchable && searchRef.current) searchRef.current.focus();
}, [open, searchable]);
const handleToggle = () => {
if (!open && ref.current) {
const rect = ref.current.getBoundingClientRect();
const spaceBelow = window.innerHeight - rect.bottom;
setDropUp(spaceBelow < 220);
}
setOpen((v) => !v);
};
const filtered = useMemo(() => {
if (!searchable || !query.trim()) return options;
const q = query.toLowerCase();
@@ -110,8 +120,8 @@ function Dropdown<T extends DropdownOption>({
<div className={`relative ${size}`} ref={ref}>
<button
type="button"
onClick={() => setOpen((v) => !v)}
className={`border-border bg-card text-foreground flex w-full cursor-pointer items-center justify-between border px-5 py-3 ${open ? radiusTop : radius}`}
onClick={handleToggle}
className={`border-border bg-card text-foreground flex w-full cursor-pointer items-center justify-between border px-5 py-3 ${open ? (dropUp ? radiusBottom : radiusTop) : radius}`}
>
<span
className={`truncate ${contentSize} ${displayValue ? '' : 'text-muted-foreground'}`}
@@ -125,7 +135,11 @@ function Dropdown<T extends DropdownOption>({
{open && (
<div
className={`border-border bg-card absolute inset-x-0 z-20 -mt-px overflow-hidden border border-t-0 shadow-lg ${radiusBottom}`}
className={`border-border bg-card absolute inset-x-0 z-20 overflow-hidden border shadow-lg ${
dropUp
? `bottom-full -mt-px border-b-0 ${radiusTop}`
: `-mt-px border-t-0 ${radiusBottom}`
}`}
>
{searchable && (
<div className="flex items-center px-3 py-2">

View File

@@ -10,7 +10,9 @@ interface SkeletonLoaderProps {
| 'chatbot'
| 'dropdown'
| 'chunkCards'
| 'sourceCards';
| 'sourceCards'
| 'toolCards'
| 'addToolCards';
}
const SkeletonLoader: React.FC<SkeletonLoaderProps> = ({
@@ -237,6 +239,55 @@ const SkeletonLoader: React.FC<SkeletonLoaderProps> = ({
</>
);
const renderAddToolCards = () => (
<>
{Array.from({ length: count }).map((_, idx) => (
<div
key={`add-tool-skel-${idx}`}
className="border-light-gainsboro dark:border-arsenic flex h-52 w-full animate-pulse flex-col justify-between rounded-2xl border p-6"
>
<div className="w-full">
<div className="flex w-full items-center justify-between px-1">
<div className="h-6 w-6 rounded bg-gray-300 dark:bg-gray-600"></div>
</div>
<div className="mt-[9px] space-y-2 px-1">
<div className="h-4 w-2/3 rounded bg-gray-300 dark:bg-gray-600"></div>
<div className="h-3 w-full rounded bg-gray-200 dark:bg-gray-700"></div>
<div className="h-3 w-5/6 rounded bg-gray-200 dark:bg-gray-700"></div>
<div className="h-3 w-3/4 rounded bg-gray-200 dark:bg-gray-700"></div>
</div>
</div>
</div>
))}
</>
);
const renderToolCards = () => (
<>
{Array.from({ length: count }).map((_, idx) => (
<div
key={`tool-skel-${idx}`}
className="bg-muted flex h-52 w-[300px] animate-pulse flex-col justify-between rounded-2xl p-6"
>
<div className="w-full">
<div className="flex items-center gap-2 px-1">
<div className="h-6 w-6 rounded bg-gray-300 dark:bg-gray-600"></div>
</div>
<div className="mt-[9px] space-y-2 px-1">
<div className="h-4 w-2/3 rounded bg-gray-300 dark:bg-gray-600"></div>
<div className="h-3 w-full rounded bg-gray-200 dark:bg-gray-700"></div>
<div className="h-3 w-5/6 rounded bg-gray-200 dark:bg-gray-700"></div>
<div className="h-3 w-3/4 rounded bg-gray-200 dark:bg-gray-700"></div>
</div>
</div>
<div className="flex justify-end">
<div className="h-5 w-9 rounded-full bg-gray-300 dark:bg-gray-600"></div>
</div>
</div>
))}
</>
);
const componentMap = {
fileTable: renderTable,
chatbot: renderChatbot,
@@ -246,6 +297,8 @@ const SkeletonLoader: React.FC<SkeletonLoaderProps> = ({
analysis: renderAnalysis,
chunkCards: renderChunkCards,
sourceCards: renderSourceCards,
toolCards: renderToolCards,
addToolCards: renderAddToolCards,
};
const render = componentMap[component] || componentMap.default;

View File

@@ -619,7 +619,9 @@
"tokenLimiting": "Token-Limitierung",
"tokenLimitingDescription": "Begrenze die täglich von diesem Agenten verwendbaren Tokens",
"requestLimiting": "Anfrage-Limitierung",
"requestLimitingDescription": "Begrenze die täglich an diesen Agenten gestellten Anfragen"
"requestLimitingDescription": "Begrenze die täglich an diesen Agenten gestellten Anfragen",
"systemPromptOverride": "Prompt-Überschreibung erlauben",
"systemPromptOverrideDescription": "Erlaubt v1-API-Aufrufern, den System-Prompt dieses Agenten zu ersetzen"
},
"preview": {
"publishedPreview": "Veröffentlichte Agenten können hier in der Vorschau angezeigt werden"

View File

@@ -653,7 +653,9 @@
"tokenLimiting": "Token limiting",
"tokenLimitingDescription": "Limit daily total tokens that can be used by this agent",
"requestLimiting": "Request limiting",
"requestLimitingDescription": "Limit daily total requests that can be made to this agent"
"requestLimitingDescription": "Limit daily total requests that can be made to this agent",
"systemPromptOverride": "Allow prompt override",
"systemPromptOverrideDescription": "Let v1 API callers replace this agent's system prompt"
},
"preview": {
"publishedPreview": "Published agents can be previewed here"

View File

@@ -641,7 +641,9 @@
"tokenLimiting": "Límite de tokens",
"tokenLimitingDescription": "Limita el total diario de tokens que puede usar este agente",
"requestLimiting": "Límite de solicitudes",
"requestLimitingDescription": "Limita el total diario de solicitudes que se pueden hacer a este agente"
"requestLimitingDescription": "Limita el total diario de solicitudes que se pueden hacer a este agente",
"systemPromptOverride": "Permitir sobrescribir el prompt",
"systemPromptOverrideDescription": "Permitir que los llamadores de la API v1 reemplacen el prompt del sistema de este agente"
},
"preview": {
"publishedPreview": "Los agentes publicados se pueden previsualizar aquí"

View File

@@ -641,7 +641,9 @@
"tokenLimiting": "トークン制限",
"tokenLimitingDescription": "このエージェントが使用できる1日の合計トークン数を制限します",
"requestLimiting": "リクエスト制限",
"requestLimitingDescription": "このエージェントに対して行える1日の合計リクエスト数を制限します"
"requestLimitingDescription": "このエージェントに対して行える1日の合計リクエスト数を制限します",
"systemPromptOverride": "プロンプトの上書きを許可",
"systemPromptOverrideDescription": "v1 API呼び出し元がこのエージェントのシステムプロンプトを置き換えることを許可します"
},
"preview": {
"publishedPreview": "公開されたエージェントはここでプレビューできます"

View File

@@ -641,7 +641,9 @@
"tokenLimiting": "Лимит токенов",
"tokenLimitingDescription": "Ограничить ежедневное общее количество токенов, которые может использовать этот агент",
"requestLimiting": "Лимит запросов",
"requestLimitingDescription": "Ограничить ежедневное общее количество запросов, которые можно сделать к этому агенту"
"requestLimitingDescription": "Ограничить ежедневное общее количество запросов, которые можно сделать к этому агенту",
"systemPromptOverride": "Разрешить замену промпта",
"systemPromptOverrideDescription": "Разрешить вызовам API v1 заменять системный промпт этого агента"
},
"preview": {
"publishedPreview": "Опубликованные агенты можно просмотреть здесь"

View File

@@ -641,7 +641,9 @@
"tokenLimiting": "權杖限制",
"tokenLimitingDescription": "限制此代理每天可使用的總權杖數",
"requestLimiting": "請求限制",
"requestLimitingDescription": "限制每天可向此代理發出的總請求數"
"requestLimitingDescription": "限制每天可向此代理發出的總請求數",
"systemPromptOverride": "允許覆蓋提示詞",
"systemPromptOverrideDescription": "允許 v1 API 呼叫者替換此代理的系統提示詞"
},
"preview": {
"publishedPreview": "已發佈的代理可以在此處預覽"

View File

@@ -641,7 +641,9 @@
"tokenLimiting": "令牌限制",
"tokenLimitingDescription": "限制此代理每天可使用的总令牌数",
"requestLimiting": "请求限制",
"requestLimitingDescription": "限制每天可向此代理发出的总请求数"
"requestLimitingDescription": "限制每天可向此代理发出的总请求数",
"systemPromptOverride": "允许覆盖提示词",
"systemPromptOverrideDescription": "允许 v1 API 调用者替换此代理的系统提示词"
},
"preview": {
"publishedPreview": "已发布的代理可以在此处预览"

View File

@@ -3,8 +3,8 @@ import { useTranslation } from 'react-i18next';
import { useSelector } from 'react-redux';
import userService from '../api/services/userService';
import Spinner from '../components/Spinner';
import { useOutsideAlerter } from '../hooks';
import SkeletonLoader from '../components/SkeletonLoader';
import { useLoaderState, useOutsideAlerter } from '../hooks';
import { ActiveState } from '../models/misc';
import { selectToken } from '../preferences/preferenceSlice';
import ConfigToolModal from './ConfigToolModal';
@@ -37,7 +37,7 @@ export default function AddToolModal({
React.useState<ActiveState>('INACTIVE');
const [mcpModalState, setMcpModalState] =
React.useState<ActiveState>('INACTIVE');
const [loading, setLoading] = React.useState(false);
const [loading, setLoading] = useLoaderState(false);
useOutsideAlerter(modalRef, () => {
if (modalState === 'ACTIVE') {
@@ -121,8 +121,8 @@ export default function AddToolModal({
</h2>
<div className="mt-5 h-[73vh] overflow-auto px-3 py-px">
{loading ? (
<div className="flex h-full items-center justify-center">
<Spinner />
<div className="grid auto-rows-fr grid-cols-1 gap-4 pb-2 sm:grid-cols-2 lg:grid-cols-3">
<SkeletonLoader component="addToolCards" count={6} />
</div>
) : (
<div className="grid auto-rows-fr grid-cols-1 gap-4 pb-2 sm:grid-cols-2 lg:grid-cols-3">

View File

@@ -10,9 +10,9 @@ import NoFilesIcon from '../assets/no-files.svg';
import SearchIcon from '../assets/search.svg';
import ThreeDotsIcon from '../assets/three-dots.svg';
import ContextMenu, { MenuOption } from '../components/ContextMenu';
import Spinner from '../components/Spinner';
import SkeletonLoader from '../components/SkeletonLoader';
import ToggleSwitch from '../components/ToggleSwitch';
import { useDarkTheme } from '../hooks';
import { useDarkTheme, useLoaderState } from '../hooks';
import AddToolModal from '../modals/AddToolModal';
import ConfirmationModal from '../modals/ConfirmationModal';
import MCPServerModal from '../modals/MCPServerModal';
@@ -33,7 +33,7 @@ export default function Tools() {
const [selectedTool, setSelectedTool] = React.useState<
UserToolType | APIToolType | null
>(null);
const [loading, setLoading] = React.useState(false);
const [loading, setLoading] = useLoaderState(false);
const [activeMenuId, setActiveMenuId] = React.useState<string | null>(null);
const menuRefs = React.useRef<{
[key: string]: React.RefObject<HTMLDivElement | null>;
@@ -242,10 +242,8 @@ export default function Tools() {
</div>
<div className="border-border dark:border-border mt-5 mb-8 border-b" />
{loading ? (
<div className="grid grid-cols-2 gap-6 lg:grid-cols-3">
<div className="col-span-2 mt-24 flex h-32 items-center justify-center lg:col-span-3">
<Spinner />
</div>
<div className="flex flex-wrap justify-center gap-4 sm:justify-start">
<SkeletonLoader component="toolCards" count={6} />
</div>
) : (
<div className="flex flex-wrap justify-center gap-4 sm:justify-start">

View File

@@ -543,8 +543,20 @@ function Configure-TTS {
}
}
# Generate INTERNAL_KEY for worker-to-backend auth if not already present
function Ensure-InternalKey {
$content = if (Test-Path $ENV_FILE) { Get-Content $ENV_FILE -Raw } else { "" }
if ($content -notmatch "(?m)^INTERNAL_KEY=") {
$bytes = New-Object byte[] 32
[System.Security.Cryptography.RandomNumberGenerator]::Fill($bytes)
$internal_key = ($bytes | ForEach-Object { $_.ToString("x2") }) -join ""
"INTERNAL_KEY=$internal_key" | Add-Content -Path $ENV_FILE -Encoding utf8
}
}
# Main advanced settings menu
function Prompt-AdvancedSettings {
Ensure-InternalKey
Write-Host ""
$configure_advanced = Read-Host "Would you like to configure advanced settings? (y/N)"
if ($configure_advanced -ne "y" -and $configure_advanced -ne "Y") {

View File

@@ -396,8 +396,18 @@ configure_tts() {
esac
}
# Generate INTERNAL_KEY for worker-to-backend auth if not already present
ensure_internal_key() {
if ! grep -q "^INTERNAL_KEY=" "$ENV_FILE" 2>/dev/null; then
local internal_key
internal_key=$(openssl rand -hex 32 2>/dev/null || head -c 64 /dev/urandom | od -An -tx1 | tr -d ' \n')
echo "INTERNAL_KEY=$internal_key" >> "$ENV_FILE"
fi
}
# Main advanced settings menu
prompt_advanced_settings() {
ensure_internal_key
echo
read -p "$(echo -e "${DEFAULT_FG}Would you like to configure advanced settings? (y/N): ${NC}")" configure_advanced
if [[ ! "$configure_advanced" =~ ^[yY]$ ]]; then

View File

@@ -28,6 +28,7 @@ def _patch_mcp_globals(monkeypatch):
monkeypatch.setattr(mcp_mod, "mongo", mock_mongo)
monkeypatch.setattr(mcp_mod, "db", mock_db)
monkeypatch.setattr(mcp_mod, "_mcp_clients_cache", {})
monkeypatch.setattr(mcp_mod, "validate_url", lambda url: url)
@pytest.fixture

View File

@@ -33,6 +33,8 @@ def _patch_mcp_globals(monkeypatch):
monkeypatch.setattr(mcp_mod, "mongo", mock_mongo)
monkeypatch.setattr(mcp_mod, "db", mock_db)
monkeypatch.setattr(mcp_mod, "_mcp_clients_cache", {})
# Bypass DNS-resolving URL validation for tests using fake hostnames.
monkeypatch.setattr(mcp_mod, "validate_url", lambda u, **kw: u)
@pytest.fixture
@@ -136,6 +138,47 @@ class TestMCPToolInit:
})
assert tool.custom_headers == {"X-Custom": "val"}
def test_rejects_metadata_ip(self, monkeypatch):
from application.agents.tools.mcp_tool import MCPTool
from application.core.url_validation import validate_url as real_validate_url
import application.agents.tools.mcp_tool as mcp_mod
monkeypatch.setattr(mcp_mod, "validate_url", real_validate_url)
with pytest.raises(ValueError, match="Invalid MCP server URL"):
MCPTool(config={"server_url": "http://169.254.169.254/latest/meta-data", "auth_type": "none"})
def test_rejects_localhost(self, monkeypatch):
from application.agents.tools.mcp_tool import MCPTool
from application.core.url_validation import validate_url as real_validate_url
import application.agents.tools.mcp_tool as mcp_mod
monkeypatch.setattr(mcp_mod, "validate_url", real_validate_url)
with pytest.raises(ValueError, match="Invalid MCP server URL"):
MCPTool(config={"server_url": "http://localhost:8080/mcp", "auth_type": "none"})
def test_rejects_private_ip(self, monkeypatch):
from application.agents.tools.mcp_tool import MCPTool
from application.core.url_validation import validate_url as real_validate_url
import application.agents.tools.mcp_tool as mcp_mod
monkeypatch.setattr(mcp_mod, "validate_url", real_validate_url)
with pytest.raises(ValueError, match="Invalid MCP server URL"):
MCPTool(config={"server_url": "http://10.0.0.1/mcp", "auth_type": "none"})
def test_accepts_public_url(self):
tool = _make_tool({
"server_url": "https://mcp.example.com/api",
"auth_type": "none",
})
assert tool.server_url == "https://mcp.example.com/api"
def test_empty_server_url_allowed(self):
from application.agents.tools.mcp_tool import MCPTool
with patch.object(MCPTool, "_setup_client"):
tool = MCPTool(config={"server_url": "", "auth_type": "none"})
assert tool.server_url == ""
# =====================================================================
# Redirect URI Resolution

View File

@@ -330,6 +330,170 @@ class TestStreamProcessorDocPrefetch:
assert docs_together is not None
assert "Agent doc content" in docs_together
def test_configure_source_treats_default_string_as_no_docs(self, mock_mongo_db):
from application.api.answer.services.stream_processor import StreamProcessor
from application.core.settings import settings
db = mock_mongo_db[settings.MONGO_DB_NAME]
agents_collection = db["agents"]
agent_id = ObjectId()
agents_collection.insert_one(
{
"_id": agent_id,
"key": "agent_default_source_key",
"user": "user_123",
"prompt_id": "default",
"agent_type": "classic",
"source": "default",
}
)
processor = StreamProcessor(
{"question": "Hi", "api_key": "agent_default_source_key"},
None,
)
processor._configure_agent()
processor._configure_source()
assert processor.source == {}
assert processor.all_sources == []
def test_prefetch_skipped_when_no_active_docs(self, mock_mongo_db):
from unittest.mock import MagicMock
from application.api.answer.services.stream_processor import StreamProcessor
processor = StreamProcessor(
{"question": "Hi there"},
{"sub": "user_123"},
)
processor.initialize()
processor.create_retriever = MagicMock()
docs_together, docs = processor.pre_fetch_docs("Hi there")
processor.create_retriever.assert_not_called()
assert docs_together is None
assert docs is None
def test_prefetch_skipped_when_active_docs_is_default(self, mock_mongo_db):
from unittest.mock import MagicMock
from application.api.answer.services.stream_processor import StreamProcessor
processor = StreamProcessor(
{"question": "Hi", "active_docs": "default"},
{"sub": "user_123"},
)
processor.initialize()
processor.create_retriever = MagicMock()
docs_together, docs = processor.pre_fetch_docs("Hi")
processor.create_retriever.assert_not_called()
assert docs_together is None
assert docs is None
def test_agent_retriever_and_chunks_propagate_to_retriever_config(self, mock_mongo_db):
from application.api.answer.services.stream_processor import StreamProcessor
from application.core.settings import settings
db = mock_mongo_db[settings.MONGO_DB_NAME]
agents_collection = db["agents"]
source_id = ObjectId()
db["sources"].insert_one(
{"_id": source_id, "name": "src", "retriever": "hybrid", "chunks": 5}
)
agent_id = ObjectId()
agents_collection.insert_one(
{
"_id": agent_id,
"key": "agent_ret_key",
"user": "user_123",
"prompt_id": "default",
"agent_type": "classic",
"retriever": "hybrid",
"chunks": 5,
"source": DBRef("sources", source_id),
}
)
processor = StreamProcessor(
{"question": "Test", "api_key": "agent_ret_key"},
None,
)
processor.initialize()
assert processor.retriever_config["retriever_name"] == "hybrid"
assert processor.retriever_config["chunks"] == 5
def test_request_retriever_and_chunks_override_agent_config(self, mock_mongo_db):
from application.api.answer.services.stream_processor import StreamProcessor
from application.core.settings import settings
db = mock_mongo_db[settings.MONGO_DB_NAME]
agents_collection = db["agents"]
agent_id = ObjectId()
agents_collection.insert_one(
{
"_id": agent_id,
"key": "agent_override_key",
"user": "user_123",
"prompt_id": "default",
"agent_type": "classic",
"retriever": "hybrid",
"chunks": 5,
}
)
processor = StreamProcessor(
{
"question": "Test",
"api_key": "agent_override_key",
"retriever": "classic",
"chunks": 7,
},
None,
)
processor.initialize()
assert processor.retriever_config["retriever_name"] == "classic"
assert processor.retriever_config["chunks"] == 7
def test_agent_data_fetched_once_per_request(self, mock_mongo_db):
from unittest.mock import patch
from application.api.answer.services.stream_processor import StreamProcessor
from application.core.settings import settings
db = mock_mongo_db[settings.MONGO_DB_NAME]
agents_collection = db["agents"]
agent_id = ObjectId()
agents_collection.insert_one(
{
"_id": agent_id,
"key": "agent_once_key",
"user": "user_123",
"prompt_id": "default",
"agent_type": "classic",
}
)
processor = StreamProcessor(
{"question": "Test", "api_key": "agent_once_key"},
None,
)
with patch.object(
processor, "_get_data_from_api_key", wraps=processor._get_data_from_api_key
) as spy:
processor.initialize()
assert spy.call_count == 1
@pytest.mark.unit
class TestStreamProcessorAttachments:

View File

@@ -566,14 +566,13 @@ class TestConfigureSource:
decoded_token={"sub": "u"},
)
sp.agent_key = None
agent_data = {
sp._agent_data = {
"sources": [
{"id": "src1", "retriever": "classic"},
{"id": "src2", "retriever": "hybrid"},
],
"source": None,
}
sp._get_data_from_api_key = MagicMock(return_value=agent_data)
sp._configure_source()
assert sp.source == {"active_docs": ["src1", "src2"]}
assert len(sp.all_sources) == 2
@@ -593,12 +592,11 @@ class TestConfigureSource:
decoded_token={"sub": "u"},
)
sp.agent_key = None
agent_data = {
sp._agent_data = {
"sources": [],
"source": "single_src",
"retriever": "classic",
}
sp._get_data_from_api_key = MagicMock(return_value=agent_data)
sp._configure_source()
assert sp.source == {"active_docs": "single_src"}
assert len(sp.all_sources) == 1
@@ -618,8 +616,7 @@ class TestConfigureSource:
decoded_token={"sub": "u"},
)
sp.agent_key = None
agent_data = {"sources": [], "source": None}
sp._get_data_from_api_key = MagicMock(return_value=agent_data)
sp._agent_data = {"sources": [], "source": None}
sp._configure_source()
assert sp.source == {}
assert sp.all_sources == []
@@ -639,11 +636,10 @@ class TestConfigureSource:
decoded_token={"sub": "u"},
)
sp.agent_key = "agent_key_123"
agent_data = {
sp._agent_data = {
"sources": [{"id": "s1", "retriever": "classic"}],
"source": None,
}
sp._get_data_from_api_key = MagicMock(return_value=agent_data)
sp._configure_source()
assert sp.source == {"active_docs": ["s1"]}
@@ -662,11 +658,10 @@ class TestConfigureSource:
decoded_token={"sub": "u"},
)
sp.agent_key = None
agent_data = {
sp._agent_data = {
"sources": [{"id": None}, {"retriever": "classic"}],
"source": None,
}
sp._get_data_from_api_key = MagicMock(return_value=agent_data)
sp._configure_source()
assert sp.source == {}
@@ -1189,6 +1184,8 @@ class TestConfigureAgent:
"chunks": "5",
})
sp._configure_agent()
sp.model_id = "test-model"
sp._configure_retriever()
assert sp.agent_config["workflow"] == "wf_123"
assert sp.agent_config["workflow_owner"] == "user1"
assert sp.retriever_config["retriever_name"] == "hybrid"
@@ -1211,6 +1208,8 @@ class TestConfigureAgent:
"chunks": "not_a_number",
})
sp._configure_agent()
sp.model_id = "test-model"
sp._configure_retriever()
assert sp.retriever_config["chunks"] == 2
@@ -1763,8 +1762,8 @@ class TestConfigureAgentAdditionalPaths:
assert sp.decoded_token == {"sub": "owner_user"}
@pytest.mark.unit
def test_configure_agent_with_source_in_data_key(self):
"""Cover line 463-464: data_key has 'source' set."""
def test_configure_source_picks_up_cached_agent_data(self):
"""After _configure_agent caches _agent_data, _configure_source uses it."""
sp = self._make_sp()
sp._resolve_agent_id = MagicMock(return_value="agent_id_1")
sp._get_agent_key = MagicMock(return_value=("agent_key", False, None))
@@ -1780,6 +1779,7 @@ class TestConfigureAgentAdditionalPaths:
"source": "my_source",
})
sp._configure_agent()
sp._configure_source()
assert sp.source == {"active_docs": "my_source"}
@pytest.mark.unit
@@ -2067,7 +2067,7 @@ class TestPreFetchDocsFullPaths:
"chunks": 2,
"doc_token_limit": 50000,
}
sp.source = {}
sp.source = {"active_docs": ["src1"]}
sp.model_id = "test-model"
sp.agent_id = None
return sp

View File

@@ -48,7 +48,7 @@ def internal_app(monkeypatch, mock_mongo_db):
@pytest.mark.unit
class TestVerifyInternalKey:
def test_no_internal_key_configured_allows_access(
def test_no_internal_key_configured_rejects_access(
self, internal_app, monkeypatch
):
app, db = internal_app
@@ -63,9 +63,8 @@ class TestVerifyInternalKey:
),
)
with app.test_client() as client:
# download will fail for missing file but should not be 401
resp = client.get("/api/download?user=u&name=n&file=f")
assert resp.status_code != 401
assert resp.status_code == 401
def test_missing_key_returns_401(self, internal_app, monkeypatch):
app, db = internal_app
@@ -131,9 +130,12 @@ class TestVerifyInternalKey:
@pytest.mark.unit
class TestUploadIndex:
_TEST_INTERNAL_KEY = "test-internal-key"
_AUTH_HEADERS = {"X-Internal-Key": "test-internal-key"}
def _make_settings(self, vector_store="faiss"):
return MagicMock(
INTERNAL_KEY=None,
INTERNAL_KEY=self._TEST_INTERNAL_KEY,
UPLOAD_FOLDER="uploads",
VECTOR_STORE=vector_store,
EMBEDDINGS_NAME="test_embeddings",
@@ -146,7 +148,7 @@ class TestUploadIndex:
"application.api.internal.routes.settings", self._make_settings()
)
with app.test_client() as client:
resp = client.post("/api/upload_index", data={})
resp = client.post("/api/upload_index", data={}, headers=self._AUTH_HEADERS)
assert resp.json["status"] == "no user"
def test_missing_name_returns_no_name(self, internal_app, monkeypatch):
@@ -155,7 +157,7 @@ class TestUploadIndex:
"application.api.internal.routes.settings", self._make_settings()
)
with app.test_client() as client:
resp = client.post("/api/upload_index", data={"user": "testuser"})
resp = client.post("/api/upload_index", data={"user": "testuser"}, headers=self._AUTH_HEADERS)
assert resp.json["status"] == "no name"
def test_creates_new_source_entry(self, internal_app, monkeypatch):
@@ -182,6 +184,7 @@ class TestUploadIndex:
"id": doc_id,
"type": "local",
},
headers=self._AUTH_HEADERS,
)
assert resp.json["status"] == "ok"
@@ -219,6 +222,7 @@ class TestUploadIndex:
"id": str(doc_id),
"type": "remote",
},
headers=self._AUTH_HEADERS,
)
assert resp.json["status"] == "ok"
@@ -252,6 +256,7 @@ class TestUploadIndex:
"type": "local",
"directory_structure": json.dumps(dir_struct),
},
headers=self._AUTH_HEADERS,
)
assert resp.json["status"] == "ok"
@@ -285,6 +290,7 @@ class TestUploadIndex:
"type": "local",
"directory_structure": "not valid json",
},
headers=self._AUTH_HEADERS,
)
assert resp.json["status"] == "ok"
@@ -317,6 +323,7 @@ class TestUploadIndex:
"type": "local",
"file_name_map": json.dumps(fmap),
},
headers=self._AUTH_HEADERS,
)
assert resp.json["status"] == "ok"
@@ -349,6 +356,7 @@ class TestUploadIndex:
"id": doc_id,
"type": "local",
},
headers=self._AUTH_HEADERS,
)
assert resp.json["status"] == "no file"
@@ -379,6 +387,7 @@ class TestUploadIndex:
"type": "local",
"file_faiss": (io.BytesIO(b""), ""),
},
headers=self._AUTH_HEADERS,
)
assert resp.json["status"] == "no file name"
@@ -408,6 +417,7 @@ class TestUploadIndex:
"remote_data": '{"url":"http://example.com"}',
"sync_frequency": "daily",
},
headers=self._AUTH_HEADERS,
)
assert resp.json["status"] == "ok"
@@ -443,6 +453,7 @@ class TestUploadIndex:
"file_pkl": (io.BytesIO(b"pkl data"), "index.pkl"),
},
content_type="multipart/form-data",
headers=self._AUTH_HEADERS,
)
assert resp.json["status"] == "ok"
@@ -477,6 +488,7 @@ class TestUploadIndex:
"file_faiss": (io.BytesIO(b"faiss data"), "index.faiss"),
},
content_type="multipart/form-data",
headers=self._AUTH_HEADERS,
)
assert resp.json["status"] == "no file"
@@ -508,9 +520,27 @@ class TestUploadIndex:
"file_pkl": (io.BytesIO(b""), ""),
},
content_type="multipart/form-data",
headers=self._AUTH_HEADERS,
)
assert resp.json["status"] == "no file name"
def test_no_internal_key_rejects_upload(self, internal_app, monkeypatch):
"""Verify that upload_index is rejected when INTERNAL_KEY is not set."""
app, db = internal_app
monkeypatch.setattr(
"application.api.internal.routes.settings",
MagicMock(
INTERNAL_KEY=None,
UPLOAD_FOLDER="uploads",
VECTOR_STORE="faiss",
EMBEDDINGS_NAME="test",
MONGO_DB_NAME="docsgpt",
),
)
with app.test_client() as client:
resp = client.post("/api/upload_index", data={"user": "attacker"})
assert resp.status_code == 401
def test_update_existing_with_file_name_map(self, internal_app, monkeypatch):
"""Cover line 124: update existing entry with file_name_map."""
app, db = internal_app
@@ -540,6 +570,7 @@ class TestUploadIndex:
"type": "local",
"file_name_map": json.dumps(fmap),
},
headers=self._AUTH_HEADERS,
)
assert resp.json["status"] == "ok"
@@ -572,6 +603,7 @@ class TestUploadIndex:
"type": "local",
"file_name_map": "not valid json{{{",
},
headers=self._AUTH_HEADERS,
)
assert resp.json["status"] == "ok"

View File

@@ -14,6 +14,13 @@ def app():
return app
@pytest.fixture(autouse=True)
def _bypass_url_validation():
"""Bypass SSRF URL validation so tests using localhost URLs can proceed."""
with patch("application.api.user.tools.mcp.validate_url"):
yield
# ---------------------------------------------------------------------------
# Helper: _sanitize_mcp_transport
# ---------------------------------------------------------------------------

View File

@@ -395,6 +395,9 @@ class TestAvailableTools:
"application.api.user.tools.routes.tool_manager", mock_manager
):
with app.test_request_context("/api/available_tools"):
from flask import request
request.decoded_token = {"sub": "user1"}
response = AvailableTools().get()
assert response.status_code == 200
@@ -419,6 +422,9 @@ class TestAvailableTools:
"application.api.user.tools.routes.tool_manager", mock_manager
):
with app.test_request_context("/api/available_tools"):
from flask import request
request.decoded_token = {"sub": "user1"}
response = AvailableTools().get()
assert response.status_code == 400
@@ -438,6 +444,9 @@ class TestAvailableTools:
"application.api.user.tools.routes.tool_manager", mock_manager
):
with app.test_request_context("/api/available_tools"):
from flask import request
request.decoded_token = {"sub": "user1"}
response = AvailableTools().get()
assert response.status_code == 200

0
tests/api/v1/__init__.py Normal file
View File

View File

@@ -0,0 +1,64 @@
from flask import Flask
from application.api.v1.routes import v1_bp
class _FakeCollection:
def __init__(self, docs):
self.docs = docs
def find_one(self, query):
for doc in self.docs:
if all(doc.get(k) == v for k, v in query.items()):
return doc
return None
def find(self, query):
return [doc for doc in self.docs if all(doc.get(k) == v for k, v in query.items())]
def _build_app():
app = Flask(__name__)
app.register_blueprint(v1_bp)
return app
def test_v1_models_does_not_expose_agent_keys(monkeypatch):
docs = [
{"_id": "agent-1", "key": "key-1", "user": "user-1", "name": "Agent One"},
{"_id": "agent-2", "key": "key-2", "user": "user-1", "name": "Agent Two"},
]
fake_mongo = {"testdb": {"agents": _FakeCollection(docs)}}
monkeypatch.setattr("application.api.v1.routes.MongoDB.get_client", lambda: fake_mongo)
monkeypatch.setattr("application.api.v1.routes.settings.MONGO_DB_NAME", "testdb")
app = _build_app()
client = app.test_client()
response = client.get("/v1/models", headers={"Authorization": "Bearer key-1"})
assert response.status_code == 200
payload = response.get_json()
assert payload["object"] == "list"
assert len(payload["data"]) == 2
assert payload["data"][0]["id"] == "agent-1"
assert payload["data"][1]["id"] == "agent-2"
# Keys must never appear as model IDs
assert all(model["id"] != "key-1" for model in payload["data"])
assert all(model["id"] != "key-2" for model in payload["data"])
def test_v1_models_invalid_key_returns_401(monkeypatch):
docs = [
{"_id": "agent-1", "key": "key-1", "user": "user-1", "name": "Agent One"},
]
fake_mongo = {"testdb": {"agents": _FakeCollection(docs)}}
monkeypatch.setattr("application.api.v1.routes.MongoDB.get_client", lambda: fake_mongo)
monkeypatch.setattr("application.api.v1.routes.settings.MONGO_DB_NAME", "testdb")
app = _build_app()
client = app.test_client()
response = client.get("/v1/models", headers={"Authorization": "Bearer wrong-key"})
assert response.status_code == 401

View File

@@ -20,133 +20,49 @@ def test_epub_init_parser():
assert parser.parser_config_set
def test_epub_parser_ebooklib_import_error(epub_parser):
"""Test that ImportError is raised when ebooklib is not available."""
with patch.dict(sys.modules, {"ebooklib": None}):
with pytest.raises(ValueError, match="`EbookLib` is required to read Epub files"):
def test_epub_parser_fast_ebook_import_error(epub_parser):
"""Test that ImportError is raised when fast-ebook is not available."""
with patch.dict(sys.modules, {"fast_ebook": None}):
with pytest.raises(ValueError, match="`fast-ebook` is required to read Epub files"):
epub_parser.parse_file(Path("test.epub"))
def test_epub_parser_html2text_import_error(epub_parser):
"""Test that ImportError is raised when html2text is not available."""
fake_ebooklib = types.ModuleType("ebooklib")
fake_epub = types.ModuleType("ebooklib.epub")
fake_ebooklib.epub = fake_epub
with patch.dict(sys.modules, {"ebooklib": fake_ebooklib, "ebooklib.epub": fake_epub}):
with patch.dict(sys.modules, {"html2text": None}):
with pytest.raises(ValueError, match="`html2text` is required to parse Epub files"):
epub_parser.parse_file(Path("test.epub"))
def test_epub_parser_successful_parsing(epub_parser):
"""Test successful parsing of an epub file."""
fake_fast_ebook = types.ModuleType("fast_ebook")
fake_epub = types.ModuleType("fast_ebook.epub")
fake_fast_ebook.epub = fake_epub
fake_ebooklib = types.ModuleType("ebooklib")
fake_epub = types.ModuleType("ebooklib.epub")
fake_html2text = types.ModuleType("html2text")
# Mock ebooklib constants
fake_ebooklib.ITEM_DOCUMENT = "document"
fake_ebooklib.epub = fake_epub
mock_item1 = MagicMock()
mock_item1.get_type.return_value = "document"
mock_item1.get_content.return_value = b"<h1>Chapter 1</h1><p>Content 1</p>"
mock_item2 = MagicMock()
mock_item2.get_type.return_value = "document"
mock_item2.get_content.return_value = b"<h1>Chapter 2</h1><p>Content 2</p>"
mock_item3 = MagicMock()
mock_item3.get_type.return_value = "other" # Should be ignored
mock_item3.get_content.return_value = b"<p>Other content</p>"
mock_book = MagicMock()
mock_book.get_items.return_value = [mock_item1, mock_item2, mock_item3]
mock_book.to_markdown.return_value = "# Chapter 1\n\nContent 1\n\n# Chapter 2\n\nContent 2\n"
fake_epub.read_epub = MagicMock(return_value=mock_book)
def mock_html2text_func(html_content):
if "Chapter 1" in html_content:
return "# Chapter 1\n\nContent 1\n"
elif "Chapter 2" in html_content:
return "# Chapter 2\n\nContent 2\n"
return "Other content\n"
fake_html2text.html2text = mock_html2text_func
with patch.dict(sys.modules, {
"ebooklib": fake_ebooklib,
"ebooklib.epub": fake_epub,
"html2text": fake_html2text
"fast_ebook": fake_fast_ebook,
"fast_ebook.epub": fake_epub,
}):
result = epub_parser.parse_file(Path("test.epub"))
expected_result = "# Chapter 1\n\nContent 1\n\n# Chapter 2\n\nContent 2\n"
assert result == expected_result
# Verify epub.read_epub was called with correct parameters
fake_epub.read_epub.assert_called_once_with(Path("test.epub"), options={"ignore_ncx": True})
assert result == "# Chapter 1\n\nContent 1\n\n# Chapter 2\n\nContent 2\n"
fake_epub.read_epub.assert_called_once_with(Path("test.epub"))
def test_epub_parser_empty_book(epub_parser):
"""Test parsing an epub file with no document items."""
# Create mock modules
fake_ebooklib = types.ModuleType("ebooklib")
fake_epub = types.ModuleType("ebooklib.epub")
fake_html2text = types.ModuleType("html2text")
fake_ebooklib.ITEM_DOCUMENT = "document"
fake_ebooklib.epub = fake_epub
# Create mock book with no document items
"""Test parsing an epub file with no content."""
fake_fast_ebook = types.ModuleType("fast_ebook")
fake_epub = types.ModuleType("fast_ebook.epub")
fake_fast_ebook.epub = fake_epub
mock_book = MagicMock()
mock_book.get_items.return_value = []
mock_book.to_markdown.return_value = ""
fake_epub.read_epub = MagicMock(return_value=mock_book)
fake_html2text.html2text = MagicMock()
with patch.dict(sys.modules, {
"ebooklib": fake_ebooklib,
"ebooklib.epub": fake_epub,
"html2text": fake_html2text
"fast_ebook": fake_fast_ebook,
"fast_ebook.epub": fake_epub,
}):
result = epub_parser.parse_file(Path("empty.epub"))
assert result == ""
fake_html2text.html2text.assert_not_called()
def test_epub_parser_non_document_items_ignored(epub_parser):
"""Test that non-document items are ignored during parsing."""
fake_ebooklib = types.ModuleType("ebooklib")
fake_epub = types.ModuleType("ebooklib.epub")
fake_html2text = types.ModuleType("html2text")
fake_ebooklib.ITEM_DOCUMENT = "document"
fake_ebooklib.epub = fake_epub
mock_doc_item = MagicMock()
mock_doc_item.get_type.return_value = "document"
mock_doc_item.get_content.return_value = b"<p>Document content</p>"
mock_other_item = MagicMock()
mock_other_item.get_type.return_value = "image" # Not a document
mock_book = MagicMock()
mock_book.get_items.return_value = [mock_other_item, mock_doc_item]
fake_epub.read_epub = MagicMock(return_value=mock_book)
fake_html2text.html2text = MagicMock(return_value="Document content\n")
with patch.dict(sys.modules, {
"ebooklib": fake_ebooklib,
"ebooklib.epub": fake_epub,
"html2text": fake_html2text
}):
result = epub_parser.parse_file(Path("test.epub"))
assert result == "Document content\n"
fake_html2text.html2text.assert_called_once_with("<p>Document content</p>")

View File

@@ -8,7 +8,7 @@ from application.storage.local import LocalStorage
@pytest.fixture
def temp_base_dir():
return "/tmp/test_storage"
return os.path.realpath("/tmp/test_storage")
@pytest.fixture
@@ -30,12 +30,12 @@ class TestLocalStorageInitialization:
def test_get_full_path_with_relative_path(self, local_storage):
result = local_storage._get_full_path("documents/test.txt")
expected = os.path.join("/tmp/test_storage", "documents/test.txt")
assert os.path.normpath(result) == os.path.normpath(expected)
expected = os.path.realpath(os.path.join(os.path.realpath("/tmp/test_storage"), "documents/test.txt"))
assert result == expected
def test_get_full_path_with_absolute_path(self, local_storage):
result = local_storage._get_full_path("/absolute/path/test.txt")
assert result == "/absolute/path/test.txt"
def test_get_full_path_with_absolute_path_outside_base_raises(self, local_storage):
with pytest.raises(ValueError, match="Path traversal detected"):
local_storage._get_full_path("/absolute/path/test.txt")
@patch("os.makedirs")
@patch("builtins.open", new_callable=mock_open)
@@ -48,8 +48,8 @@ class TestLocalStorageInitialization:
result = local_storage.save_file(file_data, path)
expected_dir = os.path.join("/tmp/test_storage", "documents")
expected_file = os.path.join("/tmp/test_storage", "documents/test.txt")
expected_dir = os.path.join(os.path.realpath("/tmp/test_storage"), "documents")
expected_file = os.path.join(os.path.realpath("/tmp/test_storage"), "documents/test.txt")
assert mock_makedirs.call_count == 1
assert os.path.normpath(mock_makedirs.call_args[0][0]) == os.path.normpath(
@@ -74,25 +74,19 @@ class TestLocalStorageInitialization:
result = local_storage.save_file(file_data, path)
expected_file = os.path.join("/tmp/test_storage", "documents/test.txt")
expected_file = os.path.join(os.path.realpath("/tmp/test_storage"), "documents/test.txt")
assert file_data.save.call_count == 1
assert os.path.normpath(file_data.save.call_args[0][0]) == os.path.normpath(
expected_file
)
assert result == {"storage_type": "local"}
@patch("os.makedirs")
@patch("builtins.open", new_callable=mock_open)
def test_save_file_with_absolute_path(
self, mock_file, mock_makedirs, local_storage
):
def test_save_file_with_absolute_path_outside_base_raises(self, local_storage):
file_data = io.BytesIO(b"test content")
path = "/absolute/path/test.txt"
local_storage.save_file(file_data, path)
mock_makedirs.assert_called_once_with("/absolute/path", exist_ok=True)
mock_file.assert_called_once_with("/absolute/path/test.txt", "wb")
with pytest.raises(ValueError, match="Path traversal detected"):
local_storage.save_file(file_data, path)
@pytest.mark.unit
@@ -105,7 +99,7 @@ class TestLocalStorageGetFile:
result = local_storage.get_file(path)
expected_path = os.path.join("/tmp/test_storage", "documents/test.txt")
expected_path = os.path.join(os.path.realpath("/tmp/test_storage"), "documents/test.txt")
assert mock_exists.call_count == 1
assert os.path.normpath(mock_exists.call_args[0][0]) == os.path.normpath(
expected_path
@@ -122,7 +116,7 @@ class TestLocalStorageGetFile:
with pytest.raises(FileNotFoundError, match="File not found"):
local_storage.get_file(path)
expected_path = os.path.join("/tmp/test_storage", "documents/nonexistent.txt")
expected_path = os.path.join(os.path.realpath("/tmp/test_storage"), "documents/nonexistent.txt")
assert mock_exists.call_count == 1
assert os.path.normpath(mock_exists.call_args[0][0]) == os.path.normpath(
expected_path
@@ -141,7 +135,7 @@ class TestLocalStorageDeleteFile:
result = local_storage.delete_file(path)
expected_path = os.path.join("/tmp/test_storage", "documents/test.txt")
expected_path = os.path.join(os.path.realpath("/tmp/test_storage"), "documents/test.txt")
assert result is True
assert mock_exists.call_count == 1
assert os.path.normpath(mock_exists.call_args[0][0]) == os.path.normpath(
@@ -158,7 +152,7 @@ class TestLocalStorageDeleteFile:
result = local_storage.delete_file(path)
expected_path = os.path.join("/tmp/test_storage", "documents/nonexistent.txt")
expected_path = os.path.join(os.path.realpath("/tmp/test_storage"), "documents/nonexistent.txt")
assert result is False
assert mock_exists.call_count == 1
assert os.path.normpath(mock_exists.call_args[0][0]) == os.path.normpath(
@@ -175,7 +169,7 @@ class TestLocalStorageFileExists:
result = local_storage.file_exists(path)
expected_path = os.path.join("/tmp/test_storage", "documents/test.txt")
expected_path = os.path.join(os.path.realpath("/tmp/test_storage"), "documents/test.txt")
assert result is True
assert mock_exists.call_count == 1
assert os.path.normpath(mock_exists.call_args[0][0]) == os.path.normpath(
@@ -188,7 +182,7 @@ class TestLocalStorageFileExists:
result = local_storage.file_exists(path)
expected_path = os.path.join("/tmp/test_storage", "documents/nonexistent.txt")
expected_path = os.path.join(os.path.realpath("/tmp/test_storage"), "documents/nonexistent.txt")
assert result is False
assert mock_exists.call_count == 1
assert os.path.normpath(mock_exists.call_args[0][0]) == os.path.normpath(
@@ -205,7 +199,7 @@ class TestLocalStorageListFiles:
self, mock_exists, mock_walk, local_storage
):
directory = "documents"
base_dir = os.path.join("/tmp/test_storage", "documents")
base_dir = os.path.join(os.path.realpath("/tmp/test_storage"), "documents")
mock_walk.return_value = [
(base_dir, ["subdir"], ["file1.txt", "file2.txt"]),
@@ -228,7 +222,7 @@ class TestLocalStorageListFiles:
result = local_storage.list_files(directory)
expected_path = os.path.join("/tmp/test_storage", "nonexistent")
expected_path = os.path.join(os.path.realpath("/tmp/test_storage"), "nonexistent")
assert result == []
assert mock_exists.call_count == 1
assert os.path.normpath(mock_exists.call_args[0][0]) == os.path.normpath(
@@ -248,7 +242,7 @@ class TestLocalStorageProcessFile:
result = local_storage.process_file(path, processor_func, extra_arg="value")
expected_path = os.path.join("/tmp/test_storage", "documents/test.txt")
expected_path = os.path.join(os.path.realpath("/tmp/test_storage"), "documents/test.txt")
assert result == "processed"
assert processor_func.call_count == 1
call_kwargs = processor_func.call_args[1]
@@ -280,7 +274,7 @@ class TestLocalStorageIsDirectory:
result = local_storage.is_directory(path)
expected_path = os.path.join("/tmp/test_storage", "documents")
expected_path = os.path.join(os.path.realpath("/tmp/test_storage"), "documents")
assert result is True
assert mock_isdir.call_count == 1
assert os.path.normpath(mock_isdir.call_args[0][0]) == os.path.normpath(
@@ -295,7 +289,7 @@ class TestLocalStorageIsDirectory:
result = local_storage.is_directory(path)
expected_path = os.path.join("/tmp/test_storage", "documents/test.txt")
expected_path = os.path.join(os.path.realpath("/tmp/test_storage"), "documents/test.txt")
assert result is False
assert mock_isdir.call_count == 1
assert os.path.normpath(mock_isdir.call_args[0][0]) == os.path.normpath(
@@ -316,7 +310,7 @@ class TestLocalStorageRemoveDirectory:
result = local_storage.remove_directory(directory)
expected_path = os.path.join("/tmp/test_storage", "documents")
expected_path = os.path.join(os.path.realpath("/tmp/test_storage"), "documents")
assert result is True
assert mock_exists.call_count == 1
assert os.path.normpath(mock_exists.call_args[0][0]) == os.path.normpath(
@@ -339,7 +333,7 @@ class TestLocalStorageRemoveDirectory:
result = local_storage.remove_directory(directory)
expected_path = os.path.join("/tmp/test_storage", "nonexistent")
expected_path = os.path.join(os.path.realpath("/tmp/test_storage"), "nonexistent")
assert result is False
assert mock_exists.call_count == 1
assert os.path.normpath(mock_exists.call_args[0][0]) == os.path.normpath(
@@ -355,7 +349,7 @@ class TestLocalStorageRemoveDirectory:
result = local_storage.remove_directory(path)
expected_path = os.path.join("/tmp/test_storage", "documents/test.txt")
expected_path = os.path.join(os.path.realpath("/tmp/test_storage"), "documents/test.txt")
assert result is False
assert mock_exists.call_count == 1
assert os.path.normpath(mock_exists.call_args[0][0]) == os.path.normpath(
@@ -376,7 +370,7 @@ class TestLocalStorageRemoveDirectory:
result = local_storage.remove_directory(directory)
expected_path = os.path.join("/tmp/test_storage", "documents")
expected_path = os.path.join(os.path.realpath("/tmp/test_storage"), "documents")
assert result is False
assert mock_rmtree.call_count == 1
assert os.path.normpath(mock_rmtree.call_args[0][0]) == os.path.normpath(
@@ -393,7 +387,7 @@ class TestLocalStorageRemoveDirectory:
result = local_storage.remove_directory(directory)
expected_path = os.path.join("/tmp/test_storage", "documents")
expected_path = os.path.join(os.path.realpath("/tmp/test_storage"), "documents")
assert result is False
assert mock_rmtree.call_count == 1
assert os.path.normpath(mock_rmtree.call_args[0][0]) == os.path.normpath(

View File

@@ -2105,23 +2105,35 @@ class TestInternalRoutes:
app.register_blueprint(internal)
return app
_TEST_KEY = "test-key"
_AUTH_HEADERS = {"X-Internal-Key": "test-key"}
def test_upload_index_no_user(self, app):
with app.test_client() as client:
with patch(
"application.api.internal.routes.settings"
) as ms:
ms.INTERNAL_KEY = None
resp = client.post("/api/upload_index")
ms.INTERNAL_KEY = self._TEST_KEY
resp = client.post("/api/upload_index", headers=self._AUTH_HEADERS)
assert resp.get_json()["status"] == "no user"
def test_upload_index_no_name(self, app):
with app.test_client() as client:
with patch(
"application.api.internal.routes.settings"
) as ms:
ms.INTERNAL_KEY = self._TEST_KEY
resp = client.post("/api/upload_index", data={"user": "u1"}, headers=self._AUTH_HEADERS)
assert resp.get_json()["status"] == "no name"
def test_upload_index_rejected_without_internal_key(self, app):
with app.test_client() as client:
with patch(
"application.api.internal.routes.settings"
) as ms:
ms.INTERNAL_KEY = None
resp = client.post("/api/upload_index", data={"user": "u1"})
assert resp.get_json()["status"] == "no name"
assert resp.status_code == 401
# ---------------------------------------------------------------------------

View File

@@ -11,6 +11,7 @@ import pytest
from application.api.v1.translator import (
_get_client_tool_name,
convert_history,
extract_system_prompt,
extract_tool_results,
is_continuation,
translate_request,
@@ -148,6 +149,48 @@ class TestConvertHistory:
assert history == []
# ---------------------------------------------------------------------------
# extract_system_prompt
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestExtractSystemPrompt:
def test_extracts_first_system_message(self):
messages = [
{"role": "system", "content": "You are a pirate"},
{"role": "user", "content": "Hello"},
]
assert extract_system_prompt(messages) == "You are a pirate"
def test_returns_none_when_no_system_message(self):
messages = [{"role": "user", "content": "Hello"}]
assert extract_system_prompt(messages) is None
def test_returns_first_of_multiple_system_messages(self):
messages = [
{"role": "system", "content": "First"},
{"role": "system", "content": "Second"},
{"role": "user", "content": "Hello"},
]
assert extract_system_prompt(messages) == "First"
def test_empty_content_returns_empty_string(self):
messages = [
{"role": "system", "content": ""},
{"role": "user", "content": "Hello"},
]
assert extract_system_prompt(messages) == ""
def test_missing_content_returns_empty_string(self):
messages = [
{"role": "system"},
{"role": "user", "content": "Hello"},
]
assert extract_system_prompt(messages) == ""
# ---------------------------------------------------------------------------
# translate_request
# ---------------------------------------------------------------------------
@@ -167,11 +210,25 @@ class TestTranslateRequest:
result = translate_request(data, "test-key")
assert result["question"] == "What's 2+2?"
assert result["api_key"] == "test-key"
assert result["save_conversation"] is True
# Conversations are not persisted by default on the v1 endpoint.
assert result["save_conversation"] is False
history = json.loads(result["history"])
assert len(history) == 1
assert history[0]["prompt"] == "Hello"
def test_save_conversation_opt_in_via_docsgpt_extension(self):
data = {
"messages": [{"role": "user", "content": "Hi"}],
"docsgpt": {"save_conversation": True},
}
result = translate_request(data, "key")
assert result["save_conversation"] is True
def test_save_conversation_default_false(self):
data = {"messages": [{"role": "user", "content": "Hi"}]}
result = translate_request(data, "key")
assert result["save_conversation"] is False
def test_continuation_request(self):
data = {
"messages": [
@@ -237,6 +294,23 @@ class TestTranslateRequest:
result = translate_request(data, "key")
assert result["attachments"] == ["att1", "att2"]
def test_system_prompt_override_included_when_present(self):
data = {
"messages": [
{"role": "system", "content": "Custom prompt"},
{"role": "user", "content": "Hello"},
],
}
result = translate_request(data, "key")
assert result["system_prompt_override"] == "Custom prompt"
def test_system_prompt_override_absent_when_no_system_message(self):
data = {
"messages": [{"role": "user", "content": "Hello"}],
}
result = translate_request(data, "key")
assert "system_prompt_override" not in result
# ---------------------------------------------------------------------------
# translate_response

View File

@@ -363,6 +363,28 @@ class TestGetVectorstore:
assert get_vectorstore("user/source123") == "indexes/user/source123"
@pytest.mark.parametrize(
"malicious_path",
[
"../outside",
"../../etc/passwd",
"nested/../../../outside",
"/tmp/evil",
"..\\outside",
"valid/../../escape",
],
)
def test_rejects_path_traversal(self, malicious_path):
from application.vectorstore.faiss import get_vectorstore
with pytest.raises(ValueError, match="Invalid source_id path"):
get_vectorstore(malicious_path)
def test_allows_mongodb_style_ids(self):
from application.vectorstore.faiss import get_vectorstore
assert get_vectorstore("65e8f6a8a7a96b1bdad4154f") == "indexes/65e8f6a8a7a96b1bdad4154f"
@pytest.mark.unit
class TestFaissStoreAddChunk: